如何使用ImageDataGenerator洗牌批次?

2024-10-01 10:19:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用ImageDataGenerator和flow_from_dataframe加载数据集

使用flow_from_dataframeshuffle=True对数据集中的图像进行洗牌

我想洗牌。如果我有12个图像和batch_size=3,那么我有4个批次:

batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]

我希望在不洗牌每个批次中的图像的情况下洗牌批次,以便获得例如:

batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]

使用ImageDataGenerator和来自数据帧的流可以吗?我可以使用预处理功能吗


Tags: 数据from图像dataframeflowimage1image2image3
1条回答
网友
1楼 · 发布于 2024-10-01 10:19:33

考虑使用^{} API。您可以在洗牌之前执行批处理操作

import tensorflow as tf

file_names = [f'image_{i}' for i in range(1, 10)]

ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)

for _ in range(3):
    for batch in ds:
        print(batch.numpy())
    print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

然后,可以使用映射操作从文件名加载图像:

def read_image(file_name):
  image = tf.io.read_file(file_name)
  image = tf.image.decode_image(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
  label = tf.strings.split(file_path, os.sep)[0]
  label = tf.cast(tf.equal(label, class_categories), tf.int32)
  return image, label

ds = ds.map(read_image)

相关问题 更多 >