Better Heatmaps and Correlation Matrix Plots in Python

Finding the highest negative and positive correlations mean finding the strongest red and green.

To do that I need to carefully scan the entire grid.

Try to answer it again and notice how your eyes are jumping around the plot, and sometimes going to the legend.

Now consider the following plot:In addition to color, we’ve added size as a parameter to our heatmap.

The size of each square corresponds to the magnitude of the correlation it represents, that issize(c1, c2) ~ abs(corr(c1, c2))Now try to answer the questions using the latter plot.

Notice how weak correlations visually disappear, and your eyes are immediately drawn to areas where there’s high correlation.

Also note that it’s now easier to compare magnitudes of negative vs positive values (lighter red vs lighter green), and we can also compare values that are further apart.

If we’re mapping magnitudes, it’s much more natural to link them to the size of the representing object than to its color.

That’s exactly why on bar charts you would use height to display measures, and colors to display categories, but not vice versa.

Discrete Joint DistributionsLet’s see how the cars in our data set are distributed according to horsepower and drivetrain layout.

That is, we want to visualize the following tableConsider the following two ways to do itThe second version, where we use square size to display counts makes it effortless to determine which group is the largest/smallest.

It also gives some intuition about the marginal distributions, all without needing to refer to a color legend.


So how do I make these plots?To make a regular heatmap, we simply used the Seaborn heatmap function, with a bit of additional styling.

For the second kind, there’s no trivial way to make it using Matplotlib or Seaborn.

We could use corrplot from biokit, but it helps with correlations only and isn’t very useful for two-dimensional distributions.

Building a robust parametrized function that enables us to make heatmaps with sized markers is a nice exercise in Matplotlib, so I’ll show you how to do it step by step.

We’ll start by using a simple scatter plot with squares as markers.

Then we’ll fix some issues with it, add color and size as parameters, make it more general and robust to various types of input, and finally make a wrapper function corrplot that takes a result of DataFrame.

corr method and plots a correlation matrix, supplying all the necessary parameters to the more general heatmap function.

It’s just a scatter plotIf we want to plot elements on a grid made by two categorical axes, we can use a scatter plot.

Looks like we’re onto something.

But I said it’s just a scatterplot, and there’s quite a lot happening in the previous code snippet.

Since the scatterplot requires x and y to be numeric arrays, we need to map our column names to numbers.

And since we want our axis ticks to show column names instead of those numbers, we need to set custom ticks and ticklabels.

Finally there’s code that loads the dataset, selects a subset of columns, calculates all the correlations, melts the data frame (the inverse of creating a pivot table) and feeds its columns to our heatmap function.

You noticed that our squares are placed where our gridlines intersect, instead of being centered in their cells.

In order to move the squares to cell centers, we’ll actually move the grid.

And to move the grid, we’ll actually turn off major gridlines, and set minor gridlines to go right in between our axis ticks.

That’s better.

But now the left and bottom side look cropped.

That’s because our axis lower limit are set to 0.

We’ll sort this out by setting the lower limit for both axes to — 0.


Remember, our points are displayed at integer coordinates, so our gridlines are at .

5 coordinates.

Give it some colorNow comes the fun part.

We need to map the possible range of values for correlation coefficients, [-1, 1], to a color palette.

We’ll use a diverging palette, going from red for -1, all the way to green for 1.

Looking at Seaborn color palettes, seems that we’ll do just fine with something likesns.


diverging_palette(220, 20, n=7))But lets first flip the order of colors and make it smoother by adding more steps between red and green:palette = sns.

diverging_palette(20, 220, n=256)Seaborn color palettes are just arrays of color components, so in order to map a correlation value to the appropriate color, we need to ultimately map it to an index in the palette array.

It’s a simple mapping of one interval to another: [-1, 1] → [0, 1] → (0, 255).

More precisely, here’s the sequence of steps this mapping will take:Just what we wanted.

Let’s now add a color bar on the right side of the chart.

We’ll use GridSpec to set up a plot grid with 1 row and n columns.

Then we’ll use the rightmost column of the plot to display the color bar and the rest to display the heatmap.

There are multiple ways to display a color bar, here we’ll trick our eyes by using a really dense bar chart.

We’ll draw n_colors horizontal bars, each colored with its respective color from the palette.

And we have our color bar.

We’re almost done.

Now we should just flip the vertical axis so that we get correlation of each variable with itself shown on the main diagonal, make squares a bit larger and make the background a just a tad lighter so that values around 0 are more visible.

But let’s first make the entire code more useful.

More parameters!It would be great if we made our function able to accept more than just a correlation matrix.

To do this we’ll make the following changes:Be able to pass color_min, color_max and size_min, size_max as parameters so that we can map different ranges than [-1, 1] to color and size.

This will enable us to use the heatmap beyond correlationsUse a sequential palette if no palette specified, use a single color if no color vector providedUse a constant size if no size vector provided.

Avoid mapping the lowest value to 0 size.

Make x and y the only necessary parameters, and pass size, color, size_scale, size_range, color_range, palette, marker as kwargs.

Provide sensible defaults for each of the parametersUse list comprehensions instead pandas apply and map methods, so we can pass any kind of arrays as x, y, color, size instead of just pandas.

SeriesPass any other kwargs to pyplot.

scatterplot functionMake a wrapper function corrplot that accepts a corr() dataframe, melts it, calls heatmap with a red-green diverging color palette, and size/color min-max set to [-1, 1]That’s quite a lot of boilerplate stuff to cover step by step, so here’s what it looks like when done.

You can also check it out in this Kaggle kernel.

FinallyNow that we have our corrplot and heatmap functions, in order to create the correlation plot with sized squares, like the one at the beginning of this post, we simply do the following:And just for fun, let’s make a plot showing how engine power is distributed among car brands in our data set:.

. More details

Leave a Reply