TF2.x自定义丢失函数,numpy

2024-06-28 19:52:17 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用一个自定义损失函数。我现在使用TF2.x,其中默认情况下启用了“急切执行”。我尝试了TF1.x,但是遇到了太多的问题。除了用tf.py_function()包装我的函数外,还有其他选择吗?如果不是,我该如何包装

通用:具有自定义损耗功能的自动编码器,该功能是围绕不寻常的排名差异构建的。目前我只使用scipy stats rankdata,但将来会有所改变

张量形状n,x,x,1

n图像,每个dimx,x。 因此,我想对所有n图像的每对orig,pred运行此自定义丢失函数

通用算法:

import scipy.stats as ss
def rank_loss(orig, pred):
    orig_arr = orig.numpy()  # want x,x,1
    pred_arr = pred.numpy()
    orig_rank = (ss.rankdata(orig_arr))  # returns flat array of length size of array
    pred_rank = (ss.rankdata(pred_arr))
    distance_diff = 0
    for i in range(len(orig_rank)): # gets sum of rank differences
        distance_diff = abs(orig_rank[i] - pred_rank[i])
    return distance_diff

如果我不能做到这一点,我是否仅限于可用的tf.<funcs>或者如何将张量作为数组的某种形式提取出来,以便在两个张量之间运行比较计算

我还研究了tf.make_ndarray,但这似乎不适用


Tags: of函数图像功能tfstatsdiffscipy