我正在尝试实现一个条件bijector。如果您不知道它是什么并不重要,但本质上我的代码是:
import tensorflow as tf
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
from math import log, exp
tfb = tfp.bijectors
import pickle as pk
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
import os
class varFamBij2(tf.keras.models.Model):
def __init__(self, *, output_dim, **kwargs): #** additional arguments for the super class
super().__init__(**kwargs)
self.output_dim = output_dim
num_bijectors = 5
bijectors=[]
for i in range(num_bijectors):
bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1, event_shape=self.output_dim, hidden_units=[32, 32], conditional=True, conditional_event_shape= 13)))
bijectors.append(tfb.Permute(permutation=[1,0]))
bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1, event_shape=self.output_dim, hidden_units=[32, 32], conditional=True, conditional_event_shape= 13)))
#A bijector is formed by chaining together many layers of bijectors
self.bijector = tfb.Chain(bijectors)
x1 = tf.ones([2])
x2 = tf.ones([13])
mod11 = varFamBij2(output_dim=2)
predictions = mod11.bijector.forward(x1, conditional_input = x2)
conditional_input = x2
是一个夸格。基本上我得到了这个错误:
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/masked_autoregressive.py in call(self, x, conditional_input)
1049 if self._conditional:
1050 if conditional_input is None:
-> 1051 raise ValueError('`conditional_input` must be passed as a named '
1052 'argument')
1053 conditional_input = tf.convert_to_tensor(
ValueError: 'conditional_input' must be passed as a named argument
问题是函数call(self, x, conditional_input)
有一个条件输入,这个条件输入应该按照TF文档以**kwargs的形式输入(至少根据我对kwargs非常糟糕的理解),我认为**kwargs没有作为参数输入到条件输入(因为条件_输入的默认值是None,我认为这会引起错误)
我不认为有必要对TensorFlow有一个非常详细的了解来回答这个问题。我认为我无法理解和使用kwargs是导致这个程序不起作用的原因。很好奇是否有人能建议使用kwargs(或其他方法)的方法这样调用方法将接受我的条件输入。谢谢
卡梅隆
请参见TensorFlow GitHub页面上SiegeLordEx的回答。这解决了它:https://github.com/tensorflow/probability/issues/1159
基本上,我们必须为条件双射体编写以下代码:
相关问题 更多 >
编程相关推荐