无法固定多个子图中的色条位置
我有一段很大的代码,它会输出多个图,每个函数运行后都会生成一个图。我把这些图都保存在一个输出的 .png
文件里。根据不同的情况,有时候代码会保存所有函数的图,有时候只保存其中的一部分。
其中有一个图包含一个颜色条,我希望这个颜色条无论是处理了所有函数还是只处理了几个函数,都能固定在这个图的同一个位置。
我尝试了很多方法,但就是无法让颜色条在最终的 .png
文件中保持不动,尤其是当只有几个图时(更准确地说,就是当颜色条下面的图没有生成时)。
这里有一张图片来说明我的意思:
下面是一个简单的例子。为了生成第一个文件,我绘制了所有的图,而第二个文件我只是把最后八个 ax*
块注释掉了。
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
# Generate random data.
x = np.random.randn(60)
y = np.random.randn(60)
z = [np.random.random() for _ in range(60)]
fig = plt.figure(figsize=(20, 35)) # create the top-level container
gs = gridspec.GridSpec(14, 8) # create a GridSpec object
ax0 = plt.subplot(gs[0:2, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax1 = plt.subplot(gs[0:2, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax2 = plt.subplot(gs[0:2, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax3 = plt.subplot(gs[0:2, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax4 = plt.subplot(gs[2:4, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax5 = plt.subplot(gs[2:4, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax6 = plt.subplot(gs[2:4, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax7 = plt.subplot(gs[2:4, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax8 = plt.subplot(gs[4:6, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax9 = plt.subplot(gs[4:6, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax10 = plt.subplot(gs[4:6, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax11 = plt.subplot(gs[4:6, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax12 = plt.subplot(gs[6:8, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax13 = plt.subplot(gs[6:8, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax14 = plt.subplot(gs[6:8, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax15 = plt.subplot(gs[6:8, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax16 = plt.subplot(gs[8:10, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax17 = plt.subplot(gs[8:10, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax18 = plt.subplot(gs[8:10, 4:6])
cm = plt.cm.get_cmap('RdYlBu_r')
plt.scatter(x, y, s=20, c=z, cmap=cm, vmin=0, vmax=1)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
# Plot colorbar.
box = ax18.get_position()
cbar_posit = [box.x1 * 0.93, box.y1 * 0.94, 0.04, 0.005]
cbaxes = fig.add_axes(cbar_posit)
cbar = plt.colorbar(cax=cbaxes, ticks=[0, 1], orientation='horizontal')
cbar.ax.tick_params(labelsize=9)
ax19 = plt.subplot(gs[8:10, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax20 = plt.subplot(gs[10:12, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax21 = plt.subplot(gs[10:12, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax22 = plt.subplot(gs[10:12, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax23 = plt.subplot(gs[10:12, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax24 = plt.subplot(gs[12:14, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax25 = plt.subplot(gs[12:14, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax26 = plt.subplot(gs[12:14, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
ax27 = plt.subplot(gs[12:14, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
fig.tight_layout()
out_png = 'colorbar.png'
plt.savefig(out_png, dpi=150)
plt.close()
这个 MWE
复现了我在实际代码中生成输出 .png
文件的方式,所以代码看起来似乎有点多余。
因为颜色条是根据 ax18
图的位置来放置的,我本以为它应该总是能正确地放在里面,但显然不是这样。
我只想让这个颜色条在 ax18
图里面是固定的,无论我周围生成了多少其他图。这样做可能吗?
补充
好的,我把问题缩小到 fig.tight_layout()
这个调用上。当我把它注释掉时,颜色条的位置就能完美保持,无论生成多少图。缺点是最终的图看起来会差很多,邻近图的坐标轴会重叠。
有没有办法在保持 fig.tight_layout()
的同时,仍然让颜色条的位置正确呢?
2 个回答
在使用 fig.tight_layout()
之后,直接使用 fig.canvas.draw()
。即使出现 UserWarning
警告,它仍然可以正常工作。
fig.tight_layout()
fig.canvas.draw()
cax = fig.add_axes([ax.get_position().x0, ax.get_position().y0 - 0.12, ax.get_position().width, 0.02])
cbar = plt.colorbar(self.cs, orientation='horizontal', cax=cax)
我觉得这里有几个问题。
如果我理解得没错,你是想让颜色条(colorbar)相对于它所在的坐标轴保持一个固定的位置。
fig.tight_layout()
是用来调整“常规”坐标轴的位置的。这里的“常规”是指你手动创建的颜色条,它不会受到这个函数的影响。根据我所知,目前没有办法把一个坐标轴固定在另一个坐标轴上,也就是说如果你移动第一个,第二个也会跟着移动,以保持它们之间的相对位置。这就引出了一个重点:你需要在调用 tight_layout
之后设置颜色条的位置。
不过,这样做你还是会遇到问题,因为你是在设置颜色条的绝对位置,而你想要的是相对于坐标轴的固定位置。你需要指定颜色条相对于坐标轴的位置。
那么你该怎么做呢:
## ...
## ...
## Your MWE starts up there...
ax18 = plt.subplot(gs[8:10, 4:6])
cm = plt.cm.get_cmap('RdYlBu_r')
sca = plt.scatter(x, y, s=20, c=z, cmap=cm, vmin=0, vmax=1)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
# Plot colorbar.
#box = ax18.get_position()
#cbar_posit = [box.x1 * 0.93, box.y1 * 0.94, 0.04, 0.005]
#cbaxes = fig.add_axes(cbar_posit)
#cbar = plt.colorbar(cax=cbaxes, ticks=[0, 1], orientation='horizontal')
#cbar.ax.tick_params(labelsize=9)
## You keep creating and filling axes...
## ...
## ...
注意我把散点图的返回值保存在 sca
里。我们需要这个值来创建颜色条。此外,你可以删除那些注释的行,因为它们只是为了说明你不想再保留它们。
然后,在你的脚本最后,你可以创建颜色条,并指定它在坐标轴内的位置和大小。
## ...
## ...
## All the plots are created above...
fig.tight_layout() # You call fig.tight_layout BEFORE creating the colorbar
import matplotlib
# You input the POSITION AND DIMENSIONS RELATIVE TO THE AXES
x0, y0, width, height = [0.6, 0.9, 0.2, 0.04]
# and transform them after to get the ABSOLUTE POSITION AND DIMENSIONS
Bbox = matplotlib.transforms.Bbox.from_bounds(x0, y0, width, height)
trans = ax18.transAxes + fig.transFigure.inverted()
l, b, w, h = matplotlib.transforms.TransformedBbox(Bbox, trans).bounds
# Now just create the axes and the colorbar
cbaxes = fig.add_axes([l, b, w, h])
cbar = plt.colorbar(sca, cax=cbaxes, ticks=[0, 1], orientation='horizontal')
cbar.ax.tick_params(labelsize=9)
out_png = 'colorbar.png'
plt.savefig(out_png, dpi=150)
plt.close()
我很久以前在这里的一个回答中找到了坐标转换的代码,并把它保存到某个脚本里。我试着给那个回答致谢,但现在找不到了,抱歉。