Mastering Matplotlib: Easy Plotting Tips and Common Pitfalls Explained
Navigating a data-driven world: Learning the art of visualization to express thoughts and knowledge. ???? #DataDriven #VisualizationMatters

Mastering Matplotlib: Easy Plotting Tips and Common Pitfalls Explained

We live in an era where everything in the world, and the world itself, is explained by data. We have to learn how to visualize data to be able to depict our thoughts and knowledge.

When it comes to creating plots in Python, Matplotlib stands out as one of the most popular tools. However, using it efficiently can be a bit tricky. This article is here to help, breaking down Matplotlib's features in a simple and practical way.

Matplotlib gives us two main ways to make plots. The first is the functional way, great for quick visualizations, especially in places like Jupyter Notebooks. The second is the object-oriented way, which is super useful for more complex plots. Personally, I like the second way more because it gives us better control over our plots, and it's easier to understand.

In this article, I'll introduce both methods, but we'll mostly focus on the object-oriented way. This choice allows us to explore different figures and features hands-on, getting a real feel for what Matplotlib can do.

So, let's dive into the article and uncover the magic of Matplotlib's 3D plotting. Whether it's scatter plots, bar charts, quiver plots, or polar plots – each method has its own role in making your data come to life. Whether you're showing density distributions, comparing datasets, or creating awesome 3D visuals, Matplotlib has got almost everything you need. Let's see how we can make the most of this powerful library for all your plotting adventures!

Functional approach

The functional approach in Matplotlib involves using the ???????????? interface, which relies on a global state to configure and create plots. This method offers a simple way to generate basic plots by directly calling functions from the ???????????? module.

Consider the following example:

import matplotlib.pyplot as plt
import numpy as np 

# Generating sample data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Using the functional approach (plt.plot()) 
plt.plot(x, y, label='Sine Curve')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Functional Approach')
plt.legend()
plt.show()        
Plotting Sine curve using pyplot

Here's a detailed explanation of the example:

  1. Global State Handling: ???????????? manages a global state behind the scenes, allowing users to create and configure plots using simple functions.
  2. Plot Creation: ??????.????????(??, ??, ??????????='???????? ??????????') generates a line plot by plotting x against y, labeling it as 'Sine Curve'.
  3. Axis Labels and Title: ??????.????????????('??-????????') and ??????.????????????('??-????????') set labels for the x-axis and y-axis, respectively. ??????.??????????('???????????????????? ????????????????') assigns a title.
  4. Legend Display: ??????.????????????() displays a legend based on the labels specified in the plot.
  5. Displaying the Plot: ??????.????????() renders and displays the finalized plot. this function is not always necessary. For example, the plot is displayed automatically in Jupyter notebooks.

In this functional approach, the ???????????? interface abstracts much of the underlying complexity, making it suitable for quick and simple visualizations. However, this approach might be limited in handling more intricate plot configurations or managing multiple subplots. In contrast, the object-oriented approach offers finer control over individual plot elements, making it more flexible for complex layouts and detailed customizations. Object-oriented methods promote clearer code structure, enhancing readability and maintainability (at least for those with a programming background), making them preferable for intricate visualizations and larger projects.

Object-oriented method

In object oriented approach, we work with figures and subplots.

  • A figure is the canvas or container where plots reside.
  • Subplots divide the figure into grids to accommodate multiple plots, allowing for side-by-side or stacked visualizations within a single figure.

The following picture, shows a 2x2 grid of subplots within a single figure. Each subplot is positioned in a different section of the grid. They display plots of sin(x), cos(x), tan(x), and x^2.

Matplotlib figure and subplots

Creating such a figure using the legacy functional mode results in a code that is challenging to maintain and extend in the future. I always recommend using the newer, modern interfaces.

To create a new figure object, we can use ??????.????????????(); it initializes a blank canvas ready for plotting. Then add individual subplots to it using the ??????_??????????????() method.

# Creating a figure object explicitly
fig = plt.figure()

# Adding a subplot to figure using add_subplot()
ax = fig.add_subplot(2,2,1)  
# 2,2,1 means 2 row, 2 column, first subplot
ax.plot(x, np.sin(x))
ax.set_title('Sine Curve')

ax = fig.add_subplot(2,2,2)  
ax.plot(x, np.cos(x))
ax.set_title('Cos Curve')

ax = fig.add_subplot(2,2,3)  
ax.plot(x, np.tan(x))
ax.set_title('Tan Curve')

ax = fig.add_subplot(2,2,4)  
ax.plot(x, np.power(x,2))
ax.set_title('x^2')        

Alternatively, we can use ??????.??????????????() and ??????.????????????????() convenience methods. They are both functions, used to create multiple subplots within a single figure, but they have different purposes and usage. ??????????????() function returns one single subplot, while ????????????????() returns one figure and multiple subplots. First, lets take a look at an example of using the first function:

