未排序段argmax解决方法

2024-09-26 18:04:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图创建一个tf_boolean_mask,它根据索引的值过滤张量中的重复索引。如果该值大于一个副本的值,则应保留该值,其他值则丢弃。如果索引和值相同,则只应保留一个:

[Pseudocode]
for index in indices
    If index is unique:
        keep index = True
    else:
        if val[index] > val[index of all other duplicate indices]:
            keep index = True
        elif val[index] < val[index of any other duplicate indices]:
            keep index = False
        elif val[index] == val[index of any other duplicate indices]:
            keep only a single one of the equal indices(doesn't matter which)   

该问题的一个简短示例如下:

import tensorflow as tf
tf.enable_eager_execution()

index = tf.convert_to_tensor([  10,    5,   20,    20,    30,    30])
value = tf.convert_to_tensor([  1.,   0.,   2.,    0.,    0.,    0.])
# bool_mask =                [True, True, True, False,  True, False]
# or                         [True, True, True, False, False,  True]
# the index 3 is filtered because index 2 has a greater value (2 comp. to 0)
# The index 4 and 5 are identical in their respective values, that's why both
# of them can be kept, but at maximum one of them. 


...
bool_mask = ?

我当前的方法成功地解决了删除具有不同值的重复项的问题,但是在具有相同值的重复项上失败了。然而,这是一个边缘案例,不幸地出现在我的数据中:

import tensorflow as tf

y, idx = tf.unique(index) 
num_segments = tf.shape(y)[0]
maximum_vals = tf.unsorted_segment_max(value, idx, num_segments)

fused_filt = tf.stack([tf.cast(y, tf.float32), maximum_vals],axis=1)
fused_orig = tf.stack([tf.cast(index, tf.float32), value], axis=1)

fused_orig_tiled = tf.tile(fused_orig, [1, tf.shape(fused_filt)[0]])
fused_orig_res = tf.reshape(fused_orig_tiled, [-1, tf.shape(fused_filt)[0], 2])

comp_1 = tf.equal(fused_orig_res, fused_filt)
comp_2 = tf.reduce_all(comp_1, -1)
comp_3 = tf.reduce_any(comp_2, -1)
# comp_3 = [True, True, True, False, True, True]

纯张量流解决方案会很好,因为索引上的For循环可以很简单地实现。非常感谢。你知道吗


Tags: offalsetrueindexvaluetfmaskval

热门问题