如何手动为分类变量的类型指定颜色?

2024-10-03 17:19:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我在下面的代码创建2个绘图。我在JobDomain列中有categories值

  • 第1类
  • 第2类
  • 第三类

下面的代码为上述类别生成两个不同颜色的绘图。我需要保持这两个绘图有这3类相同的颜色。你知道吗

colors = ["#F28E2B", "#4E79A7","#79706E"]

edu = (df.groupby(['JobDomain'])['sal']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal'))

coding = (df.groupby(['JobDomain'])['sal2']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal2'))

fig, axs = plt.subplots(ncols=2,figsize=(20, 6),sharey=True)

plt.subplots_adjust(wspace=0.4)

p=sns.barplot(x="sal",y="Percentage",hue="JobDomain",data=edu,
              ax=axs[0],palette=sns.color_palette(colors))
q=sns.barplot(x="sal2",y="Percentage",hue="JobDomain",data=coding,
              ax=axs[1],palette=sns.color_palette(colors))

Tags: 代码true绘图df颜色colorsedugroupby
1条回答
网友
1楼 · 发布于 2024-10-03 17:19:23

通过创建一个字典,将每个cathegory映射到一种颜色(并将其传递给palette,而不调用sns.color_palette)。举个例子:

import seaborn as sns
from pandas import DataFrame
from matplotlib import pyplot as plt

df = DataFrame({'JobDomain': ['Cat1', 'Cat2', 'Cat3', 'Cat1', 'Cat3'],
                'sal':       [  110,     90,    100,    200,    130],
                'sal2':      [  100,    280,    320,    240,    440]
                })

colors = {'Cat1': "#F28E2B", 'Cat2': "#4E79A7", 'Cat3': "#79706E"}

edu = (df.groupby(['JobDomain'])['sal']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal'))
coding = (df.groupby(['JobDomain'])['sal2']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal2'))

fig, axs = plt.subplots(ncols=2,figsize=(20, 6),sharey=True)
plt.subplots_adjust(wspace=0.4)
p = sns.barplot(x="sal",y="Percentage",hue="JobDomain",data=edu,ax=axs[0],palette=colors)
q = sns.barplot(x="sal2",y="Percentage",hue="JobDomain",data=coding,ax=axs[1],palette=colors)

h, l = p.get_legend_handles_labels()
l, h = zip(*sorted(zip(l, h)))
p.legend(h, l, title="Job Domain")
q.legend(h, l, title="Job Domain")

plt.show()

PS:要再次排序图例,请在plt.show()之前插入:

h, l = p.get_legend_handles_labels()
l, h = zip(*sorted(zip(l, h)))
p.legend(h, l, title="Job Domain")
q.legend(h, l, title="Job Domain")

Example plot

相关问题 更多 >