Modern Pandas (Part 6): Visualization


This is part 6 in my series on writing modern idiomatic pandas.

This post is available as a Jupyter notebook


A few weeks ago, the R community went through some hand-wringing about plotting packages. For outsiders (like me) the details aren't that important, but some brief background might be useful so we can transfer the takeaways to Python. The competing systems are "base R", which is the plotting system built into the language, and ggplot2, Hadley Wickham's implementation of the grammar of graphics. For those interested in more details, start with

The most important takeaways are that

  1. Either system is capable of producing anything the other can
  2. ggplot2 is usually better for exploratory analysis

Item 2 is not universally agreed upon, and it certainly isn't true for every type of chart, but we'll take it as fact for now. I'm not foolish enough to attempt a formal analogy here, like "matplotlib is python's base R". But there's at least a rough comparison: like dplyr/tidyr and ggplot2, the combination of pandas and seaborn allows for fast iteration and exploration. When you need to, you can "drop down" into matplotlib for further refinement.

Overview

Here's a brief sketch of the plotting landscape as of April 2016. For some reason, plotting tools feel a bit more personal than other parts of this series so far, so I feel the need to blanket this who discussion in a caveat: this is my personal take, shaped by my personal background and tastes. Also, I'm not at all an expert on visualization, just a consumer. For real advice, you should listen to the experts in this area. Take this all with an extra grain or two of salt.

Matplotlib

Matplotlib is an amazing project, and is the foundation of pandas' built-in plotting and Seaborn. It handles everything from the integration with various drawing backends, to several APIs handling drawing charts or adding and transforming individual glyphs (artists). I've found knowing the pyplot API useful. You're less likely to need things like Transforms or artists, but when you do the documentation is there.

Matplotlib has built up something of a bad reputation for being verbose. I think that complaint is valid, but misplaced. Matplotlib lets you control essentially anything on the figure. An overly-verbose API just means there's an opportunity for a higher-level, domain specific, package to exist (like seaborn for statistical graphics).

Pandas' builtin-plotting

DataFrame and Series have a .plot namespace, with various chart types available (line, hist, scatter, etc.). Pandas objects provide additional metadata that can be used to enhance plots (the Index for a better automatic x-axis then range(n) or Index names as axis labels for example).

And since pandas had fewer backwards-compatibility constraints, it had a bit better default aesthetics. The matplotlib 2.0 release will level this, and pandas has deprecated its custom plotting styles, in favor of matplotlib's1.

At this point, I see pandas DataFrame.plot as a useful exploratory tool for quick throwaway plots.

Seaborn

Seaborn, created by Michael Waskom, "provides a high-level interface for drawing attractive statistical graphics." Seaborn gives a great API for quickly exploring different visual representations of your data. We'll be focusing on that today.

Bokeh

Bokeh is a (still under heavy development) visualization library that targets the browser.

Like matplotlib, Bokeh has a few APIs at various levels of abstraction. They have a glyph API for drawing single or arrays of glyphs (circles, rectangles, polygons, etc.). More recently they introduced a Charts API, for producing canned charts from data structures like dicts or DataFrames.

Other Libraries

This is a (probably incomplete) list of other visualization libraries that I don't know enough about to comment on

It's also possible to use Javascript tools like D3 directly in the Jupyter notebook, but we won't go into those today.

Examples

I do want to pause and explain the type of work I'm doing with these packages. The vast majority of plots I create are for exploratory analysis, helping me understand the dataset I'm working with. They aren't intended for the client (whoever that is) to see. Occasionally that exploratory plot will evolve towards a final product that will be used to explain things to the client. In this case I'll either polish the exploratory plot, or rewrite it in another system more suitable for the final product (in D3 or Bokeh, say, if it needs to be an interactive document in the browser).

Now that we have a feel for the overall landscape (from my point of view), let's delve into a few examples. We'll use the diamonds dataset from ggplot2. You could use Vincent Arelbundock's RDatasets package to find it (pd.read_csv('http://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/diamonds.csv')), but I wanted to checkout feather.

