应定义“密集”输入的最后一个维度。找到`None`

2024-10-03 06:20:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我对tensorflow非常陌生,并尝试为我自己的图像集制作一个简单的二元分类器。它们都是灰度226x226的PNG图像。我一直得到错误“ValueError:应该定义Dense输入的最后一个维度。找到None“。我已经坚持了好几天了,似乎无法以一种有效的方式来塑造我的模型/数据集。有人能帮一个新手吗?任何可能相关的代码应该在下面。提前谢谢。在

##img parser
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_png(image_string)
  image_decoded = tf.image.resize_images(image_decoded,[226,226])
  return image_decoded, label

#img processor function
#input: dir
#output: dataset
def imgPrcs(dir):
    labelArr = [];
    filenames = [];
    src = dir;

    for fname in os.listdir(src):
        png = os.path.join(src, fname);
        filenames.append(png);
        if os.path.isfile(png):
            #extract label
            with open(png, 'rb') as fobj:
                data = fobj.read()
            data_arr = [];
            for chunk_type, chunk_data in chunk_iter(data):
                if   chunk_type == b'iTXt':
                    data_arr.append(chunk_data.decode());
            label = int(data_arr[1][-1:]);

            #add label
            labelArr.append(label);

    labels = tf.constant(labelArr)
    filename_q = tf.constant(filenames)

    dataset = tf.data.Dataset.from_tensor_slices((filename_q, labels))
    dataset = dataset.map(_parse_function)

    #return variables
    return dataset;

#create labels and datasets
print('Compiling images and labels...\n');
trainData = imgPrcs('./train/');
testData = imgPrcs('./test/');
valData = imgPrcs('./validate/');


#Create Model
print('Creating Model...\n');
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(226, 226, None)),
    keras.layers.Dense(128, kernel_initializer='normal', activation='relu'),
    keras.layers.Dense(1,kernel_initializer='normal', activation='sigmoid')
])

print('compile...\n')
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy']);


print('train..\n')
#Train Model
model.fit(trainData.make_one_shot_iterator(), epochs=5, steps_per_epoch=385)

print('test')
#Test Model
test_loss, test_acc = model.evaluate(testData.make_one_shot_iterator());

print('Test accuracy:', test_acc);

Tags: testimagedatalabelsmodelpngtffilename