# Adding a single subplot using subplot()
ax = plt.subplot(2,2,1)  
# 2,2,1 means 2 rows, 2 columns, first subplot
ax.plot(x, np.sin(x))
ax.set_title('Sine Curve')

ax = plt.subplot(2,2,2)  
ax.plot(x, np.cos(x))
ax.set_title('Cos Curve')

ax = plt.subplot(2,2,3)  
ax.plot(x, np.tan(x))
ax.set_title('Tan Curve')

ax = plt.subplot(2,2,4)  
ax.plot(x, np.power(x,2))
ax.set_title('x^2')        

As you see in the sample code, ??????????????() is very similar to ??????_??????????????(). However, with ??????????????(), we do not have direct access to the figure. ????????????????() on the other hand, returns both the figure and the collection of all the subplots:

# Adding four subplots using subplots() method
fig, axes = plt.subplots(2,2)  
# 2,2 means 2 rows, 2 columns
axes[0][0].plot(x, np.sin(x))
axes[0][0].set_title('Sine Curve')

axes[0][1].plot(x, np.cos(x))
axes[0][1].set_title('Cos Curve')

axes[1][0].plot(x, np.tan(x))
axes[1][0].set_title('Tan Curve')

axes[1][1].plot(x, np.power(x,2))
axes[1][1].set_title('x^2')        

In summary, the ??????.????????????????() function is a convenient way to create both the figure and several subplots simultaneously , while ??????.??????????????() only returns a subplot each time.


Customizing figures

If you use ??????.????????????() or ????????????????() functions, then you have access to figure object and you can customize it.

??????.????????????????_????????????() adjusts the spacing between subplots within a figure. By specifying parameters like left, right, top, bottom, wspace, and hspace, you can control the distance between subplots horizontally (wspace) and vertically (hspace), as well as adjust the margins (left, right, top, bottom). This helps manage the layout and alignment of subplots within the figure.

fig.subplots_adjust(hspace=0.4, wspace=0.3)        
Aligning and adjusting subplots within a figure
Aligning and adjusting subplots within a figure

There are several other figure customization functions that help modify various aspects of the figure. Some common ones include:

  • ??????.??????_????????_????????????(): Adjusts the size of the figure in inches.
  • ??????.????????????????(): Sets a centered title for the entire figure.
  • ??????.??????????_????????????(): Automatically adjusts subplot parameters to fit the figure area.
  • ??????.??????????????(): Saves the figure to a file (e.g., PNG, PDF, SVG).


Customizing Plots

Customizing Matplotlib plots involves various techniques, including using styles and themes to alter the overall appearance. Annotations can be added to highlight specific points, and the library supports various axis scales, including logarithmic. You can customize grid lines and add axis labels for clarity. When dealing with multiple datasets, incorporating legends helps in distinguishing them.

Styling plots

Plots have some styling parameters (color, line style, marker style, line width, marker size, labels, and legends) that help differentiate and enhance the visual representation of the plotted data, making it easier to interpret and understand the relationships between the sets of data. For example, the following plot illustrates y = x^2 and y = x^3 by adjusting these parameters:

x = np.linspace(-2, 2, 20)
fig = plt.figure() 
ax = fig.add_subplot(1,2,1)
ax.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax.legend()        
Utilizing styling parameters to enhance the appearance of line plots and differentiate between two sets of data.
Enhancing the appearance of line plots and differentiate between two sets of data.

Scaling and presenting multiple datasets

Now, I want to illustrate an annoying problem when you work with data. Guess what you see if you run the following code:

x = np.linspace(-100, 100, 20)
fig = plt.figure() 
ax = fig.add_subplot(1,2,1)
ax.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax.legend()        

This code is very similar to the previous one, I just expanded the domain where data is plotted. but the result is surprising:

One dataset dominating the plot

You will just see a horizontal blue line for y = x^2! The issue arises due to the vast difference in the growth rates between the functions y = x^2 and y = x^3 within the specified range of x-values.

When plotted together on the same graph with a linear scale, the y = x^3 curve grows much faster than y = x^2 for the provided range of x-values (-100 to 100). As a result, the y = x^3 curve dominates the plot, and the y = x^2 curve appears nearly horizontal or flat relative to the rapid increase shown by y = x^3.

Because of the significant difference in growth rates, the y = x^2 curve appears almost straight when visualized alongside the much steeper y = x^3 curve. This issue occurs due to the scaling of the plot, where the differences in growth rates of the functions are not effectively accommodated within the same scale range, making the slower-growing function appear relatively flat or linear.

To handle the problem we may consider different strategies. We can simply plot each dataset on separately however this method may not be completely optimal for comparing y values for the same x values. One solution to mitigate the problem is to use logarithmic scaling for Y-axis:

ax.set_yscale("symlog")        
Using logarithmic scale to solve the issue of two data sets with one dominating the other one with higher growth rate
Using logarithmic scale in plots

The advantage of this method is simplicity, we just need to change the scale by only one simple command. Now, we can see that the x^2 is also growing. While using a logarithmic scale helps to display a wider range of values more clearly, It does not solve the problem completely, as the visualization of the data is not always accurate this way.

In my view, the best way is to use a secondary Y-axis. If the datasets share the same x-axis but have vastly different scales on the y-axis, you can use a secondary y-axis. This approach allows you to plot datasets with different scales on the same plot while maintaining clarity.

x = np.linspace(-100, 100, 20)
fig = plt.figure(figsize=(10,5)) 
ax1 = fig.add_subplot(1,2,1)
ax1.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax1.set_ylabel("y=x^2")
ax1.legend( loc="upper left")

ax2 = ax1.twinx()
ax2.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax2.set_ylabel("y=x^3")
ax2.legend(loc="upper right")        

  • ?????? represents the primary y-axis on the left side, plotting y = x^2 (blue).
  • ?????? represents the secondary y-axis on the right side, plotting y = x^3 (red).
  • ??????.??????????() creates a twin of the primary axis ax1 that shares the same x-axis but allows a different y-axis (ax2).

twinx() allows us two print two different datasets with single x-axis and two y-axis
Printing two different datasets on single x-axis and two y-axis

Each dataset is plotted separately on its respective y-axis, allowing clear visualization of both datasets' trends without one overshadowing the other due to different scales.

This approach effectively visualizes datasets with different scales on the same x-axis, aiding in comparing their trends while preserving their individual characteristics. Adjust the parameters as needed for your specific datasets and visualization requirements.

Using annotations

Sometimes, it's essential to draw attention to specific data points or areas within a plot. Annotations in Matplotlib serve as a valuable tool for highlighting particular points or regions. They enable the addition of textual information, arrows, or markers at specific locations, emphasizing critical data points. Below is an example, demonstrating the usage of annotations to highlight points on a plot:

ax = plt.subplot()
x = np.arange(-180,180,1)
ax.plot(x,np.sin(x/180*np.pi), label="Sin(x)" , color="blue")
ax.plot(x,np.cos(x/180*np.pi), label="Cos(x)" , color="green")
ax.set_xticks(np.arange(-180,200,45))
ax.legend()
ax.grid()
ax.annotate("Sin x = Cos x",(45,0.7),(45,-0.4),arrowprops=dict(arrowstyle='->'),ha="center")
ax.annotate("",(-135,-0.7),(8,-0.4),arrowprops=dict(arrowstyle='->'),ha="center")        
Annotations highlight specific data points

Common Plotting Functions

Now it's time to introduce various plotting types. Matplotlib isn't just about plotting mathematical functions. It offers a host of other helpful methods for handling categorical data, creating bar plots, histograms, scatter plots, and more.

Matplotlib is a versatile library that offers various plotting functions. Initially, I faced confusion due to the multitude of options available, making it challenging to create meaningful plots. To simplify the process, I organized the functions based on the types of data they are best suited for. Understanding these categorized groups can assist in selecting the most suitable function for your specific data.

Methods for Numerical Data:

  • ????????(): Creates line plots or markers with x-axis and y-axis lists/arrays.
  • ??????????????(): Generates scatter plots using x-axis and y-axis lists/arrays.
  • ????????(): Displays histograms for numerical data distribution (accepts a list/array).
  • ????????????????????(), ????????????????????(): Used for visualizing numerical data distributions.

Methods for Categorical Data:

  • ??????(), ????????(): Creates vertical/horizontal bar charts for categorical data.
  • ??????(): Represents categorical proportions using a list of values.

Methods for Specialized Data:

  • ??????????(): Constructs polar plots for radial or circular data representation.
  • ????????????(): Plots 2D vector fields using coordinate and vector components.
  • ????????????(), ????????????(), ??????????????(): Specialized functions for image and 2D array visualization.

Understanding these categorized groups can simplify the process of choosing the appropriate function tailored to your specific data type. Note that certain functions, such as boxplot(), bar(), or hist(), can handle both numerical and categorical data depending on their usage and input parameters.

plot() and scatter()

Both ????????() and ??????????????() are used with numerical data, presenting relationships between X and Y variables. However, their fundamental differences lie in how they visualize this relationship:

????????() typically creates line-based plots, emphasizing the connected nature of the data points. It's often used to display trends, sequences, or continuous data, showcasing the overall pattern between data points through lines or markers connected by default.

