When exploring data and presenting results, it can be important to compare multiple visualizations side-by-side. By combining multiple subplots into one figure, it is much easier to compare and contrast the results, while keeping everything organized and digestible. One convenient way to achieve this in Python is using the subplots()
function from matplotlib
. In this post, we’ll go over the basic syntax and a couple examples using subplots()
with matplotlib
visualizations, as well as with seaborn
visualizations.
In the below example, we use a dataset from Kaggle detailing Adidas shoe sales in the United States.
fig, axs = plt.subplots(nrows, ncols)
Basic Syntax: The first thing to know about the function plt.subplots()
is that it returns multiple objects, a Figure
, usually labeled fig
, and one or more Axes
objects. If there are more than one Axes
objects, each object can be indexed as you would an array, with square brackets. The below line of code creates a 2 x 2 Figure containing 4 Axes objects.
import matplotlib.pyplot as plt
# Fig object, and array of Amxes
fig, axs = plt.subplots(nrows = 2, ncols = 2)
# Upper left
axs[0,0].set_title("Operating Margin")
# Upper right
axs[0,1].set_title("Price per Unit")
# Bottom left
axs[1,0].set_title("Units Sold")
# Bottom right
axs[1,1].set_title("Total Sales")
fig.show()
Output:

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows = 2, ncols = 2)
Basic Syntax v2: If you want to unpack the Axes
objects, you can do so using tuples.
# Fig object, ((UL, UR), (BL, BR)) axes
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows = 2, ncols = 2)
figsize
and dpi
Additional Figure arguments: In this example, we also use the figsize
and dpi
arguments to adjust the size and resolution of the final Figure object. This can be particularly useful when you are working with subplots as you are conveying more information.
# Fig object, ((UL, UR), (BL, BR)) axes
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows = 2, ncols = 2, figsize = [10, 10], dpi = 300)
Axes Syntax: ax1.hist(x = df[“col”]) and ax1.set_title(“Title”)
Once you’ve created your Figure
and Axes
, you can create plots in each Axes
object and add a title or other attributes, as below. In this example, because we created matplotlib
visualizations, we are able to create those plots using normal matplotlib
syntax. You can create any kind of plot that is available in the matplotlib
library.
# Upper left
ax1.hist(x = df["Operating Margin"])
ax1.set_title("Operating Margin")
# Upper right
ax2.hist(x = df["Price per Unit"])
ax2.set_title("Price per Unit")
# Bottom left
ax3.hist(x = df["Units Sold"].dropna())
ax3.set_title("Units Sold")
# Bottom right
ax4.hist(x = df["Total Sales"].dropna())
ax4.set_title("Total Sales")
fig.show()
Output:

seaborn
with matplotlib
subplots
Using Since seaborn
is built on top of the matplotlib
library, it is no surprise that seaborn
is fairly compatible with the matplotlib
subplots infrastructure. The key difference is that instead of calling the seaborn
plotting function on the Axes
object, you set the ax
argument within the seaborn function itself (i.e. sns.histplot(ax = ax1, x = df["Operating Margin"], bins = 20)
)
Note that in the below example, we are able to customize and rotate the axis labels on individual plots using get_xticks()
and set_xticks()
.
import seaborn as sns
sns.set_theme()
# Fig object, ((UL, UR), (BL, BR)) axes
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows = 2, ncols = 2, figsize = [10, 10], dpi = 300)
# Upper left
sns.histplot(ax = ax1, x = df["Operating Margin"], bins = 20)
# Upper right
sns.histplot(ax = ax2, x = df["Price per Unit"], bins = 20)
# Bottom left
sns.countplot(ax = ax3, x = df["Region"])
# Get and set x-axis tick locations and labels
xtick_loc = ax3.get_xticks()
xtick_labels = ax3.get_xticklabels()
ax3.set_xticks(ticks = xtick_loc, labels = xtick_labels, rotation = 45, ha = 'right', fontsize = 10)
# Bottom right
sns.countplot(ax = ax4, x = df["Retailer"])
# Get x-axis tick locations and labels
xtick_loc = ax4.get_xticks()
xtick_labels = ax4.get_xticklabels()
ax4.set_xticks(ticks = xtick_loc, labels = xtick_labels, rotation = 45, ha = 'right', fontsize = 10)
fig.show()
Output:

fig.tight_layout()
Example with Unfortunately, the subplots above need more whitespace, so we can add one more line of code using the fig.tight_layout()
function. This function pads the subplots so there is more whitespace, as a ratio of the font size. You can customize as needed.
# Fig object, ((UL, UR), (BL, BR)) axes
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows = 2, ncols = 2, figsize = [10, 10], dpi = 300)
# Upper left
sns.histplot(ax = ax1, x = df["Operating Margin"], bins = 20)
# Upper right
sns.histplot(ax = ax2, x = df["Price per Unit"], bins = 20)
# Bottom left
sns.countplot(ax = ax3, x = df["Region"])
# Get and set x-axis tick locations and labels
xtick_loc = ax3.get_xticks()
xtick_labels = ax3.get_xticklabels()
ax3.set_xticks(ticks = xtick_loc, labels = xtick_labels, rotation = 45, ha = 'right', fontsize = 10)
# Bottom right
sns.countplot(ax = ax4, x = df["Retailer"])
# Get x-axis tick locations and labels
xtick_loc = ax4.get_xticks()
xtick_labels = ax4.get_xticklabels()
ax4.set_xticks(ticks = xtick_loc, labels = xtick_labels, rotation = 45, ha = 'right', fontsize = 10)
# Pad layout
fig.tight_layout()
fig.show()
Output:

About
Einblick is an AI-native data science platform that provides data teams with an agile workflow to swiftly explore data, build predictive models, and deploy data apps. Founded in 2020, Einblick was developed based on six years of research at MIT and Brown University. Einblick is funded by Amplify Partners, Flybridge, Samsung Next, Dell Technologies Capital, and Intel Capital. For more information, please visit www.einblick.ai and follow us on LinkedIn and Twitter.