%load_ext rpy2.ipython
%%R
suppressPackageStartupMessages(library(ggplot2))
library(feather)
write_feather(diamonds, 'diamonds.fthr')
import feather
df = feather.read_dataframe('diamonds.fthr')
df.head()
carat cut color clarity depth table price x y z
0 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
1 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
2 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
3 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
4 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53940 entries, 0 to 53939
Data columns (total 10 columns):
carat      53940 non-null float64
cut        53940 non-null category
color      53940 non-null category
clarity    53940 non-null category
depth      53940 non-null float64
table      53940 non-null float64
price      53940 non-null int32
x          53940 non-null float64
y          53940 non-null float64
z          53940 non-null float64
dtypes: category(3), float64(6), int32(1)
memory usage: 2.8 MB

Notice that the categorical dtypes are preserved.

Bokeh

import bokeh.charts as bc
import bokeh.plotting as bk

Bokeh provides two APIs, a low-level glyph API and a higher-level Charts API. Here's an example of the charts API.

fig = (df.assign(xy = df.x / df.y)
         .sample(n=500)
         .pipe(bc.Scatter, "xy", "price"))
bk.show(fig)

It's not clear to me where the scientific community will come down on Bokeh for exploratory analysis. The ability to share interactive graphics is compelling. The trend towards more and more analysis and communication happening in the browser will only enhance this feature of Bokeh.

Personally though, I have a lot of inertia behind matplotlib so I haven't switched to Bokeh for day-to-day exploratory analysis.

I have greatly enjoyed Bokeh for building dashboards and webapps with Bokeh server. It's still young, and I've hit some rough edges, but I'm happy to put up with some awkwardness to avoid writing more javascript.

Matplotlib

Since it's relatively new, I should point out that matplotlib 1.5 added support for plotting labeled data.

fig, ax = plt.subplots()

ax.scatter(x='carat', y='depth', data=df, c='k', alpha=.15)

png

This isn't limited to just DataFrames. It supports anything that uses __getitem__ (square-brackets) with string keys. Other than that, I don't have much to add to the matplotlib documentation.

Pandas Built-in Plotting

The metadata in DataFrames gives a bit better defaults on plots.

df.plot.scatter(x='carat', y='depth', c='k', alpha=.15)

png

We get axis labels from the column names. Nothing major, just nice.

Pandas can be more convenient for plotting a bunch of columns with a shared x-axis (the index), say several timeseries.

from pandas_datareader import fred

gdp = fred.FredReader(['GCEC96', 'GPDIC96'], start='2000-01-01').read()
gdp.rename(columns={"GCEC96": "Government Expenditure",
                    "GPDIC96": "Private Investment"}).plot(figsize=(12, 6))

png

Most DataFrame.plot methods will take a boolean subplots keyword, for drawing on separate axes within the same figure. This can be nicer than setting those up manually with matplotlib, though as we'll see next seaborn has an even better API here.

Seaborn

The rest of this post will focus on seaborn, and why I think it's especially great for exploratory analysis.

I would encourage you to read Seaborn's introductory notes, which describe its design philosophy and attempted goals. Some highlights:

Seaborn aims to make visualization a central part of exploring and understanding data.

It does this through a consistent, understandable (to me anyway) API.

The plotting functions try to do something useful when called with a minimal set of arguments, and they expose a number of customizable options through additional parameters.

Which works great for exploratory analysis, with the option to turn that into something more polished if it looks promising.

Some of the functions plot directly into a matplotlib axes object, while others operate on an entire figure and produce plots with several panels.

The fact that seaborn is built on matplotlib means that if you are familiar with the pyplot API, your knowledge will still be useful.

Most seaborn plotting functions (one per chart-type) take an x, y, hue, and data arguments (only some are required, depending on the plot type). If you're working with DataFrames, you'll pass in strings referring to column names, and the DataFrame for data.

sns.countplot(x='cut', data=df)
sns.despine()

png

sns.barplot(x='cut', y='price', data=df)
sns.despine()

png

Bivariate relationships can easily be explored, either one at a time:

sns.jointplot(x='carat', y='price', data=df, size=8, alpha=.25,
              color='k', marker='.')

png

Or many at once

g = sns.pairplot(df, hue='cut')

png

pairplot is a convenience wrapper around PairGrid, and offers our first look at an important seaborn abstraction, the Grid. Seaborn Grids provide a link between a matplotlib Figure with multiple axes and features in your dataset.

There are two main ways of interacting with grids. First, seaborn provides convenience-wrapper functions like pairplot, that have good defaults for common tasks. If you need more flexibility, you can work with the Grid directly by mapping plotting functions over each axes.

