Matplotlib:imshow图像的共享轴

2024-09-29 18:41:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用Matplotlib的imshow()方法绘制多个图像,并让它们共享一个y轴。虽然图像具有相同数量的y像素,但图像的高度并不相同

演示代码


import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import poisson


def ibp_oneparam(alpha, N):
    """One-parameter IBP"""

    # First customer
    Z = np.array([np.ones(poisson(alpha).rvs(1))], dtype=int)

    # ith customer
    for i in range(2, N+1):

        # Customer walks along previously sampled dishes
        z_i = []
        for previously_sampled_dish in Z.T:
            m_k = np.sum(previously_sampled_dish)
            if np.random.rand() >= m_k / i:
                # Customer decides to sample this dish
                z_i.append(1.0)
            else:
                # Customer decides to skip this dish
                z_i.append(0.0)

        # Customer decides to try some new dishes
        z_i.extend(np.ones(poisson(alpha / i).rvs(1)))
        z_i = np.array(z_i)

        # Add this customer to Z
        Z_new = np.zeros((
            Z.shape[0] + 1,
            max(Z.shape[1], len(z_i))
        ))
        Z_new[0:Z.shape[0], 0:Z.shape[1]] = Z
        Z = Z_new
        Z[i-1, :] = z_i

    return Z


np.random.seed(3)

N = 10
alpha = 2.0

#plt.figure(dpi=100)
fig, (ax1, ax2, ax3) = plt.subplots(
    1,
    3,
    dpi=100,
    sharey=True
)

Z = ibp_oneparam(alpha, N)
plt.sca(ax1)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.ylabel("Customers")
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))
plt.yticks(range(1, Z.shape[0] + 1))

Z = ibp_oneparam(alpha, N)
plt.sca(ax2)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))

Z = ibp_oneparam(alpha, N)
plt.sca(ax3)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))

plt.show()

产出

Three subplots each showing a binary image

我希望这些图像具有相同的高度和不同的宽度。我如何才能做到这一点?

旁白:上面的代码演示了Indian Buffet Process。为了这篇文章的目的,考虑这三个图像是具有相同行数但可变数目的列的随机二进制矩阵。

谢谢,


Tags: to图像importalphanewlennprange
1条回答
网友
1楼 · 发布于 2024-09-29 18:41:35

我得到了一个不错的结果,网格规格宽度比

"""fig, (ax1, ax2, ax3) = plt.subplots(
    1,
    3,
    dpi=100,
    sharey=True,
    constrained_layout=True
)"""

# I commented the above code and replaced with below.

import matplotlib.gridspec as gridspec
fig = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(ncols=3, nrows=1, figure=fig, width_ratios=[7./4.,1,6./4.])
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])
ax3 = fig.add_subplot(gs[0,2])

Resulting

需要使用宽度比来调整高度有些违反直觉,但在具有多行的网格环境中,只能按宽度独立缩放列是有意义的。按高度独立排列。 https://matplotlib.org/tutorials/intermediate/gridspec.html

相关问题 更多 >

    热门问题