如何在tfp.layers.DistributionLambda内实例化variablecontaining分布

2024-09-30 20:33:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用tfp.layers.DistributionLambda构建一个tf.keras.Sequential模型。我正在跟踪^{} example,但是希望用一个包含tfd.TransformedDistributionRealNVPbijector的变量替换tfd.Normal

import tensorflow as tf
import tensorflow_probability as tfp


model = tf.keras.Sequential((
    tf.keras.layers.Lambda(lambda x: tf.shape(x)[-1]),
    tfp.layers.DistributionLambda(lambda t: (
        tfp.distributions.TransformedDistribution(
            distribution=(
                tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(t))),
            bijector=tfp.bijectors.RealNVP(
                num_masked=2,
                shift_and_log_scale_fn=tfp.bijectors.real_nvp_default_template(
                    hidden_layers=[32, 32]))))),
))


x = tf.random.uniform((5, 3))
distribution = model(x)

但是,这在以下情况下失败:

The layer cannot safely ensure proper Variable reuse across multiple calls, and consquently this behavior is disallowed for safety. Lambda layers are not well suited to stateful computation; instead, writing a subclassed Layer is the recommend way to define layers with Variables.

请注意RealNVPbijector的变量必须在bijector中初始化,这与Normal分布在^{} example中的变量不同,后者是在^{的顶层创建的

我想知道是否有一种方法可以将DistributionLambda用于必须在发行版中创建变量的这种设置?如果是这样,处理DistributionLambda层内变量的正确方法是什么?如果不是,那么建议用什么方法构建这样的模型


Tags: 方法模型importexamplelayerstftensorflowkeras