def core(df, α=.05):
    mask = (df > df.quantile(α)).all(1) & (x < df.quantile(1 - α)).all(1)
    return df[mask]
cmap = sns.cubehelix_palette(as_cmap=True, dark=0, light=1, reverse=True)

(df.select_dtypes(include=[np.number])
   .pipe(core)
   .pipe(sns.PairGrid)
   .map_upper(plt.scatter, marker='.', alpha=.25)
   .map_diag(sns.kdeplot)
   .map_lower(plt.hexbin, cmap=cmap, gridsize=20)
)

png

PairGrid is a special case of Grid faceting by each (x, y) combination. Next we'll use FacetGrid, another class for producing Grids, with control over how each facet (individual axes) gets determined. In this example, we'll facet by cut.

g = sns.FacetGrid(df, row='cut', aspect=4, size=1.76, margin_titles=True)
g.map(sns.kdeplot, 'price', shade=True, color='k')
for ax in g.axes.flat:
    ax.yaxis.set_visible(False)
sns.despine(left=True)
g.fig.subplots_adjust(hspace=0.1)
g.set(xlim=(0, 15000))

png

This last example shows the tight integration with matplotlib. g.axes is an array of matplotlib.Axes and g.fig is a matplotlib.Figure. This is a pretty common pattern when using seaborn: use a seaborn plotting method (or grid) to get a good start, and then adjust with matplotlib as needed.

I think (not an expert on this at all) that one thing people like about the grammar of graphics is its flexibility. You aren't limited to a fixed set of chart types defined by the library author. Instead, you construct your chart by layering scales, aesthetics and geometries. And using ggplot2 in R is a delight.

That said, I wouldn't really call what seaborn / matplotlib offer that limited. You can create pretty complex charts suited to your needs.

agged = df.groupby(['cut', 'color']).mean().sort_index().reset_index()

g = sns.PairGrid(agged, x_vars=agged.columns[2:], y_vars=['cut', 'color'],
                 size=5, aspect=.65)
g.map(sns.stripplot, orient="h", size=10, palette='Blues_d')

png

g = sns.FacetGrid(df, col='color', hue='color', col_wrap=4)
g.map(sns.regplot, 'carat', 'price')

png

Initially I had many more examples showing off seaborn, but I'll spare you. Seaborn's documentation is thorough (and just beautiful to look at).

We'll end with a nice scikit-learn integration for exploring the parameter-space on a GridSearch object.

from sklearn.ensemble import RandomForestClassifier
from sklearn.grid_search import GridSearchCV

For those unfamiliar with machine learning or scikit-learn, the basic idea is your algorithm (RandomForestClassifer) is trying to maximize some objective function (percent of correctly classified items in this case). There are various hyperparameters that affect the fit. We can search this space by trying out a bunch of possible values for each parameter with the GridSearchCV estimator.

df = sns.load_dataset('titanic')

clf = RandomForestClassifier()
param_grid = dict(max_depth=[1, 2, 5, 10, 20, 30, 40],
                  min_samples_split=[2, 5, 10],
                  min_samples_leaf=[2, 3, 5])
est = GridSearchCV(clf, param_grid=param_grid, n_jobs=4)

y = df['survived']
X = df.drop(['survived', 'who', 'alive'], axis=1)

X = pd.get_dummies(X, drop_first=True)
X = X.fillna(value=X.median())
est.fit(X, y)

Let's unpack the scores (a list of tuples) into a DataFrame.

scores = est.grid_scores_
rows = []
params = sorted(scores[0].parameters)
for row in scores:
    mean = row.mean_validation_score
    std = row.cv_validation_scores.std()
    rows.append([mean, std] + [row.parameters[k] for k in params])
scores = pd.DataFrame(rows, columns=['mean_', 'std_'] + params)

And visualize it, seeing that max-depth should probably be at least 10.

sns.factorplot(x='max_depth', y='mean_', data=scores, col='min_samples_split',
               hue='min_samples_leaf')

png

Thanks for reading! I want to reiterate at the end that this is just my way of doing data visualization. Your needs might differ, meaning you might need different tools. You can still use pandas to get it to the point where it's ready to be visualized!

As always, feedback is welcome.


  1. OK, technically I just broke it when fixing matplotlib 1.5 compatibility, so we deprecated it after the fact.