在点悬停时将二维散射与图例和图像显示相结合

2024-06-28 12:14:22 发布

您现在位置:Python中文网/ 问答频道 /正文

背景

使用CNN自动编码器,我观察图像数据集的潜在空间的投影。我想将鼠标悬停在2D散点图上,并显示相应的图像。我也有图像真实的标签,并希望它作为图例(颜色分散点)

设置

我的原始图像包含在一个3D数组X_plot中,我的PCA简化数据集位于X,我有一系列与y中的图像对应的标签

X_plot.shape  = (n, 64, 64)  # n images of 64x64
X.shape       = (n, 2)       # list of 2D coordinates for each image 
y.shape       = (n, )        # n labels
# Example code to reproduce
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np

n = 20
num_classes = 4
X_plot = np.random.rand(n, 64, 64)
X = np.random.rand(n, 2)
y = np.random.randint(num_classes, size=n)

现行代码

在图像显示悬停时散射

这主要是受this answer on StackOverFlow的启发

# Split 2D coordinates into list of xs and ys
xx, yy = zip(*X)
# create figure and plot scatter
fig = plt.figure()
ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")

# create the annotations box
im = OffsetImage(X_plot[0,:,:], zoom=1, cmap='gray')
xybox=(50., 50.)
ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',
        boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="->"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)

def hover(event):
    # if the mouse is over the scatter points
    if line.contains(event)[0]:
        # find out the index within the array from the event
        ind, = line.contains(event)[1]["ind"]
        # get the figure size
        w,h = fig.get_size_inches()*fig.dpi
        ws = (event.x > w/2.)*-1 + (event.x <= w/2.) 
        hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
        # if event occurs in the top or right quadrant of the figure,
        # change the annotation box position relative to mouse.
        ab.xybox = (xybox[0]*ws, xybox[1]*hs)
        # make annotation box visible
        ab.set_visible(True)
        # place it at the position of the hovered scatter point
        ab.xy =(xx[ind], yy[ind])
        # set the image corresponding to that point
        im.set_data(X_plot[ind,:,:])
    else:
        #if the mouse is not over a scatter point
        ab.set_visible(False)
    fig.canvas.draw_idle()

# add callback for mouse moves
fig.canvas.mpl_connect('motion_notify_event', hover)           
plt.show()

用图例分散

如果我想显示点上色并标有y的二维散射,我使用以下代码:

fig = plt.figure()
ax = fig.add_subplot(111)

labels = np.unique(y)
for label in labels:
    filtered_by_label = X[y == label]
    ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)

ax.legend()
ax.axis('off')

挑战

我无法将上面的两段代码合并在一起ax.plot似乎不接受图例列表作为参数。使用第二个子解决方案中的labels循环,我需要创建line函数中使用的hover对象。然而,我考虑合并其中几个,但没有成功

有什么建议吗?谢谢


Tags: ofthe图像eventaddabplotnp
1条回答
网友
1楼 · 发布于 2024-06-28 12:14:22

我通过叠加我的两个图找到了解决办法

在以下部分中(悬浮分散):

ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")

只需添加带有图例的多个散点图

ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker="") # no marker for this one
labels = np.unique(y)
for label in labels:
    filtered_by_label = X[y == label]
    ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)

line对象仍然可以通过hover函数访问,并且点以彩色显示

相关问题 更多 >