使用ImageDataGenerator,matplotlib抛出TypeError:图像数据的无效形状(1、256、256、3)

2024-06-18 11:53:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我有15张汽车图像,使用数据增强,我想用它们创建一个数据集。然而,当我使用Keras的ImageDataGenerator并试图绘制生成的图像时,我得到一个错误,即

TypeError: Invalid shape (1, 256, 256, 3) for image data.

我附加的代码,以及,请让我知道我如何可以解决这个问题

datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.1, rotation_range=25, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, horizontal_flip=True)

ite = datagen.flow_from_directory("Car Images", batch_size=1)

for i in range(9):

    # define subplot
    plt.subplot(330 + 1 + i)

    # generate batch of images
    batch = ite.next()

    # convert to unsigned integers for viewing
    image = batch[0].astype('uint8')

    # plot raw pixel data
    plt.imshow(image)

# show the figure
plt.show()

错误指向plt.imshow()

This is showing when I use np.squeeze() or np.reshape()


Tags: 数据图像imagefordatashift错误batch
2条回答

对于Invalid Shape错误,应该删除批处理维度。因此,需要重塑或使用np.squeeze()

此外,由于您正在通过1./255重新缩放图像,因此图像数据在范围[0,1]内,将其转换为uint8将使所有图像都为零。因此,将for循环中的最后两行更改如下:

image = batch[0]                  #remove astype('uint8')
# plot raw pixel data
plt.imshow(np.squeeze(image))     #remove batch dimension

你需要重塑图像,试试这个

datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.1, rotation_range=25, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, horizontal_flip=True)

ite = datagen.flow_from_directory("Car Images", batch_size=1)

for i in range(9):

    # define subplot
    plt.subplot(330 + 1 + i)

    # generate batch of images
    batch = ite.next()

    # convert to unsigned integers for viewing
    image = batch[0].astype('uint8')

    image = np.reshape(256,256,3)

    # plot raw pixel data
    plt.imshow(image)

# show the figure
plt.show()

相关问题 更多 >