Tensor`当未启用紧急执行时,对象是不可iterable的。若要迭代此张量,请使用“tf.map”`

2024-10-06 13:51:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图创建自己的损失函数:

def custom_mse(y_true, y_pred):
    tmp = 10000000000
    a = list(itertools.permutations(y_pred))
    for i in range(0, len(a)): 
     t = K.mean(K.square(a[i] - y_true), axis=-1)
     if t < tmp :
        tmp = t
     return tmp

它应该创建预测向量的置换,并返回最小的损失。

   "`Tensor` objects are not iterable when eager execution is not "
TypeError: `Tensor` objects are not iterable when eager execution is not enabled. To iterate over this tensor use `tf.map_fn`.

错误。我找不到这个错误的任何来源。为什么会这样?

谢谢你的帮助。


Tags: 函数trueobjectsis错误notiterabletmp
1条回答
网友
1楼 · 发布于 2024-10-06 13:51:02

发生此错误的原因是y_pred是一个张量(在没有紧急执行的情况下不可iterable),并且itertools.permutations期望iterable从中创建置换。此外,计算最小损耗的部分也不起作用,因为张量t的值在图形创建时是未知的。

我不会排列张量,而是创建索引的排列(这是在图形创建时可以做的事情),然后从张量收集排列的索引。假设您的Keras后端是TensorFlow,并且y_true/y_pred是二维的,那么您的loss函数可以实现如下:

def custom_mse(y_true, y_pred):
    batch_size, n_elems = y_pred.get_shape()
    idxs = list(itertools.permutations(range(n_elems)))
    permutations = tf.gather(y_pred, idxs, axis=-1)  # Shape=(batch_size, n_permutations, n_elems)
    mse = K.square(permutations - y_true[:, None, :])  # Shape=(batch_size, n_permutations, n_elems)
    mean_mse = K.mean(mse, axis=-1)  # Shape=(batch_size, n_permutations)
    min_mse = K.min(mean_mse, axis=-1)  # Shape=(batch_size,)
    return min_mse

相关问题 更多 >