Pandas correlation matrix

I have a data set with a huge number of functions, so the analysis of the correlation matrix has become very complicated. I want to build a correlation matrix that we get using the dataframe.corr() function from the pandas library. Is there a built-in function provided by the pandas library to build this matrix?

+163
python matplotlib pandas data-visualization
Apr 03 '15 at 12:57
source share
7 answers

You can use pyplot.matshow() from matplotlib :

 import matplotlib.pyplot as plt plt.matshow(dataframe.corr()) plt.show() 



Edit:

In the comments there was a request on how to change axis labels. Here, the luxurious version, which is drawn on a larger size of the figure, has axis labels corresponding to the data frame, and the legend of the color scale for interpreting the color scale.

I turn on how to adjust the size and rotation of the labels, and I use the ratio of numbers, thanks to which the color bar and the main figure are the same height.

 f = plt.figure(figsize=(19, 15)) plt.matshow(df.corr(), fignum=f.number) plt.xticks(range(df.shape[1]), df.columns, fontsize=14, rotation=45) plt.yticks(range(df.shape[1]), df.columns, fontsize=14) cb = plt.colorbar() cb.ax.tick_params(labelsize=14) plt.title('Correlation Matrix', fontsize=16); 

correlation plot example

+222
Apr 03 '15 at 13:04 on
source share

If your main goal is to visualize the correlation matrix, rather than creating a graph as such, convenient pandas style options are a viable built-in solution:

 import pandas as pd import numpy as np rs = np.random.RandomState(0) df = pd.DataFrame(rs.rand(10, 10)) corr = df.corr() corr.style.background_gradient(cmap='coolwarm') # 'RdBu_r' & 'BrBG' are other good diverging colormaps 

enter image description here

Note that this should be in the backend that supports HTML rendering, such as JupyterLab Notepad. (The automatic light text on a dark background is taken from the existing PR, and not from the latest released version, pandas 0.23).




styling

You can easily limit the accuracy of the numbers:

 corr.style.background_gradient(cmap='coolwarm').set_precision(2) 

enter image description here

Or get rid of numbers altogether if you prefer a matrix without annotations:

 corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'}) 

enter image description here

The style documentation also contains instructions for more complex styles, for example, how to change the display of the cell above which the mouse pointer is located. To save the output, you can return the HTML by adding the render() method and then writing it to a file (or just taking a screenshot for less formal purposes).




Time comparison

In my testing, style.background_gradient() was 4 times faster than plt.matshow() and in sns.heatmap() faster than sns.heatmap() with a 10x10 matrix. Unfortunately, it does not scale as well as plt.matshow() : they take about the same time for a 100x100 matrix, and plt.matshow() works 10 times faster for a matrix.




saving

There are several possible ways to save a stylized data frame:

  • Return the HTML by adding the render() method and then write the result to a file.
  • Save as a conditional formatted .xslx file by adding the to_excel() method.
  • Combine with imgkit to save a bitmap
  • Take a screenshot (for less formal purposes).



Update for pandas> = 0.24

By setting axis=None , you can now calculate colors based on the entire matrix, and not for a column or row:

 corr.style.background_gradient(cmap='coolwarm', axis=None) 

enter image description here

+116
Jun 05 '18 at 15:18
source share

Try this function, which also displays the variable names for the correlation matrix:

 def plot_corr(df,size=10): '''Function plots a graphical correlation matrix for each pair of columns in the dataframe. Input: df: pandas DataFrame size: vertical and horizontal size of the plot''' corr = df.corr() fig, ax = plt.subplots(figsize=(size, size)) ax.matshow(corr) plt.xticks(range(len(corr.columns)), corr.columns); plt.yticks(range(len(corr.columns)), corr.columns); 
+86
Jul 13 '15 at 13:10
source share

Version for the coolant of a sea vessel:

 import seaborn as sns corr = dataframe.corr() sns.heatmap(corr, xticklabels=corr.columns.values, yticklabels=corr.columns.values) 
+80
Oct 24 '16 at 10:45
source share

You can observe the connection between the features either by drawing a heat map from the marine or scattering matrix from pandas.

Scattering matrix:

 pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde'); 

If you want to visualize every sign of distortion, also use sea pairs.

 sns.pairplot(dataframe) 

Sns Heatmap:

 import seaborn as sns f, ax = pl.subplots(figsize=(10, 8)) corr = dataframe.corr() sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True), square=True, ax=ax) 

The output will be a correlation map of functions. those. see the example below.

enter image description here

The correlation between product and detergents is high. Similarly:

High correlation pdoducts:
  • Groceries and detergents.
Products with medium correlation:
  • Milk and Groceries
  • Milk and detergents_Folder
Low correlation products:
  • Milk and Delhi
  • Frozen and fresh.
  • Frozen and gastronomes.

From paired lines: you can observe the same set of relationships from paired or scattering matrices. But from this we can say that data is usually distributed or not.

enter image description here

Note. The above graph is taken from the data used to draw the heat map.

+70
Mar 23 '17 at 13:48 on
source share

You can use imshow () method from matplotlib

 import pandas as pd import matplotlib.pyplot as plt plt.style.use('ggplot') plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest') plt.colorbar() tick_marks = [i for i in range(len(X.columns))] plt.xticks(tick_marks, X.columns, rotation='vertical') plt.yticks(tick_marks, X.columns) plt.show() 
+6
Jun 28 '18 at 16:02
source share

If you have a df dataframe, you can simply use:

 import matplotlib.pyplot as plt import seaborn as sns plt.figure(figsize=(15, 10)) sns.heatmap(df.corr(), annot=True) 
+4
Aug 15 '19 at 21:06
source share



All Articles