ImageDataGenerator不喜欢我的时尚主义者数据集。它需要哪些输入?

2024-09-30 01:32:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我从tensorflow基本图像分类指南中得到了一组名为train\u images和train\u labels的图像:

https://www.tensorflow.org/tutorials/keras/classification

我加载数据集时使用:

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

这两个列表的形状分别是: (60000,28,28)(60000,)

之后,我想使用ImageDataGenerator水平翻转一些图像,但是当我将模型与我的火车列表相匹配时,它会返回一个错误,说x应该是一个秩4的数组

我已经试过了

train_images = (np.expand_dims(train_images,0))

所以形状变成(160000,28,28) (我必须这样做,才能让模型检查单个图像) 但它不适合模型

以下是代码的其余部分:

aug = ImageDataGenerator(rotation_range=20, horizontal_flip=True)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28,28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
    ])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
    )

BS=32
EPOCHS=10
H = model.fit_generator(
    aug.flow(train_images, train_labels, batch_size=BS),
    validation_data=(test_images, test_labels),
    steps_per_epoch=len(train_images) // BS,
    epochs=EPOCHS)

这就是产生的错误:

---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-65-e49da92bcb89> in <module>()
      5 #train_images.shape
      6 H = model.fit_generator(
----> 7         aug.flow(train_images, train_labels, batch_size=BS),
      8         validation_data=(test_images, test_labels),
      9         steps_per_epoch=len(train_images) // BS,

1 frames
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/numpy_array_iterator.py in __init__(self, x, y, image_data_generator, batch_size, shuffle, sample_weight, seed, data_format, save_to_dir, save_prefix, save_format, subset, dtype)

    115             raise ValueError('Input data in `NumpyArrayIterator` '
    116                              'should have rank 4. You passed an array '
--> 117                              'with shape', self.x.shape)
    118         channels_axis = 3 if data_format == 'channels_last' else 1
    119         if self.x.shape[channels_axis] not in {1, 3, 4}:

ValueError: ('Input data in `NumpyArrayIterator` should have rank 4. You passed an array with shape', (60000, 28, 28))

实际上,列车图像是(N°的图像,宽度,高度)什么是第四轴它正在等待? 如何执行此操作?你知道吗


Tags: in模型test图像datalabelsmodelbs
2条回答

通道应该是4D张量的最后一个维度。因此,请尝试使用train_images = (np.expand_dims(train_images, -1)),而不是train_images = (np.expand_dims(train_images,0))。希望能有帮助。你知道吗

你应该把你的图像转换成4D张量。现在您有了NHW格式(批量尺寸、高度、宽度)。错误提示您应该使用NHWC格式—批处理、高度、宽度、通道。所以你需要

train_images = (np.expand_dims(train_images, axis=3))

这将添加一个通道尺寸(大小为1),得到的形状将是(60000,28,28,1),它应该可以解决您的问题。你知道吗

相关问题 更多 >

    热门问题