<p>在第一个图形中,<code>cylinders</code>是一个类型为<code>int64</code>的连续变量,seaborn使用单一颜色,在本例中为紫色,并对其进行着色以指示值的比例,因此8个圆柱体的颜色将比4暗。这是故意的,所以你可以很容易地通过颜色的阴影来分辨什么是什么</p>
<p>一旦转换为“分类”,圆柱体值之间就不再存在这种关系,即8个圆柱体不再是4个圆柱体的两倍,它们本质上是两个完全不同的类别。为了避免将颜色的明暗度与变量的比例相关联(因为值不再连续且关系不存在),默认情况下将使用分类调色板,以便每种颜色都不同于其他颜色</p>
<p>为了解决您的问题,在使用运行最终图表之前,您需要将<code>cylinders</code>转换回<code>int64</code></p>
<p><code>df.cylinders = df.cylinders.astype('int64')</code></p>
<p>这将使变量恢复为连续变量,并允许seaborn使用相同颜色的渐变来表示值的大小,最终的绘图将与第一个一样</p>
<pre><code>import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import warnings
warnings.filterwarnings("ignore")
# download data
df = pd.read_csv("https://www.statlearning.com/s/Auto.csv")
df.head()
# remove rows with "?"
df.drop(df.index[~df.horsepower.str.isnumeric()], axis=0, inplace=True)
df['horsepower'] = pd.to_numeric(df.horsepower, errors='coerce')
# plot 1 (gives the desired color-palette)
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);
# plot 2
# Converting column cylinder to factor before using for 'color'
df.cylinders = df.cylinders.astype('category')
# Scatter plot - Cylinders as hue
pal = ['#fdc086','#386cb0','#beaed4','#33a02c','#f0027f']
col_map = dict(zip(sorted(df.cylinders.unique()), pal))
fig = px.scatter(df, y='mpg', x='year', color='cylinders',
color_discrete_map=col_map,
hover_data=['name','origin'])
fig.update_layout(width=800, height=400, plot_bgcolor='#fff')
fig.update_traces(marker=dict(size=8, line=dict(width=0.2,color='DarkSlateGrey')),
selector=dict(mode='markers'))
fig.show()
# plot 1 run again
df.cylinders = df.cylinders.astype('int64')
fig = sns.PairGrid(df, vars=df.columns[~df.columns.isin(['cylinders','origin','name'])].tolist(), hue='cylinders')
plt.gcf().set_size_inches(17,15)
fig.map_diag(sns.histplot)
fig.map_upper(sns.scatterplot)
fig.map_lower(sns.kdeplot)
fig.add_legend(ncol=5, loc=1, bbox_to_anchor=(0.5, 1.05), borderaxespad=0, frameon=False);
</code></pre>
<p>输出</p>