ValueError:logits和标签必须与autoencoder具有相同的形状((无,328,328,3)和(无,1))

2024-03-29 05:14:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试用以下代码构建一个自动编码器

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sys
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, MaxPool2D, Flatten, BatchNormalization
from keras.layers import Conv1D, MaxPool1D, Reshape
from keras.layers import Input, Dense, Dropout, Activation, Add, Concatenate
from keras import regularizers
from keras.models import Model, Sequential
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras.optimizers import SGD, Adam, RMSprop, Adadelta
import keras.backend as K
from keras.objectives import mean_squared_error
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils

def create_block(input, chs): ## Convolution block of 2 layers
    x = input
    for i in range(2):
        x = Conv2D(chs, 3, padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)
    return x
input_img = Input(shape=(328, 328, 3))

block1 = create_block(input_img, 32)
x = MaxPool2D(2)(block1)
block2 = create_block(x, 64)

#Middle
x = MaxPool2D(2)(block2)
middle = create_block(x, 128)

# Decoder
block3 = create_block(middle, 64)
up1 = UpSampling2D((2,2))(block3)
block4 = create_block(up1, 32)
up2 = UpSampling2D((2,2))(block4)

# output
x = Conv2D(3, 1)(up2)
output = Activation("sigmoid")(x)


autoencoder = Model(input_img, output)
autoencoder.compile(SGD(1e-3, 0.9), loss='binary_crossentropy')
autoencoder.summary()

对于我的培训数据,我使用:

img_height = 328
img_width = 328

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

当我尝试使用autoencoder.fit运行它时( 列车,, 验证数据=val\U ds, 纪元=50 )我得到了ValueError:logits和标签必须具有相同的形状((None,328,328,3)vs(None,1))

有人知道怎么解决这个问题吗


Tags: fromimageimportimginputsizelayerstf
1条回答
网友
1楼 · 发布于 2024-03-29 05:14:19

对于自动编码器,您的目标需要是您的输入。您试图最小化的错误是重建损失,因此输出需要是原始输入。默认情况下,Keras目录迭代器将返回目标(例如0和1),而不是输入

generator = tf.keras.preprocessing.image.ImageDataGenerator(
    validation_split=0.2
)

指定class_mode='input'。请注意,我似乎有一个不同的版本,所以我必须用target_size替换image_size来运行它

train_ds = generator.flow_from_directory(
  data_dir,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  class_mode='input'
)

下面是一个生成数据的完整工作示例,它使用您提供的自动编码器体系结构:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPool2D
from tensorflow.keras.layers import Activation, Dense, Input, BatchNormalization
from tensorflow.keras import Model, Sequential

def create_block(input, chs):
    x = input
    for i in range(2):
        x = Conv2D(chs, 3, padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)
    return x
input_img = Input(shape=(328, 328, 3))

block1 = create_block(input_img, 32)
x = MaxPool2D(2)(block1)
block2 = create_block(x, 64)

#Middle
x = MaxPool2D(2)(block2)
middle = create_block(x, 128)

# Decoder
block3 = create_block(middle, 64)
up1 = UpSampling2D((2,2))(block3)
block4 = create_block(up1, 32)
up2 = UpSampling2D((2,2))(block4)

# output
x = Conv2D(3, 1)(up2)
output = Activation("sigmoid")(x)

X = np.random.rand(8, 328, 328, 3).astype(np.float32)

autoencoder = Model(input_img, output)
autoencoder.compile('adam', loss='binary_crossentropy')
autoencoder.summary()

generator = tf.keras.preprocessing.image.ImageDataGenerator()

train_ds = generator.flow(
  x=X,
  y=X
)

history = autoencoder.fit(train_ds)
Train for 1 steps
1/1 [==============================] - 9s 9s/step - loss: 0.8827

相关问题 更多 >