如何在Tensorflow中实现条件双射体(主要是如何使用kwargs问题,因为这会让我出错)

2024-09-29 23:25:06 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个条件bijector。如果您不知道它是什么并不重要,但本质上我的代码是:

import tensorflow as tf
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
from math import log, exp
tfb = tfp.bijectors
import pickle as pk
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
import os

class varFamBij2(tf.keras.models.Model):
    def __init__(self, *, output_dim, **kwargs): #** additional arguments for the super class
        super().__init__(**kwargs)
        self.output_dim = output_dim
        num_bijectors = 5
        bijectors=[]
        for i in range(num_bijectors):
            bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1, event_shape=self.output_dim, hidden_units=[32, 32], conditional=True, conditional_event_shape= 13)))
            bijectors.append(tfb.Permute(permutation=[1,0]))
        bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1, event_shape=self.output_dim, hidden_units=[32, 32], conditional=True, conditional_event_shape= 13)))

        #A bijector is formed by chaining together many layers of bijectors
        self.bijector = tfb.Chain(bijectors)


          
x1 = tf.ones([2])
x2 = tf.ones([13])
mod11 = varFamBij2(output_dim=2)
predictions = mod11.bijector.forward(x1, conditional_input = x2)

conditional_input = x2是一个夸格。基本上我得到了这个错误:

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/masked_autoregressive.py in call(self, x, conditional_input)
   1049       if self._conditional:
   1050         if conditional_input is None:
-> 1051           raise ValueError('`conditional_input` must be passed as a named '
   1052                            'argument')
   1053         conditional_input = tf.convert_to_tensor(

ValueError: 'conditional_input' must be passed as a named argument

问题是函数call(self, x, conditional_input)有一个条件输入,这个条件输入应该按照TF文档以**kwargs的形式输入(至少根据我对kwargs非常糟糕的理解),我认为**kwargs没有作为参数输入到条件输入(因为条件_输入的默认值是None,我认为这会引起错误)

我不认为有必要对TensorFlow有一个非常详细的了解来回答这个问题。我认为我无法理解和使用kwargs是导致这个程序不起作用的原因。很好奇是否有人能建议使用kwargs(或其他方法)的方法这样调用方法将接受我的条件输入。谢谢

卡梅隆


Tags: importselfinputoutputtftensorflowas条件

热门问题