使用CNN自动编码器,我观察图像数据集的潜在空间的投影。我想将鼠标悬停在2D散点图上,并显示相应的图像。我也有图像真实的标签,并希望它作为图例(颜色分散点)
我的原始图像包含在一个3D数组X_plot
中,我的PCA简化数据集位于X
,我有一系列与y
中的图像对应的标签
X_plot.shape = (n, 64, 64) # n images of 64x64
X.shape = (n, 2) # list of 2D coordinates for each image
y.shape = (n, ) # n labels
# Example code to reproduce
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np
n = 20
num_classes = 4
X_plot = np.random.rand(n, 64, 64)
X = np.random.rand(n, 2)
y = np.random.randint(num_classes, size=n)
这主要是受this answer on StackOverFlow的启发
# Split 2D coordinates into list of xs and ys
xx, yy = zip(*X)
# create figure and plot scatter
fig = plt.figure()
ax = fig.add_subplot(111)
line, = ax.plot(xx, yy, ls="", marker=".")
# create the annotations box
im = OffsetImage(X_plot[0,:,:], zoom=1, cmap='gray')
xybox=(50., 50.)
ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',
boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)
def hover(event):
# if the mouse is over the scatter points
if line.contains(event)[0]:
# find out the index within the array from the event
ind, = line.contains(event)[1]["ind"]
# get the figure size
w,h = fig.get_size_inches()*fig.dpi
ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
# if event occurs in the top or right quadrant of the figure,
# change the annotation box position relative to mouse.
ab.xybox = (xybox[0]*ws, xybox[1]*hs)
# make annotation box visible
ab.set_visible(True)
# place it at the position of the hovered scatter point
ab.xy =(xx[ind], yy[ind])
# set the image corresponding to that point
im.set_data(X_plot[ind,:,:])
else:
#if the mouse is not over a scatter point
ab.set_visible(False)
fig.canvas.draw_idle()
# add callback for mouse moves
fig.canvas.mpl_connect('motion_notify_event', hover)
plt.show()
如果我想显示点上色并标有y的二维散射,我使用以下代码:
fig = plt.figure()
ax = fig.add_subplot(111)
labels = np.unique(y)
for label in labels:
filtered_by_label = X[y == label]
ax.scatter(*zip(*filtered_by_label), s=12, marker='.', alpha=0.9, label=label)
ax.legend()
ax.axis('off')
我无法将上面的两段代码合并在一起ax.plot
似乎不接受图例列表作为参数。使用第二个子解决方案中的labels循环,我需要创建line
函数中使用的hover
对象。然而,我考虑合并其中几个,但没有成功
有什么建议吗?谢谢
我通过叠加我的两个图找到了解决办法
在以下部分中(悬浮分散):
只需添加带有图例的多个散点图
line
对象仍然可以通过hover
函数访问,并且点以彩色显示相关问题 更多 >
编程相关推荐