如何创建“input\u fn”用作“.fit()”中的参数?

2024-05-18 11:05:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试着用我标注的一些图片来训练cnn的模特。我不熟悉TensorFlow。以下是我所做的:

def read_labeled_image_list(image_list_file):
    f = open(image_list_file, 'r')
    filenames = []
    labels = []
    for line in f:
        filename, label = line[:-1].split(' ')
        filenames.append(filename)
        index0 = 1 if int(label) == 0 else 0
        index1 = 1 if int(label) == 1 else 0
        labels.append([index0, index1])
    return filenames, labels

def read_images_from_disk(input_queue):
    label = input_queue[1]
    file_contents = tf.read_file(input_queue[0])
    example = tf.image.decode_jpeg(file_contents, channels=1)
    return example, label

使用“从磁盘读取图像”作为输入:

image_list, label_list = 
          read_labeled_image_list("./images_training/training_list.txt")

images = tf.constant(image_list, dtype=tf.string)
labels = tf.constant(label_list, dtype=tf.int32)

# Makes an input queue
input_queue = tf.train.slice_input_producer([images, labels],
                                            num_epochs=30,
                                                shuffle=True)

image, label = read_images_from_disk(input_queue)

# Train the model
graph_classifier.fit(
    input_fn=read_images_from_disk(input_queue),
    steps=20000,
    monitors=[logging_hook])

我得到以下错误:

features, labels = input_fn()
TypeError: 'tuple' object is not callable

Tags: fromimagereadinputlabelsqueuetfdef
1条回答
网友
1楼 · 发布于 2024-05-18 11:05:46

错误的原因是fit方法中的input_fn参数应该是可调用的。然后您可以尝试:

def read_images_from_disk(input_queue):
    label = input_queue[1]
    file_contents = tf.read_file(input_queue[0])
    example = tf.image.decode_jpeg(file_contents, channels=1)
    return example, label

def my_input_func():
 return read_images_from_disk(input_queue)

# Train the model
graph_classifier.fit(
    input_fn=my_input_func,
    steps=20000,
    monitors=[logging_hook])

我也建议你仔细阅读 the official docinput_func上。你知道吗

相关问题 更多 >