"如何在Keras功能模型API中处理批量归一化的更新操作?"

2024-05-19 06:46:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试Keras模型的函数API,并尝试使用相同的模型和权重共享来设置两个数据集流,这也包括批处理规范化。当我在第一个数据流上创建模型,然后在第二个数据流上调用它时,我可以在tensorboard中看到额外的更新操作,但是模型不包括这些新创建的操作。我的问题是,在仍然能够使用经典的tensorflow会话、数据集迭代器和自定义损失和优化器的情况下,对于这种情况,什么是一种好的编码方法?在

我附加了一些代码,展示了我所请求的行为实际上是如何为BatchNormalization层本身工作的。不过,在这种情况下,对更新操作有更多的控制也是很好的,例如,不要迭代列表并检查名称,直接将更新操作与层本身的调用连接或返回,这样就可以将更新操作直接关联到正确的数据流。在

import tensorflow as tf
import shutil
import os.path as osp


shutil.rmtree(osp.join('/tmp', 'keras_model_testtb'), ignore_errors=True)
tb_saver = tf.summary.FileWriter(osp.join(
    '/tmp', 'keras_model_testtb',
))

input1 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
batchnorm = tf.keras.layers.BatchNormalization()
output1 = batchnorm(input1, training=tf.constant(True))
# following print shows:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
print(batchnorm.updates)
input2 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
output2 = batchnorm(input2, training=tf.constant(True))
# following print shows:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1_1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1_1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
# update ops of both layer calls are merged into one list and one has to check the name of the ops
# to use them correctly with optimizers
print(batchnorm.updates)

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    tb_saver.add_graph(session.graph)

下面是一些使用Keras模型API的示例代码。如您所见,model updates操作确实有额外的2个更新操作,如前一个示例所示,仅适用于使用keras输入层的情况。对于数据集迭代器,它似乎做了一些不同的事情。在

^{pr2}$

Tags: 数据模型tftypebatch情况operation数据流

热门问题