为什么我的图像加载到数据集中时都是白色的?

2024-05-18 19:55:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用图像创建了一个数据集:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1,
    image_size=(900, 900))

images = next(iter(dataset))
print(tf.shape(images))

我得到输出: 找到属于1个类的209个文件。 使用189个文件进行培训。 tf.张量([329003],shape=(4,),dtype=int32)

现在,我想看一幅带有以下内容的图像:

plt.imshow(images[19])
plt.show()

作为输出,我得到: 使用RGB数据将输入数据剪裁到imshow的有效范围([0..1]表示浮点数,[0..255]表示整数)

以及作为输出的纯白色图像

我确信加载到数据集中的图像不是纯白色的。有人能帮我吗


Tags: 文件数据图像imageimporttftensorflowas
3条回答

您的dataset是一个tf.data.Dataset,因此您可以使用此可视化功能。 https://www.tensorflow.org/datasets/api_docs/python/tfds/visualization/show_examples

tfds.visualization.show_examples(
    ds: tf.data.Dataset,
    ds_info: tfds.core.DatasetInfo,
    **options_kwargs
)

以这种方式实施:

import tensorflow_datasets as tfds

ds, info = tfds.load(<your dataset name>, split='train', with_info=True)
fig = tfds.show_examples(ds, info)

将它们转移到numpy,并转换到uint8。 这是我用来检查输入和输出的函数,我使用32的批量大小,但只打印其中的8个

def generate_images(test_input):
  prediction = decoder(encoder(test_input), training=True)
  plt.figure(figsize=(15, 15))
  for i in range(8):
    plt.subplot(4, 8, i*2+1)
    plt.imshow(test_input[i].numpy().astype("uint8") )
    plt.subplot(4, 8, i*2+2)
    plt.imshow(prediction[i] )
  plt.axis('off')
  plt.show() 

要从数据集调用此

for imgs, labels in train_ds.take(1):
  generate_images( imgs)

不是对你的问题的直接回答,但franky我更喜欢使用来自目录的ImageDataGenerator.flow_,它允许你重新缩放图像,增强图像,并且生成器的输出易于使用。文档为here.对于您的应用程序,代码为:

data_gen=tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255,
    validation_split=.01)
train_gen=data_gen.flow_from_directory(
    directory= r'c:\your_directory',
    target_size=(900,900),
    class_mode="categorical",
    batch_size=32,
    shuffle=True,
    seed=123,    
    subset='training')
valid_gen=data_gen.flow_from_directory(
    directory= r'c:\your_directory',
    target_size=(900,900),
    class_mode="categorical",
    batch_size=32,
    shuffle=False,
    seed=123,    
    subset='validation')
tr_images, tr_labels=next (train_gen) # generate a batch of 32 training images, and labels
val_images, val_labels=next(valid_gen)
print('tr_images.shape = ', tr_images.shape)
# result will be tr_images.shape =  (32, 900, 900, 3)
image1=tr_images[0]
plt.imshow(image1)
plt.show()
# result will be showing the image
# other useful things available
class_dict=train_gen.class_indices # a dictionary where key is the text class name and value is integer label of the class
print (class_dict)
labels= train_gen.labels # a sequential list of all the generator labels
file_names= test_gen.filenames  # a sequential list of all the generator file names
# hope this helps

相关问题 更多 >

    热门问题