??????????????() focuses on individual data points, emphasizing the distinct nature of each point. It doesn't connect points with lines by default, presenting data as separate markers. This function is commonly utilized to explore correlations, clusters, or distributions within a dataset, especially when highlighting individual data points is essential.

For example, this code generates 100 values of y that are linearly related to x with a certain amount of random noise added.

mu_x = 50
mu_y = 30
sigma_x = 10
sigma_y = 15
correlation_coefficient = 0.7

# Generate data
x = np.random.normal(mu_x, sigma_x, 100)
y = np.random.normal(mu_y, sigma_y, 100)

# Transform data to match correlation
y = correlation_coefficient * x + np.sqrt(1 - correlation_coefficient**2) * np.random.normal(0, sigma_y, 100)        

We can illustrate the correlation between two datasets using ??????????????(), while the regression line, also known as model line, or least squares line is plotted by ????????() function.

from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x.reshape(-1, 1), y)
slope = model.coef_[0]
intercept = model.intercept_
ax = plt.subplot()
ax.scatter(x,y)
x1 = x.min()
x2 = x.max()
ax.plot([x1,x2],[slope*x1+intercept, slope*x2+intercept],color="red")        

I believe the plot is inspiring and self-explanatory! It's always rewarding when visualizations effectively communicate insights and relationships within data!

Linear Regression Fit: Visualizing the fitted line on a scatter plot using sklearn's Linear Regression model
Linear Regression Fit: Visualizing the fitted line on a scatter plot

hist()

As we saw already, the ????????() and ??????????????() functions in Matplotlib are used to visualize relationships between numerical data sets. ????????() displays connected points or lines to showcase the correlation between corresponding data pairs. ??????????????() emphasizes individual points, revealing associations between paired values from both sets. Unlike plot and scatter, ????????() doesn't require two sets of data; it focuses on a single dataset, illustrating the frequency distribution within defined intervals (bins). This unique feature makes it stand distinct as it's specifically tailored to generate histograms for understanding the distribution pattern and frequency of a single set of numerical data.

uniform = np.random.random(10000) * 8 - 4
ax = plt.subplot()
ax.hist(uniform, alpha=0.5, bins=30)

normal = np.random.normal(0,1,10000)
ax.hist(normal, color="green", alpha=0.5, bins=30)        

The code generates two histograms on one plot. The x-axis represents the value range, while the y-axis shows the frequency of occurrence. Alpha (transparency) helps overlay and visualize both histograms simultaneously, where one histogram shows a uniform distribution between -4 to +4, and the other depicts a normal (Gaussian) distribution centered around 0.

Histograms differentiate datasets based on their sample distributions.

Working with data always presents challenges as it often involves unforeseen problems and complexities. Rarely is data as tidy and straightforward as in the previous example. For instance, when dealing with two datasets, a common issue arises when one dataset contains extremely larger number of data samples. Improper representation of such data can result in user misconceptions.

Look at the following example:

uniform = np.random.random(10000) * 8 - 4
fig , axes = plt.subplots(1,2)
fig.set_figwidth(10)
normal = np.random.normal(0,1,1000)

axes[0].hist(uniform, alpha=0.5, bins=30)
axes[0].hist(normal, color="green", alpha=0.5, bins=30)
axes[0].set_xlabel("Frequency")

axes[1].hist(uniform, alpha=0.5, bins=30, density=True)
axes[1].hist(normal, color="green", alpha=0.5, bins=30, density=True)
axes[1].set_xlabel("Density")        
density=True

In the left plot, a potential misconception stems from the uniform distribution sample having a substantially larger size (10,000) compared to the normal distribution (1,000). Overlooking this sample size difference may lead the reader to misinterpret the plot, incorrectly inferring that the uniform distribution has more samples around the mean than the normal distribution. In reality, the perceived dissimilarity is a consequence of the unequal sample sizes, introducing the possibility of misinterpretations regarding the distribution characteristics of the two datasets. Emphasizing and considering sample sizes is crucial when interpreting visualizations to prevent such misconceptions.

Normalization choices in histograms, like setting density=True, significantly affect y-axis representation. Opting for density=True helps address this issue by presenting relative frequencies or proportions, facilitating fair visual comparisons between datasets with differing sample sizes. This normalization approach mitigates biases stemming from unequal sample sizes, enabling a more accurate interpretation of distribution shapes and proportions.

boxplot()

The ?????????????? primarily focuses on displaying statistical parameters and summarizing the distribution's key features, such as median, quartiles, and outliers, rather than explicitly illustrating the data's underlying distribution function.

Unlike histograms, which provide a visual representation of the data's shape and frequency distribution, boxplots prioritize conveying statistical summary measures and identifying variability between groups or categories within the dataset. While histograms offer insights into the data's distributional shape, boxplots excel in highlighting central tendencies and spread, making them complementary visualization tools for different analytical purposes.

