Jupyter笔记本互动悬停为许多数字

2024-06-28 12:25:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用MNIST数据的较小子集的降维算法,我想从每个图像所在的图中进行检查。这可以根据StackOverflow主题的答案来完成:Python show image upon hovering over a point

我想用jupyter笔记本来实现这一点,这样一次运行就能得到所有的数据。问题是,我可能需要为hover函数定义一个数字以使其正常工作。我一点也不确定这是否可能。我将用以下简单的例子来说明我的问题:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.manifold import MDS, Isomap
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib as mpl
plt.rcParams['figure.figsize'] = (15, 10)
%matplotlib notebook

# Load MNIST data
data, Y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
data, Y = np.array(data, 'int16'), np.array(Y, 'int')
data, Y = data[:100], Y[:100]

def transform_data(data):
    """
    This function makes dimensional rediction with PCA, MDS and ISOMAP and returns them in numpy list
    """
    data_pca = PCA(n_components=2).fit_transform(data)
    data_mds = MDS(n_jobs=-1).fit_transform(data)
    data_isomap = Isomap().fit_transform(data)
    transformed_datas = np.array([data_pca, data_mds, data_isomap])
    return transformed_datas

def hover(event):
    if line.contains(event)[0]:
        inds = line.contains(event)[1]['ind']
        ind = inds[0]
        w,h = fig.get_size_inches()*fig.dpi
        ws = (event.x > w/2.)*-1 + (event.x <= w/2.) 
        hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
        ab.xybox = (xybox[0]*ws, xybox[1]*hs)
        ab.set_visible(True)
        ab.xy =(x_coords[ind], y_coords[ind])
        offset_image.set_data(arr[ind,:,:])
    else:
        ab.set_visible(False)
    fig.canvas.draw_idle()


transformed_datas = transform_data(data)
algorithms_text = ['PCA','MDS','ISOMAP']
colors = np.unique(Y, return_inverse=True)[1].tolist()
arr = np.reshape(data, (100,28,28))

for i in range(transformed_datas.shape[0]):
    fig, ax = plt.subplots(figsize=(10,6))
    x_coords = transformed_datas[i][:,0]
    y_coords = transformed_datas[i][:,1]
    line = plt.scatter(x_coords,y_coords, s=30, c=colors,  cmap='jet', edgecolor='k')

    offset_image = OffsetImage(arr[0,:,:], zoom=2, cmap=plt.cm.gray_r)
    xybox = (40, 40)
    ab = AnnotationBbox(offset_image, (0,0), xybox=xybox, xycoords='data', boxcoords='offset points',  pad=0.3,  arrowprops=dict(arrowstyle='->'))
    ax.add_artist(ab)
    ab.set_visible(False)

    plt.title('2D-visualization of MNIST data with {} algorithm'.format(algorithms_text[i]), fontsize=10)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    fig.canvas.mpl_connect('motion_notify_event', hover)
    plt.show()

Tags: imageimporteventdataabnpfigtransform