在keras中,如何应用过滤器(where)功能?

2024-06-28 11:04:45 发布

您现在位置:Python中文网/ 问答频道 /正文

像functools包中的filter函数一样,我想在张量中找到0.5以上的元素。在

这是代码,但不是工作。在

def pred_overhalf(y_true, y_pred):

    return K.count_params( filter( lambda x : x > 0.5 , y_pred ) )
model.compile(optimizer = "adam" , loss = "mse", metrics = [ pred_overhalf])

有什么办法解决这个问题吗?我搜索了keras后端文档,但找不到任何解决方案


Tags: lambda函数代码true元素modelreturndef
1条回答
网友
1楼 · 发布于 2024-06-28 11:04:45
def pred_overhalf(y_true,y_pred):
    out = K.greater(y_pred,0.5)
    out = K.cast(out,K.floatx())

    #option 1
    return K.mean(out) #fraction of items greater than 0.5

    #option 2
    return K.sum(out) #total count (beware: this will consider all samples)

相关问题 更多 >