Pandas / Pyplot scatter plots: how to build by category

I am trying to make a simple scatter plot in pyplot using a Pandas DataFrame object, but I want to have an efficient way to build two variables, but have characters dictated by the third column (key). I tried various ways using df.groupby, but not successfully. Below is an example df script. This colors the markers according to "key1", but Id loves to see a legend with the categories "key1". I am close? Thank.

import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) fig1 = plt.figure(1) ax1 = fig1.add_subplot(111) ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) plt.show() 
+67
python matplotlib pandas
Feb 09 '14 at 2:51
source share
7 answers

You can use scatter for this, but for this you need to have numerical values ​​for key1 , and you will not see the legend, as you noticed.

Better to just use plot for such discrete categories. For example:

 import matplotlib.pyplot as plt import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels)) groups = df.groupby('label') # Plot fig, ax = plt.subplots() ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling for name, group in groups: ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) ax.legend() plt.show() 

enter image description here

If you want everything to look like the default pandas style, just update rcParams the pandas stylesheet and use its color generator. (I also tweak the legend a bit):

 import matplotlib.pyplot as plt import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels)) groups = df.groupby('label') # Plot plt.rcParams.update(pd.tools.plotting.mpl_stylesheet) colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random') fig, ax = plt.subplots() ax.set_color_cycle(colors) ax.margins(0.05) for name, group in groups: ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name) ax.legend(numpoints=1, loc='upper left') plt.show() 

enter image description here

+86
Feb 09 '14 at 4:23
source share

This is easy to do with Seaborn ( pip install seaborn ) as an insert

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) :

 import seaborn as sns import pandas as pd import numpy as np np.random.seed(1974) df = pd.DataFrame( np.random.normal(10, 1, 30).reshape(10, 3), index=pd.date_range('2010-01-01', freq='M', periods=10), columns=('one', 'two', 'three')) df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8) sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) 

enter image description here

Here is the data for reference:

enter image description here

Since there are three column variables in your data, you can build all pairwise dimensions with:

 sns.pairplot(vars=["one","two","three"], data=df, hue="key1", size=5) 

enter image description here

https://rasbt.imtqy.com/mlxtend/user_guide/plotting/category_scatter/ is another option.

+42
Aug 31 '16 at 13:44
source share

With plt.scatter I can only think of one plt.scatter : use a proxy executor:

 df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) fig1 = plt.figure(1) ax1 = fig1.add_subplot(111) x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) ccm=x.get_cmap() circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)] leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1) 

And the result:

enter image description here

+19
Feb 09 '14 at 4:19
source share

You can use df.plot.scatter and pass the array to the c = argument, which determines the color of each dot:

 import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) colors = np.where(df["key1"]==4,'r','-') colors[df["key1"]==6] = 'g' colors[df["key1"]==8] = 'b' print(colors) df.plot.scatter(x="one",y="two",c=colors) plt.show() 

enter image description here

+6
Sep 17 '17 at 2:45 on
source share

You can also try Altair or ggpot , which are focused on declarative rendering.

 import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels)) 

Altair Code

 from altair import Chart c = Chart(df) c.mark_circle().encode(x='x', y='y', color='label') 

enter image description here

ggplot code

 from ggplot import * ggplot(aes(x='x', y='y', color='label'), data=df) +\ geom_point(size=50) +\ theme_bw() 

enter image description here

+3
Jul 03 '17 at 9:19 on
source share

These are pretty hacks, but you can use one1 as Float64Index to do it all in one go:

 df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True) 

enter image description here

Note that with 0.20.3, sorting the index is necessary , and the legend is a bit inconvenient .

+2
Oct 21 '17 at 20:51 on
source share

Starting with matplotlib 3.1 you can use .legend_elements() . An example is shown in Automatic Legend Creation . The advantage is that you can use one disparate call.

In this case:

 import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) fig, ax = plt.subplots() sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) ax.legend(*sc.legend_elements()) plt.show() 

enter image description here

If the keys were not indicated as numbers, it would look like

 import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = list("AAABBBCCCC") labels, index = np.unique(df["key1"], return_inverse=True) fig, ax = plt.subplots() sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8) ax.legend(sc.legend_elements()[0], labels) plt.show() 

enter image description here

0
Jun 08 '19 at 14:42
source share



All Articles