自定义DistributionLambda层中的变量不会更新

2024-09-29 23:29:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在tensorflow-probability中构建一个自定义层,然后用它来构建DenseVariational层的后部

作为第一步,我已经建立了如下后验,它相当于tutorial中使用的平均场后验,但不是学习正态分布的参数,而是学习两个双射体的参数

def posterior_trainable_bijector(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.0))

    return tf.keras.Sequential(
        [
            tfp.layers.VariableLayer(2 * n, dtype=dtype),
            tfp.layers.DistributionLambda(
                lambda t: tfp.distributions.TransformedDistribution(
                    tfd.Independent(
                        tfd.Normal(loc=tf.zeros(n), scale=tf.ones(n)),
                        reinterpreted_batch_ndims=1,
                    ),
                    tfp.bijectors.Chain(
                        bijectors=[
                            tfp.bijectors.Shift(t[..., n:]),
                            tfp.bijectors.Scale(
                                1e-5 + 0.01 * tf.math.softplus(c + t[..., :n])
                            ),
                        ]
                    ),
                )
            ),
        ]
    )

作为下一步,我认为对DistributionLambda层进行子类化是一个好主意,因为这样可以设置更复杂的双射体。 不幸的是,我的初稿似乎不起作用。更具体地说,我仍然能够运行我的代码,但似乎loc_params/scale_params在培训期间没有更新,但我不明白为什么会出现这种情况。 有什么建议吗

class LocScaleBijectorLayer(tfp.layers.DistributionLambda):
    def __init__(
        self,
        event_shape=(),
        convert_to_tensor_fn=tfd.Distribution.sample,
        validate_args=False,
        name="LocScaleBijectorLayer",
        **kwargs,
    ):

        c = np.log(np.expm1(1.0))
        with tf.name_scope(name) as name:
            loc_params = tf.Variable(
                tf.zeros(event_shape), name="loc_var", trainable=True
            )
            scale_params = tf.Variable(
                tf.ones(event_shape), name="scale_var", trainable=True
            )

            self.base_distribution = tfd.Independent(
                tfd.Normal(loc=tf.zeros(event_shape), scale=tf.ones(event_shape)),
                reinterpreted_batch_ndims=-1,
            )

            self.bijector = tfp.bijectors.Chain(
                bijectors=[
                    tfp.bijectors.Shift(loc_params),
                    tfp.bijectors.Scale(
                        1e-5 + 0.01 * tf.math.softplus(c + scale_params)
                    ),
                ]
            )

            super(LocScaleBijectorLayer, self).__init__(
                lambda t: tfp.distributions.TransformedDistribution(
                    self.base_distribution, self.bijector
                ),
                convert_to_tensor_fn,
                name=name,
                **kwargs,
            )

Tags: nameselfeventsizetfnpparamsloc

热门问题