Tensorflow概率:InvalidArgumentError:所需的可广播形状

2024-05-20 20:26:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我的数据具有第一种格式。当我使用tensorflow概率层时,我得到以下错误:

下面是一个示例,其中输入形状为[1,28,28],可复制代码:Gist(请确保您正在GPU上运行代码。)

InvalidArgumentError:  required broadcastable shapes
     [[node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1 (defined at <ipython-input-22-243a182981d9>:9) ]] [Op:__inference_train_function_7663]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1:
 model_3/mixture_same_family_4/MixtureSameFamily/independent_normal_4/IndependentNormal/Softplus (defined at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/distribution_layer.py:988)

Function call stack:
train_function

我不知道如何更改源代码,使其与通道第一输入形状一起工作。有人能帮我吗


Tags: 代码lognodemodeltensorflowfamily形状same
1条回答
网友
1楼 · 发布于 2024-05-20 20:26:22

您的preprocess函数返回的是image, image,而不是image, sample['label']。如果你改变这个,它应该会工作

我认为你也可以在你的损失中放弃K.cast

更新:事实上,当我运行这个时,我得到了nan的损失。可能是出了什么问题。但至少它克服了形状错误!🤷‍♂️

相关问题 更多 >