调整热图的颜色间隔并删除颜色条标记

2024-09-29 17:15:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用以下代码制作此热图:


breast_cancer = load_breast_cancer()
data = breast_cancer.data
features = breast_cancer.feature_names
df = pd.DataFrame(data, columns = features)
df_small = df.iloc[:,:6]
correlation_mat = df_small.corr()

#Create color pallete: 

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb =matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict
 
#https://www.december.com/html/spec/colorshades.html

hc = ['#e5e5ff', '#acacdf', '#7272bf', '#39399f', '#000080','#344152']
th = [0, 0.2, 0.4, 0.6, 0.8,1]

cdict = NonLinCdict(th, hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)


#plot correlation matrix:

plt.figure(figsize = (12,10))

ax=sns.heatmap(correlation_mat,center=0, linewidths=.2, annot = True,cmap=cm , vmin=-1, vmax=1,cbar=True)

plt.title("title", y=-1.5,fontsize = 18)

plt.xlabel("X_parameters",fontsize = 18)

plt.ylabel("Y_paramaters",fontsize = 18)

ax.tick_params(axis='both', which='both', length=0)

#choose colors 
#change ticks size and remove colorbar ticks
#ad saving option
#change to 5 portions instead of four (0.2,0.4,0.6,0.8)

plt.show()

我有两个未解决的问题: 1-如何删除颜色条记号? 2-如果vmax=1且v min=-1),如何将颜色和颜色栏的亮度设置为(-1,-0.8,-0.6,-0.4,-0.2,0,0.2,0.6,0.8,1)

这是我在所附照片中的当前输出。Output


Tags: dfdata颜色pltgreenbluergbred
1条回答
网友
1楼 · 发布于 2024-09-29 17:15:48

您可以通过ax.collections[0].colorbar获取颜色栏。从那里可以更改记号属性

下面是一个简单的例子:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def NonLinCdict(steps, hexcol_array):
    cdict = {'red': (), 'green': (), 'blue': ()}
    for s, hexcol in zip(steps, hexcol_array):
        rgb = matplotlib.colors.hex2color(hexcol)
        cdict['red'] = cdict['red'] + ((s, rgb[0], rgb[0]),)
        cdict['green'] = cdict['green'] + ((s, rgb[1], rgb[1]),)
        cdict['blue'] = cdict['blue'] + ((s, rgb[2], rgb[2]),)
    return cdict

hc = ['#e5e5ff', '#acacdf', '#7272bf', '#39399f', '#000080', '#344152']
hc = hc[:0:-1] + hc  # prepend a reversed copy, but without repeating the central value
cdict = NonLinCdict(np.linspace(0, 1, len(hc)), hc)
cm = matplotlib.colors.LinearSegmentedColormap('test', cdict)

fig = plt.figure(figsize=(8, 6))
ax = sns.heatmap(np.random.uniform(-1, 1, (10, 10)), center=0, linewidths=.2, annot=True, fmt='.2f', cmap=cm, vmin=-1, vmax=1, cbar=True)
ax.tick_params(axis='both', which='both', length=0)

cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=15, axis='both', which='both', length=0)
cbar.set_ticks(np.linspace(-1, 1, 11))
# cbar.set_label('correlation')

plt.show()

example plot

相关问题 更多 >

    热门问题