Tensorflow无法计算Addv2,因为输入#1(基于零)应为双张量,但它是浮点张量[Op:Addv]

2024-09-28 23:38:51 发布

您现在位置:Python中文网/ 问答频道 /正文

错误消息:

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2]

在我的代码中,我创建了一个tensorflow distribution MixtureSameFamily对象,并使用网络的输出作为参数。然而,当我试图计算一系列值的概率以生成概率密度函数时,我收到了这个错误

我的代码:

gm = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=alphas),
    components_distribution=tfd.Normal(
        loc=mus,
        scale=sigmas
    )
)

x = np.linspace(-2,2,int(1000), dtype=np.double)
print(x.dtype)
pyx = gm.prob(x)

print(x.dtype)的结果是“dtype:'float'”

据我所知,tensorflow不支持the documentation.中的浮点数据类型

因为这个原因,我特别困惑。任何帮助都将不胜感激


Tags: 代码消息tensorflow错误npfloatdistributiondouble
1条回答
网友
1楼 · 发布于 2024-09-28 23:38:51

似乎是最新tensorflow概率模块中的一个bug。它只适用于float32

解决方法

显式地将参数强制转换为float32

gm = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=alphas.astype('float32')),
    components_distribution=tfd.Normal(
        loc=mus.astype('float32'),
        scale=sigmas.astype('float32')
    )
)

x = np.linspace(-2,2,int(1000), dtype='float32')
pyx = gm.prob(x)

相关问题 更多 >