Making prettier plots is part matter-of-taste, part an appreciation for optical perception. These days, there are a number of things you can do to make prettier plots. The guiding philosophy for these bits of advice is that it's better to start with little, and add more elements to the plot only if they actually add information (see the work of Edward Tufte).
# we'll use the pythonic pyplot interface
import matplotlib.pyplot as plt
# necessary for the notebook to render the plots inline
%matplotlib inline
import numpy as np
np.random.seed(42)
x = np.linspace(0, 40, 1000)
y = np.sin(np.linspace(0, 10*np.pi, 1000))
y += np.random.randn(len(x))
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f3da78f4290>]
The default plots created with matplotlib aren't bad, but they do have elements that are, at best, unnecessary. At worst, these elements detract from the display of quantitative information. We want to change this. First, let's change matplotlib's style with a built-in style sheet.
# this gives us a style and color palette similar to ggplot2
plt.style.use('ggplot')
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f3da7801090>]
This produces something a bit more pleasing to the eye, with what probably amounts to better coloration. It replaced the box with a grid, however. Although this is useful for some plots, in particular panels of plots in which one needs to compare across different plots, it is often just unnecessary noise.
Using seaborn can help with this. We'll import seaborn's helper functions without importing its style:
# import seaborn's helpful functions without applying its style
import seaborn.apionly as sns
# importing seaborn can sometimes reset matplotlib's style to default
plt.style.use('ggplot')
# this will remove the noisy grids the 'ggplot' style gives
sns.set_style('ticks')
plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7f3d996614d0>]
It almost looks like we took two steps forward and one step back: now we have a box again. But seaborn provides a useful function for removing axis lines: despine
.
plt.plot(x, y)
sns.despine()
We can also go further, moving the axes a bit so they distract even less from the data, which should be front-and-center.
plt.plot(x, y)
sns.despine(offset=10)
Now let's do some refining. Figures for exploratory work can be any size that's convenient, but when making figures for a publication, you must consider the real size of the figure in the final printed form. Considering a page in the U.S. is typically 8.5 x 11 inches, a typical figure should be no more than 4 inches wide to fit in a single column of the page. We can adjust figure sizes by giving matplotlib a bit more detail:
fig = plt.figure(figsize=(4, 2))
ax = fig.add_subplot(1,1,1)
ax.plot(x, y)
sns.despine(offset=10, ax=ax)
# let's add some axes labels to boot
ax.set_ylabel(r'displacement ($\AA$)')
ax.set_xlabel('time (ns)')
<matplotlib.text.Text at 0x7f3d9936bed0>
We added some axes labels, too. Because this is a timeseries, we deliberately made the height of the figure less than the width. This is because timeseries are difficult to interpret when the variations with time are smashed together. Tufte's general rule is that no line in the timeseries be greater than 45$^\circ$; we would have a hard time doing that here with such noisy data, but going wider than tall is a step in the right direction.
We can save figures in a variety of formats. It's useful to save a version as a PDF so that it can be postprocessed using vector graphics tools like Inkscape and Adobe Illustrator, but because vector graphics must be rendered by the viewer on load, it's useful to also write out a PNG.
PNGs are raster graphics: they are just a matrix of pixels with four components (red, green, blue, and alpha (transparency)). This means they are quick to render with your favorite viewer, even if the plot originally had hundreds of thousands of points. However, they are not so great for making posters and final publication-quality figures, since they cannot be scaled to any size like vector graphics.
fig = plt.figure(figsize=(4, 2))
ax = fig.add_subplot(1,1,1)
ax.plot(x, y)
sns.despine(offset=10, ax=ax)
# let's add some axes labels to boot
ax.set_ylabel(r'displacement ($\AA$)')
ax.set_xlabel('time (ns)')
fig.savefig('testfigure.pdf')
fig.savefig('testfigure.png', dpi=300)
We can view the resulting PNG directly:
from IPython.display import Image
Image(filename='testfigure.png')
Woah...something's wrong. The figure doesn't fit in the frame! This is because the figure elements were adjusted after the figure object was created, and so some of these elements, including the axis labels, are beyond the figure's edges. We can usually fix this with a call to plt.tight_layout
to ensure everything fits in the plots we write out.
fig = plt.figure(figsize=(4, 2))
ax = fig.add_subplot(1,1,1)
ax.plot(x, y)
sns.despine(offset=10, ax=ax)
# let's add some axes labels to boot
ax.set_ylabel(r'displacement ($\AA$)')
ax.set_xlabel('time (ns)')
# and now we'll also refine the y-axis ticks a bit
ax.set_ylim(-4.5, 4.5)
ax.set_yticks(np.linspace(-4, 4, 5))
ax.set_yticks(np.linspace(-3, 3, 4), minor=True)
plt.tight_layout()
fig.savefig('testfigure.pdf')
fig.savefig('testfigure.png', dpi=300)
Image(filename='testfigure.png')