我正在尝试实现一个自定义的损失函数
def lossFunction(self,y_true,y_pred):
maxi=K.argmax(y_true)
return K.mean((K.max(y_true) -(K.gather(y_pred,maxi)))**2)
训练时会出现以下错误
InvalidArgumentError (see above for traceback): indices[5] = 51 is not in [0, 32) [[Node: loss/dense_3_loss/Gather = Gather[Tindices=DT_INT64, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dense_3/BiasAdd, metrics/acc/ArgMax)]]
模型概要
Argmax从最后一个轴获取,而gather从第一个轴获取。两个轴上的元素数量不同,所以这是意料之中的。在
如果只处理类,请使用最后一个轴,因此我们将围绕聚集方法进行研究:
相关问题 更多 >
编程相关推荐