我正在使用MNIST数据的较小子集的降维算法,我想从每个图像所在的图中进行检查。这可以根据StackOverflow主题的答案来完成:Python show image upon hovering over a point
我想用jupyter笔记本来实现这一点,这样一次运行就能得到所有的数据。问题是,我可能需要为hover函数定义一个数字以使其正常工作。我一点也不确定这是否可能。我将用以下简单的例子来说明我的问题:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.manifold import MDS, Isomap
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib as mpl
plt.rcParams['figure.figsize'] = (15, 10)
%matplotlib notebook
# Load MNIST data
data, Y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
data, Y = np.array(data, 'int16'), np.array(Y, 'int')
data, Y = data[:100], Y[:100]
def transform_data(data):
"""
This function makes dimensional rediction with PCA, MDS and ISOMAP and returns them in numpy list
"""
data_pca = PCA(n_components=2).fit_transform(data)
data_mds = MDS(n_jobs=-1).fit_transform(data)
data_isomap = Isomap().fit_transform(data)
transformed_datas = np.array([data_pca, data_mds, data_isomap])
return transformed_datas
def hover(event):
if line.contains(event)[0]:
inds = line.contains(event)[1]['ind']
ind = inds[0]
w,h = fig.get_size_inches()*fig.dpi
ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
ab.xybox = (xybox[0]*ws, xybox[1]*hs)
ab.set_visible(True)
ab.xy =(x_coords[ind], y_coords[ind])
offset_image.set_data(arr[ind,:,:])
else:
ab.set_visible(False)
fig.canvas.draw_idle()
transformed_datas = transform_data(data)
algorithms_text = ['PCA','MDS','ISOMAP']
colors = np.unique(Y, return_inverse=True)[1].tolist()
arr = np.reshape(data, (100,28,28))
for i in range(transformed_datas.shape[0]):
fig, ax = plt.subplots(figsize=(10,6))
x_coords = transformed_datas[i][:,0]
y_coords = transformed_datas[i][:,1]
line = plt.scatter(x_coords,y_coords, s=30, c=colors, cmap='jet', edgecolor='k')
offset_image = OffsetImage(arr[0,:,:], zoom=2, cmap=plt.cm.gray_r)
xybox = (40, 40)
ab = AnnotationBbox(offset_image, (0,0), xybox=xybox, xycoords='data', boxcoords='offset points', pad=0.3, arrowprops=dict(arrowstyle='->'))
ax.add_artist(ab)
ab.set_visible(False)
plt.title('2D-visualization of MNIST data with {} algorithm'.format(algorithms_text[i]), fontsize=10)
ax.set_yticklabels([])
ax.set_xticklabels([])
fig.canvas.mpl_connect('motion_notify_event', hover)
plt.show()
目前没有回答
相关问题 更多 >
编程相关推荐