我正在尝试可视化对cat图像的VGG-16预测,计算前5个分数(概率最大的5个类),对于这5个分数中的每一个,打印相应的标签和相应的概率
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing import image
# load the image from cat class and resize it
img = image.load_img('cat.jpg', target_size=(224, 224))
# convert to numpy array of (224, 224, 3)
x = image.img_to_array(img)
# add empty dimention for tensor flow (1,224,224,3)
x = np.expand_dims(x, axis = 0)
# perform mean removal as in the original VGG16 network
x = preprocess_input(x)
# make the prediction using VGG16
output = model.predict(x)
print('model prediction output', output)
# plot the prediction
plt.plot(output[0], '-')
plt.show()
# decode the prediction
from keras.applications.vgg16 import decode_predictions
top5 = decode_predictions(output)
for _, label, proba in top5[0]:
print(label, 'with probability', proba)
我收到这个错误,任何帮助都将被感激
File "C:\Users\mwaqa\Desktop\Spyder\E8 Q2,3,4,5.py", line 75, in plt.plot(output, '-')
File "c:\users\mwaqa\miniconda3\lib\site-packages\matplotlib\pyplot.py", line 2789, in plot is not None else {}), **kwargs)
File "c:\users\mwaqa\miniconda3\lib\site-packages\matplotlib\axes_axes.py", line 1665, in plot lines = [*self._get_lines(*args, data=data, **kwargs)]
File "c:\users\mwaqa\miniconda3\lib\site-packages\matplotlib\axes_base.py", line 225, in
__call__
yield from self._plot_args(this, kwargs)File "c:\users\mwaqa\miniconda3\lib\site-packages\matplotlib\axes_base.py", line 391, in _plot_args x, y = self._xy_from_xy(x, y)
File "c:\users\mwaqa\miniconda3\lib\site-packages\matplotlib\axes_base.py", line 273, in _xy_from_xy "shapes {} and {}".format(x.shape, y.shape))
ValueError: x and y can be no greater than 2-D, but have shapes (1,) and (1, 224, 224, 3)
目前没有回答
相关问题 更多 >
编程相关推荐