The boxplot summarizes the distribution's centrality and dispersal but does not depict the data's shape.

Although a normal distribution causes the interquartile range be narrower due to the fact that a more significant portion (around 68% for the empirical rule) of the data lies within the one standard deviation from the mean, the box plot alone doesn't explicitly display the bell-shaped curve characteristic of a normal distribution.

bar() and barh()

Many people confuse bar charts with histograms. The confusion often arises because both charts use bars to represent data. Additionally, some bar charts may have numerical data on the x-axis, which can resemble a histogram. However, the key difference lies in the nature of the data and the purpose of the chart.

Tips to avoid confusion:

  • Consider the data type: If the data is categorical (e.g., fruit types), use a bar chart. If it's continuous (e.g., heights), use a histogram.
  • Think about the purpose: If the chart compares categories, it's a bar chart. If it shows the distribution of data, it's a histogram.

import random

# Define genders
genders = ["Male", "Female"]
cities = ["New York", "London", "Tokyo", "Paris", "Berlin"]
# Generate 100 customers
customers = []
for _ in range(100):
  # Randomly choose a gender
  gender = random.choice(genders)
  
  # Generate random customer data
  customer = {
    "id": random.randint(1, 10000),
    "name": f"Customer-{random.randint(1, 1000)}",
    "gender": gender,
    "age": random.randint(18, 80),
    "city": random.choice(cities),
  }
  customers.append(customer)

city_data = {}
for person in customers:
    city = person["city"]
    gender = person["gender"]
    if city not in city_data:
        city_data[city] = {"Male": 0, "Female": 0}
    city_data[city][gender] += 1

# Plot hierarchical bar chart
cities = list(city_data.keys())
males = [city_data[city]["Male"] for city in cities]
females = [city_data[city]["Female"] for city in cities]

plt.bar(cities, males, label='Male' ,  width=0.4 , align="edge")
plt.bar(cities, females, label='Female', width=0.4 , align="center")
plt.xlabel('City')
plt.ylabel('Count')
plt.title('Count of Males and Females in Each City')
plt.legend()
plt.show()        

The code generates 100 random customers with gender, age, name, and city attributes. It counts the number of males and females in each city and creates a hierarchical bar chart using Matplotlib to display the count of males and females in various cities. The bars for males and females are aligned differently for clarity.

Bar chart compares categories of data

pie()

Pie charts are effective for showcasing categorical proportions, making it easy to visualize how individual categories contribute to the whole.

city_data = {}
for person in customers:
    city = person["city"]
    if city not in city_data:
        city_data[city] = 0
    city_data[city] += 1

plt.pie(city_data.values(), labels=city_data.keys())
plt.title('Proportion of customers in each city')
plt.legend()
plt.show()        

This code generates a pie chart displaying the distribution of customers in different cities.

A bar chart is suitable for comparing individual categories, while a pie chart is effective for displaying the proportional contribution of each category to a whole.
A bar chart is suitable for comparing individual categories, while a pie chart is effective for displaying the proportional contribution of each category to a whole.

polar()

A polar plot represents data in a circular coordinate system, where angles and distances from the center (radius) display relationships. It visualizes information radially, often used for cyclic or periodic data representations like angles, direction, or periodic patterns.

A classic and beautiful example of a polar plot is the rose curve, also known as the "rhodonea curve." This curve creates a symmetric and aesthetically pleasing pattern.

theta = np.linspace(0, 2*np.pi, 1000)
n = 6  # Number of "petals" or loops in the rose curve

r = np.cos(n*theta)  # Equation for a rose curve

plt.figure(figsize=(6, 6))
plt.polar(theta, r)
plt.title(f'Rose Curve (n={n})')
plt.show()        

This code snippet generates a rose curve with 6 petals.

I wrote this code so you can compare different presentations of the same data:

ax1 = plt.subplot(1,2,1 , projection = "polar")
ax1.plot(theta, r)
ax1.set_title(f'Rose Curve (n={n})')
ax1.set_position((0,0,0.5,1))

ax2 = plt.subplot(1,2,2)
ax2.plot(theta, r)
ax2.set_title(f'Plot (n={n})')
ax2.set_position((0.7,0,0.5,1))        

You may have noticed I didn't utilize ax.polar() in my code due to an error:

AttributeError: 'Axes' object has no attribute 'polar'        

Instead, I employed ????????????????????="??????????" alongside the plot() function. The projection parameter defines the coordinate system or projection for subplots. Setting projection="polar" configures the axes to a polar coordinate system for circular plots so the subplot ax1 functions specifically within polar coordinates for polar-type plotting capabilities.

Polar offers circular visualization; Cartesian offers linear simplicity
Polar offers circular visualization; Cartesian offers linear simplicity

