标识seaborn或matplotlib使用的默认调色板的名称

2024-10-01 09:18:38 发布

您现在位置:Python中文网/ 问答频道 /正文

enter image description here 这是调色板,seaborn在默认情况下使用带有分类变量的列为分散点着色。
有没有办法获取所用调色板的名称或颜色? 我在一开始就得到了这个配色方案,但一旦我在绘图中使用了一个diff方案,我就不能在同一个图表中使用这个调色板。 这不是来自sns.color_palette的方案。这也可以是matplotlib颜色方案

最小可再现示例

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

# download data
df = pd.read_csv("https://www.statlearning.com/s/Auto.csv")
df.head()

# remove rows with "?"
df.drop(df.index[~df.horsepower.str.isnumeric()], axis=0, inplace=True)
df['horsepower'] = pd.to_numeric(df.horsepower, errors='coerce')

# plot 1 (gives the desired color-palette)
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);

# plot 2
# Converting column cylinder to factor before using for 'color'
df.cylinders = df.cylinders.astype('category')

# Scatter plot - Cylinders as hue
pal = ['#fdc086','#386cb0','#beaed4','#33a02c','#f0027f']
col_map = dict(zip(sorted(df.cylinders.unique()), pal))
fig = px.scatter(df, y='mpg', x='year', color='cylinders', 
                 color_discrete_map=col_map, 
                 hover_data=['name','origin'])
fig.update_layout(width=800, height=400, plot_bgcolor='#fff')
fig.update_traces(marker=dict(size=8, line=dict(width=0.2,color='DarkSlateGrey')),
                  selector=dict(mode='markers'))
fig.show()

# plot 1 run again
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);


Tags: columnstoimportmapdfplotasfig
2条回答

一种方法是使用sns.set().返回,但这并没有告诉我们配色方案的名称

在第一个图形中,cylinders是一个类型为int64的连续变量,seaborn使用单一颜色,在本例中为紫色,并对其进行着色以指示值的比例,因此8个圆柱体的颜色将比4暗。这是故意的,所以你可以很容易地通过颜色的阴影来分辨什么是什么

一旦转换为“分类”,圆柱体值之间就不再存在这种关系,即8个圆柱体不再是4个圆柱体的两倍,它们本质上是两个完全不同的类别。为了避免将颜色的明暗度与变量的比例相关联(因为值不再连续且关系不存在),默认情况下将使用分类调色板,以便每种颜色都不同于其他颜色

为了解决您的问题,在使用运行最终图表之前,您需要将cylinders转换回int64

df.cylinders = df.cylinders.astype('int64')

这将使变量恢复为连续变量,并允许seaborn使用相同颜色的渐变来表示值的大小,最终的绘图将与第一个一样

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import warnings
warnings.filterwarnings("ignore")

# download data
df = pd.read_csv("https://www.statlearning.com/s/Auto.csv")
df.head()

# remove rows with "?"
df.drop(df.index[~df.horsepower.str.isnumeric()], axis=0, inplace=True)
df['horsepower'] = pd.to_numeric(df.horsepower, errors='coerce')

# plot 1 (gives the desired color-palette)
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);

# plot 2
# Converting column cylinder to factor before using for 'color'
df.cylinders = df.cylinders.astype('category')

# Scatter plot - Cylinders as hue
pal = ['#fdc086','#386cb0','#beaed4','#33a02c','#f0027f']
col_map = dict(zip(sorted(df.cylinders.unique()), pal))
fig = px.scatter(df, y='mpg', x='year', color='cylinders', 
                 color_discrete_map=col_map, 
                 hover_data=['name','origin'])
fig.update_layout(width=800, height=400, plot_bgcolor='#fff')
fig.update_traces(marker=dict(size=8, line=dict(width=0.2,color='DarkSlateGrey')),
                  selector=dict(mode='markers'))
fig.show()

# plot 1 run again
df.cylinders = df.cylinders.astype('int64')
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);

输出

相关问题 更多 >