如何在损失函数中切片张量?

2024-09-28 03:24:48 发布

您现在位置:Python中文网/ 问答频道 /正文

在我的自定义丢失函数中,我尝试使用.numpy()将张量转换为numpy数组,但没有成功。经过一些搜索,似乎不可能在损失函数中将张量转换为numpy数组。所以我决定使用后端方法。正如您在下面的代码中所看到的,我使用了K.argmax()并返回了最大值的索引张量。然后我想切片y_真,y_pred张量

def my_mse_loss(y_true, y_pred):

    y_true_index = K.argmax(y_true, axis=-1)
    y_true_startcounter = y_true_index-3
    y_true_stopcounter = y_true_index+3

    y_pred_index = K.argmax(y_pred, axis=-1)
    y_pred_startcounter = y_pred_index-3
    y_pred_stopcounter = y_pred_index+3

    y_true_pkrange = y_true[:,y_true_index:y_true_index+6]
    y_pred_pkrange = y_pred[:,y_pred_index:y_pred_index+6]

    return K.mean(K.square(y_pred_pkrange-y_true_pkrange), axis=-1)

我在安装模型时遇到的错误: ValueError:形状的秩必须相等,但为0和2 将形状0与其他形状合并。对于“{node my_mse_loss/stack_slice/stack_2}}=Pack[N=2,T=DT_INT64,axis=0](my_mse_loss/stack_slice/stack_2/values_0,my_mse_loss/ArgMax),输入形状为:[],[?,384]


Tags: 函数numpytrueindexstackmy数组形状

热门问题