In polar (left), the rose curve shows symmetric patterns but can be complex to interpret. In Cartesian (right), it's linear but lacks the circular clarity of polar plots. Polar offers circular visualization; Cartesian offers linear simplicity.

quiver()

The ????????????() method is primarily designed for visualizing 2D vector fields. It plots arrows on a flat plane, representing the magnitude and direction of vectors at each data point.

Unlike many tutorials, I begin by specifying customized input parameters to illustrate how quiver works, as the default values can be somewhat confusing to explain. The following example code depicts a vector in Cartesian space. This is not an extraordinary plot; rather, it's a representation of a standard vector in a Cartesian system that you are familiar with from school years.

ax = plt.subplot()
ax.quiver([-1,0,1],[0,0,0],[-1,1,1],[1,-1,1],scale=1, scale_units="xy")
ax.set_aspect("equal")
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.grid()        
A quiver plot representing vectors with specified components (-1, 1), (1, -1), and (1, 1) at respective coordinates (-1, 0), (0, 0), and (1, 0). The plot is set to have an equal aspect ratio, ranging from -3 to 3 on both the x and y axes, with a grid for reference
Visualization of Vectors in Cartesian Space

Let's break down each part of the code:

ax.quiver([-1, 0, 1], [0, 0, 0], [-1, 1, 1], [1, -1, 1], scale=1, scale_units="xy")        

This line generates the quiver plot. The four arrays provided as arguments represent the coordinates and components of vectors. The syntax is as follows:

  1. The first array [-1, 0, 1] represents the x-coordinates of the vector starting points.
  2. The second array [0, 0, 0] represents the y-coordinates of the vector starting points.
  3. The third array [-1, 1, 1] represents the x-components of the vectors.
  4. The fourth array [1, -1, 1] represents the y-components of the vectors.
  5. The parameters scale=1 and scale_units="xy" specify that the lengths of the vectors should be directly proportional to the specified components.

If we do not set the scale and scale_units parameters, we will encounter a surprising result because quiver uses a built-in algorithm to calculate the best values for scale. For example, if we do not set either the scale or scale_units parameters, the result will be:

For some scenarios, this visualization might be appropriate, but in our case, we want to see the vectors in accordance with grid lines.

Lets go through the rest of the code because they are also essential to have a clean and correct plot:

  • ax.set_aspect("equal"): This line ensures that the aspect ratio of the plot is set to "equal," meaning that one unit in the x-direction is equal to one unit in the y-direction. This is useful to prevent distortion in the visualization.
  • ax.set_xlim(-3, 3): Sets the x-axis limits of the plot to be between -3 and 3.
  • ax.set_ylim(-3, 3): Sets the y-axis limits of the plot to be between -3 and 3.
  • ax.grid(): Adds a grid to the plot for better reference.

In summary, the code creates a quiver plot with three vectors, sets an equal aspect ratio, limits the plot to a specific range in both x and y directions, and adds a grid for better visualization.

What I showed was only meant to help you understand the parameters. However, a quiver plot is rarely used to visualize just a few vectors. It is commonly employed to plot a plane with numerous vectors in physics and electronics to illustrate data such as airflow, electric flow, and any other type of data represented by vectors (with a starting point and magnitude). Adding many vectors one by one is not efficient, so we use meshgrid. Here is an example:

# Creating x and y arrays
x = np.arange(0, 2, 0.2)
y = np.arange(0, 2, 0.2)

# Creating u and v components using meshgrid function
X, Y = np.meshgrid(x, y)
u = np.cos(X)*Y
v = np.sin(Y)*Y
 
# creating plot
fig, ax = plt.subplots(figsize =(14, 8))
ax.quiver(X, Y, u, v)
 
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylim([-0.3, 2.3])
ax.set_xlim([-0.3, 2.3])
ax.set_aspect('equal')        

In the given code, np.meshgrid(x, y) is used to create a grid of points in the form of coordinate matrices X and Y. np.meshgrid(x, y) takes the 1-dimensional arrays x and y and returns two 2-dimensional arrays (X and Y). These arrays represent the grid of points where vectors will be plotted.

The u and v components of the vectors are then calculated based on the values of X and Y. In this example, the vectors have components related to trigonometric functions and the values of X and Y.

Using meshgrid for quiver() method

imshow(), hexbin() , matshow()

All three methods—????????????(), ????????????(), and ??????????????()—are designed for visualizing 2D data or arrays and allow customization through the colormap (cmap) parameter for color mapping. However, their primary applications differ:

  • ????????????() is primarily used for displaying images or 2D datasets, commonly associated with photographic or scientific visualizations.
  • ????????????() is specifically crafted for creating hexagonal binning plots, valuable for visualizing the density distribution of large datasets.
  • ??????????????() is intended for displaying matrices or 2D arrays, often employed in the context of heatmaps or similar visualizations.

