tf.keras中的批归一化不计算平均均值和平均方差

2024-05-19 07:41:12 发布

您现在位置:Python中文网/ 问答频道 /正文

有人问了一个类似的未回答问题here。 我正在测试一个深度强化学习算法,它使用了tensorflow中的keras后端。我不是很熟悉特斯拉斯,但仍希望添加批处理规范化层。因此,我尝试使用tf.keras.layers.BatchNormalization(),但它不更新平均值和方差,因为{}是空的。在

使用常规的tf.layers.batch_normalization似乎可以很好地工作。但是,因为完整的算法有点复杂,我需要找到一种使用tf.keras的方法。在

标准tfbatch_normed = tf.layers.batch_normalization(hidden, training=True)更新平均值,因为update_ops不是空的:

[
    <tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>, 
    <tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>, 
    <tf.Operation 'batch_normalization_1/AssignMovingAvg' type=AssignSub>, 
    <tf.Operation 'batch_normalization_1/AssignMovingAvg_1' type=AssignSub>
]

不起作用的最小示例:

^{pr2}$

输出如下(可以看到移动平均值和移动方差不变):

{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0, 'bn_Q_moving_variance': 4.0}

而预期的输出如下所示(用batch_normed演算对行进行注释,并取消对下面的行的注释):

{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0148749575, 'bn_Q_moving_variance': 3.966927}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.029601166, 'bn_Q_moving_variance': 3.934192}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.04418011, 'bn_Q_moving_variance': 3.9017918}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.05861327, 'bn_Q_moving_variance': 3.8697228}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.0729021, 'bn_Q_moving_variance': 3.8379822}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.08704803, 'bn_Q_moving_variance': 3.8065662}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.10105251, 'bn_Q_moving_variance': 3.7754717}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.11491694, 'bn_Q_moving_variance': 3.7446957}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.12864274, 'bn_Q_moving_variance': 3.7142346}
{'bn_Q_gamma': 4.0, 'bn_Q_beta': 0.0, 'bn_Q_moving_mean': 0.14223127, 'bn_Q_moving_variance': 3.6840856}

注意

即使使用tf.layers.batch_normalization,仍然有一些可疑的东西。tf.control_dependencies的标准tf方法:

    with tf.control_dependencies(update_ops):
        sess.run(out, {inputs[0]: data})

我把它放在上面的代码中,而不是下面的两行:

    sess.run(update_ops,  {inputs[0]: data})
    sess.run(out, {inputs[0]: data})

产生bn_Q_moving_mean = 0.0和{}


Tags: layerstftypebatchmeanoperationbetakeras
2条回答

这是因为tf.keras.layers.BatchNormalization继承了tf.keras.layers.Layer。kerasapi将更新操作作为其fit和evaluate循环的一部分进行处理。这反过来意味着,如果没有它,它将无法更新tf.GraphKeys.UPDATE_OPS集合。在

所以为了让它工作,你需要手动更新它

hidden = tf.keras.layers.Dense(units, activation=None)(out)
batch_normed = tf.keras.layers.BatchNormalization(trainable=True) 
layer = batch_normed(hidden)

这将创建单独的类实例

^{pr2}$

这个更新需要收集。同时看一下https://github.com/tensorflow/tensorflow/issues/25525

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

这可以解决

^{pr2}$

错误问题。在

如果使用

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)

回归

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

是列表中的列表,就像[[某物]]

和使用

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

回归

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

是[something1,something2,…]

我认为这是解决办法。在

但结果不同,我不知道哪个是真的。在

相关问题 更多 >

    热门问题