错误消息:
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2]
在我的代码中,我创建了一个tensorflow distribution MixtureSameFamily对象,并使用网络的输出作为参数。然而,当我试图计算一系列值的概率以生成概率密度函数时,我收到了这个错误
我的代码:
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=alphas),
components_distribution=tfd.Normal(
loc=mus,
scale=sigmas
)
)
x = np.linspace(-2,2,int(1000), dtype=np.double)
print(x.dtype)
pyx = gm.prob(x)
print(x.dtype)
的结果是“dtype:'float'”
据我所知,tensorflow不支持the documentation.中的浮点数据类型
因为这个原因,我特别困惑。任何帮助都将不胜感激
似乎是最新tensorflow概率模块中的一个bug。它只适用于
float32
解决方法
显式地将参数强制转换为
float32
相关问题 更多 >
编程相关推荐