In summary, although all three methods share a common purpose of visualizing 2D data, they excel in specific applications, providing tailored features for diverse data types and visualization requirements.

The following example illustrates how ????????????() works:

from PIL import Image
img = Image.open("pic.jpeg")
data = np.array(img)

red_channel = np.zeros_like(data)
red_channel[:,:,0] = data[:,:,0]

green_channel = np.zeros_like(data)
green_channel[:,:,1] = data[:,:,1]

blue_channel = np.zeros_like(data)
blue_channel[:,:,2] = data[:,:,2]

fig, axes = plt.subplots(2,2)
axes[0,0].imshow(data)
axes[0,1].imshow(red_channel)
axes[1,0].imshow(green_channel)
axes[1,1].imshow(blue_channel)        

The code opens an image, separates its RGB channels (red, green, and blue), and displays each channel as well as the original image in a 2x2 grid using Matplotlib's ????????????().

Matplotlib's imshow()

While ????????????() is commonly used to display images, it can be applied to various types of data, not limited to RGB images. You can use it to visualize grayscale images, heatmaps, 2D arrays, or any data where a color mapping can be meaningful. It's a versatile function in Matplotlib suitable for displaying a wide range of visualizations beyond just photographs.

For instance, the next code generates two random datasets (x and y) with 500 points each, drawn from normal distributions. It then creates a side-by-side comparison of a scatter plot (on the left) and an ????????????() plot (on the right) using Matplotlib.

x = np.random.normal(10,10,500)
y = np.random.normal(5,5,500)

# Create a scatter plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(x, y, c='blue', alpha=0.7)
plt.title('Scatter Plot')

# Create an imshow plot
plt.subplot(1, 2, 2)
plt.imshow(np.histogram2d(x, y, bins=10)[0])
plt.title('Imshow Plot')        

The ????????????() plot visualizes the density of points in the 2D space, providing a different perspective than the scatter plot, which shows individual data points.

Using imshow() to display the density of data in 2D space

??????????????() is very similar to ????????????() with few differences.

matshow() automatically adjusts the aspect ratio and adds tick labels to the axes by default, making it convenient for visualizing matrices or 2D datasets. With ????????????(), you may need to manually adjust the aspect ratio or add tick labels depending on your requirements.

Choose ????????????() if:

  1. You need more flexibility and control over individual settings.
  2. You're visualizing arbitrary images beyond heatmaps and matrices.
  3. You prefer smooth interpolation for non-uniform data.

Use ??????????????() when:

  1. You're primarily working with heatmaps or representing 2D matrices.
  2. You want the convenience of preconfigured settings for heatmap visualization.
  3. You don't need fine-grained control over interpolation or aspect ratio.

In the previous examples, we explored how ????????????() and ??????????????() seamlessly display 2D arrays. You say how I utilized If you have raw data and aim to create a heatmap or visualize its distribution, ??????????.?????????????????????? to compute a 2D histogram for subsequent display using these functions. However, for a more automatic and handy approach in creating hexagonal binning plots, ????????????() proves to be a convenient alternative. ????????????() does the binning automatically so you do not need an extra step of creating a histogram:

x = np.random.randn(1000)
y = np.random.randn(1000)

# Create a hexagonal binning plot
plt.hexbin(x, y, gridsize=20, cmap='viridis')
plt.colorbar(label='Count')  # Add a colorbar for reference

# Set labels and title
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Hexagonal Binning Plot')        
hexbin is suitable for efficiently visualizing the density distribution of large datasets

3D Visualization

For 3D plotting, we can use ordinary methods like scatter and plot with the ????????????????????="????" parameter to indicate the three-dimensional aspect of the plot. In addition to these general-purpose functions, there are also specific functions tailored for 3D plotting, such as plot_surface for creating 3D surface plots, plot_wireframe for 3D wireframe plots, scatter3D for 3D scatter plots, and bar3d for 3D bar plots. These specialized functions provide more control and options for creating visually appealing and informative 3D visualizations.

Lets start by utilizing the scatter() method in 3d space:

ax = plt.subplot(projection="3d")
x = np.random.random(20)
y = np.random.random(20)
z = np.random.random(20)
ax.scatter(x,y,z)        

This code creates a 3D scatter plot with randomly generated data.

3D scatter plot

bar3d()

You can use ??????????() when you have three categorical variables and want to show how a particular quantity varies across them. For example, if you have data on sales quantity (z-axis) across different months (x-axis) and different regions (y-axis):

from scipy.stats import multivariate_normal
mvn = multivariate_normal([8,4],[[4,5],[0,2]])
months , regions = np.meshgrid(np.linspace(1,12,12), np.linspace(1,5,5))
points = np.c_[months.ravel(), regions.ravel()]
sales_quantity = mvn.pdf(points)
fig = plt.figure()

