在TensorF中对多个运行使用图像批处理

2024-06-13 12:54:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试在TensorFlow中实现一个输入管道,它保存了由于不同网络部件中的多个权重更新而进行的多个图形运行的输入批。在

我想我可以用一个条件包装输入管道:

# flag to skip image fetch 
forwarding_network = tf.placeholder(tf.bool, [], name='forwarding_network')

input_images = None # image queue from input pipeline, must be set in real
input_labels = None # label queue from input pipeline, must be set in real

INPUT_HEIGHT = 64 # Height of the images/labels
WIDTH_HEIGHT = 64 # Width of the images/labels

# Fetch new batch from input pipeline
def forwardIR():
    image_batch_fetch, label_batch_fetch = tf.train.batch([input_images, input_labels], \
                                                          batchsize=32, capacity=64)

    with tf.variable_scope('im_reader_forward'):
        image_batch = tf.get_variable("image_batch ", shape=[32, INPUT_HEIGHT, INPUT_WIDTH, 3], \
                      dtype=tf.float32, trainable=False, \
                      initializer=tf.constant_initializer(0.0))

        image_batch = tf.assign(image_batch, image_batch_fetch)

        label_batch = tf.get_variable("label_batch ", shape=[32, INPUT_HEIGHT, INPUT_WIDTH, 1], \
                      dtype=tf.uint8, trainable=False, \
                      initializer=tf.constant_initializer(0.0))

        label_batch = tf.assign(label_batch, label_batch_fetch)
    return image_batch, label_batch

# Hold last batch, no new fetch from pipeline
def holdIR():
    with tf.variable_scope('im_reader_forward', reuse=True):
        return tf.get_variable('image_batch', dtype=tf.float32), \
               tf.get_variable('label_batch', dtype=tf.uint8)

# Switch: If forwarding_network == True, fetch new images from queue; else not)
image_batch, label_batch = tf.cond(forwarding_network, lambda: forwardIR(), lambda: holdIR())

# calculate loss with batch
net = Model(image_batch)
loss = net.predict()

我的问题是,训练开始时没有任何错误或失败,但什么也没有发生。也许变量和网络操作之间没有联系?条件的输出直接输入网络模型。在


Tags: fromimageinputlabelspipelinetfbatchnetwork
2条回答

好吧,比我想象的要容易得多。-.—:天

首先运行一个tf会话,然后通过占位符将输出输入到训练迭代中,从而评估图像/标签获取的部件。在

## define input pipeline, network, loss calculation, session, ...

image_batch_out, label_batch_out = sess.run([image_batch_ir, label_batch_ir])

feed_dict = { image_batch : image_batch_out, label_batch : label_batch_out }

loss_1, _ = sess.run([loss_val_1, train_op_1], feed_dict=feed_dict)
loss_2, _ = sess.run([loss_val_2, train_op_2], feed_dict=feed_dict)
loss_3, _ = sess.run([loss_val_3, train_op_3], feed_dict=feed_dict)

正如评论中提到的,根本不需要变量。:)

为了补充您的答案,您可以直接将image_batch_ir和{}张量提供给使用占位符作为输入的ops。 例如,如果您的旧代码是:

image_batch_ir, label_batch_ir = ...
image_batch = tf.placeholder(...)
label_batch = tf.placeholder(...)
loss_val = some_ops(image_batch, label_batch)
image_batch_out, label_batch_out = sess.run([image_batch_ir, label_batch_ir])

feed_dict = { image_batch : image_batch_out, label_batch : label_batch_out }

loss = sess.run([loss_val], feed_dict=feed_dict)

您可以改为:

^{pr2}$

相关问题 更多 >