Python热图:扭曲颜色映射

2024-09-28 20:17:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一组价值观,比如(想象一下这些都是小部件销售):

year | Sarah | Elizabeth | Jones | Robert |
-------------------------------------------
2003 | 11    | 0         |  0    | 0      |
2004 | 16    | 0         |  0    | 6      |
2005 | 12    | 0         |  4    | 11     |
2006 | 33    | 0         |  0    | 3      |
2007 | 18    | 0         |  0    | 0      |
2008 | 18    | 0         |  0    | 0      |
2009 | 110   | 0         |  0    | 0      |
2010 | 83    | 0         |  0    | 0      |
2011 | 1553  | 20        |  25   | 0      |
2012 | 785   | 27        |  0    | 186    |
2013 | 561   | 73        |  0    | 3      |

我用seaborn和matplotlib做了一个热图,但是不幸的是,大量的数字占据了主导地位,所以它看起来就像是一个非常黑暗的方块,在非常苍白的方块中。在

有没有一种方法可以用分段函数进行颜色映射,例如,将由[0,200],然后(550,1600]组成的整个范围映射成一个线性的、连续的颜色值?不幸的是,到目前为止我看到的所有颜色贴图都是预先设置好的。在


Tags: matplotlib颜色部件数字seabornyearrobert方块
1条回答
网友
1楼 · 发布于 2024-09-28 20:17:48

似乎您正在为Colormap Normalization寻找discrete bounds方法。下面是一个使用您的数据的示例:

import StringIO

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import pandas as pd
import seaborn as sns

s = """year Sarah Elizabeth Jones Robert
2003 11 0 0 0
2004 16 0 0 6
2005 12 0 4 11
2006 33 0 0 3
2007 18 0 0 0
2008 18 0 0 0
2009 110 0 0 0
2010 83 0 0 0
2011 1553 20 25 0
2012 785 27 0 186
2013 561 73 0 3"""
df = pd.read_table(StringIO.StringIO(s), sep=' ', header=0, index_col=0)

fig = plt.figure(figsize=(16, 6))
sns.set(font_scale=2) 

ax1 = plt.subplot(121)
sns.heatmap(df, ax=ax1)
ylabels = ax1.get_yticklabels()
plt.setp(ylabels, rotation=0)

ax2=plt.subplot(122)
bounds = np.concatenate((np.linspace(0, 200, 1050), np.linspace(550, 1600, 1050)))
norm = colors.BoundaryNorm(boundaries=bounds, ncolors=256)
sns.heatmap(df, ax=ax2, norm=norm)
ylabels = ax2.get_yticklabels()
plt.setp(ylabels, rotation=0)
cbar = ax2.collections[0].colorbar
cbar_ticks = np.concatenate((np.arange(0, 201, 50), np.arange(700, 1601, 200)))
cbar.set_ticks(cbar_ticks)
cbar.set_ticklabels(cbar_ticks)

fig.tight_layout()
plt.show()

enter image description here

相关问题 更多 >