ax = plt.axes(projection='3d')
ax.bar3d(months.ravel(),regions.ravel(),np.zeros_like(months).ravel(),np.ones_like(months).ravel(),np.ones_like(months).ravel(),sales_quantity.ravel())
ax.set_title('Sales quantity')
ax.set_xlabel("Time (Month1-12)")
ax.set_ylabel("Regions 1-6")        

The provided code generates a 3D bar chart using a bivariate normal distribution to model sales quantity over time (months) and regions. A bivariate normal distribution (????????????????????????_????????????) is created with a mean of [8, 4] and a covariance matrix [[4, 5], [0, 2]]. Then, ????.???????????????? is used to create arrays for months (1-12) and regions (1-5). Next, the ????.??_ function is employed to combine the meshgrid arrays into a single array of points. After that, the probability density function (pdf) of the bivariate normal distribution is calculated at each point in the combined array. Finally, a 3D bar chart is created using ????.??????????(), where the x, y, and z coordinates are determined by the meshgrid arrays, and the bar heights are determined by the calculated sales quantity.


3D Bar chart is used to visualize data that spans three dimensions.

The resulting plot visually represents the sales quantity over time (months) and regions as a 3D bar chart. Each bar's height corresponds to the sales quantity at a specific combination of time and region according to the bivariate normal distribution.

plot_surface()

????????_??????????????() is also available to create a 3D axis using projection='3d'. While ??????????() is suitable for displaying categorical data in a 3D space using bars, ????????_??????????????() Ideal for visualizing continuous functions or datasets on a 3D surface.

from scipy.stats import multivariate_normal
mvn = multivariate_normal([1,2],[[8,7],[6,7]])
x , y = np.meshgrid(np.linspace(-10,10,100),np.linspace(-10,10,100))
points = np.c_[x.ravel(), y.ravel()]
z = mvn.pdf(points)
fig = plt.figure()

ax = plt.axes(projection='3d')
ax.plot_surface(x,y,z.reshape(100, 100), cmap='viridis',\
                edgecolor='green')
ax.set_title('Surface plot')        

The provided code is generating a 3D surface plot of the probability density function (PDF) of a multivariate normal distribution. The use of a 3D plot is appropriate because the multivariate normal distribution is characterized by two dimensions (mean vector [1, 2]), and the surface plot visualizes the probability density in a 3D space. Pay attention that, the z values, representing the PDF, are flattened during the calculation to match the shape of the grid points. The reshape(100, 100) is used to restore the original shape of the grid for proper visualization in the 3D surface plot.

3D surface plot visualizing the probability density function of a multivariate normal distribution

plot_wireframe()

plot_wireframe is very similar to plot_surface(). It is ideal for representing the overall structure of a surface. This method uses lines to connect points on the surface without filling in the areas between them, making it suitable for emphasizing the surface's structure. This is an example:

x , y = np.meshgrid(np.linspace(-10,10,100),np.linspace(-10,10,100))
z = x**2+y**2
fig = plt.figure()

ax = plt.axes(projection='3d')
ax.plot_wireframe(x,y,z.reshape(100, 100), cmap='viridis',\
                edgecolor='green')
ax.set_title('Surface plot')
plt.show()        

This code generates a 3D wireframe plot of a surface defined by the equation z = x**2 + y**2. Here's the result plot:

plot_wireframe is ideal for representing the overall structure of a mathematic equation.

Conclusion

In conclusion, this article on Matplotlib covers various plotting techniques and functions for both 2D and 3D visualizations. I started by introducing the functional and object-oriented APIs, highlighting the benefits of the latter for intricate visualizations. The article delves into creating figures, subplots, and customizing them using Matplotlib.

I aimed to provide detailed explanations and examples for common 2D plotting functions like plot, scatter, bar, pie, and more. Special attention is given to histograms, boxplots, and annotations for effective data representation. The distinction between imshow(), hexbin(), and matshow() is clarified, offering insights into their specific use cases.

In the realm of 3D plotting, I introduced scatter plots, bar charts using bar3d(), and surface plots using plot_surface() and plot_wireframe(). I demonstrated how to leverage these functions for visualizing data that spans three dimensions.

I didn't mean to write a complete guide or provide a manual for you. I just wanted to share some ideas about data visualization using Matplotlib. I brought up some challenges and solutions to help you get accustomed to the atmosphere. Indeed, this is just assistance; you may encounter many difficulties and ideas that I haven't mentioned in this article. However, I honestly believe that numerous examples, clear code snippets, and practical tips make this article a valuable resource for anyone working with Matplotlib for data visualization. Thanks for reading this guide!


要查看或添加评论,请登录

社区洞察

其他会员也浏览了