我试图制作一个自定义损失函数,该函数计算MSE,但忽略真相低于某个阈值(接近0)的所有点。我可以通过以下方式使用numpy数组实现这一点
import numpy as np
a = np.random.normal(size=(4,4))
b = np.random.normal(size=(4,4))
temp_a = a[np.where(a>0.5)] # Your threshold condition
temp_b = b[np.where(a>0.5)]
mse = mean_squared_error(temp_a, temp_b)
但我不知道如何使用keras后端实现这一点。我的自定义损失函数不起作用,因为numpy不能对张量进行操作
def customMSE(y_true, y_pred):
'''
Correct predictions of 0 do not affect performance.
'''
y_true_ = y_true[tf.where(y_true>0.1)] # Your threshold condition
y_pred_ = y_pred[tf.where(y_true>0.1)]
mse = K.mean(K.square(y_pred_ - y_true_), axis=1)
return mse
但当我这样做时,返回的是错误
ValueError: Shape must be rank 1 but is rank 3 for '{{node customMSE/strided_slice}} = StridedSlice[Index=DT_INT64, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](cond_2/Identity_1, customMSE/strided_slice/stack, customMSE/strided_slice/stack_1, customMSE/strided_slice/Cast)' with input shapes: [?,?,?,?], [1,?,4], [1,?,4], [1].```
在loss函数中可以使用
tf.where
而不是np.where
如果你想使预测丢失,而它们对应的地面真值低于阈值,你可以编写如下自定义函数:
我将loss函数包装到另一个函数中,所以您可以将threshold作为超参数传递给模型编译
相关问题 更多 >
编程相关推荐