Tensorflow根据索引从每行中删除元素(与聚集相反)

2024-09-30 10:41:11 发布

您现在位置:Python中文网/ 问答频道 /正文

首先是一些背景。我目前正在为我的数据输入管道编写一个自定义TensorFlow 2.x预处理函数。最终,我将map在一个批上。本质上,该函数接收一批行,并通过复制行并根据条件删除每行中的元素来生成一个更大的批。例如,如果输入批次看起来像

[[4,  1, 10, 10,  2],
 [10, 7,  9, 10, 10],
 [6,  8, 10,  3,  5]]

然后,函数应根据没有10的位置生成新样本。对于非-10的每次出现,将删除这些元素,例如,从第一个样品(新样品)中删除4个,从另一个样品中删除1个,…,从最后一个样品中删除5个。从输入批次中,我们将有9个样本:

[[1, 10, 10, 2],
 [4, 10, 10, 2],
 [4, 1, 10, 10],
 [10, 9, 10, 10],
 [10, 7, 10, 10],
 [8, 10, 3, 5],
 [6, 10, 3, 5],
 [6, 8, 10, 5],
 [6, 8, 10, 3]]

现在谈谈我的职责。通过使用tf.wheretf.gathertf.unique_with_countstf.repeat,我能够以正确的次数复制原始行:

def myFunction(data):
    # Returns a 2-column tensor, with each row
    # being the index pair...
    presentIndices = tf.where(data != 10)
    # Grab the 1st column (rows) and count how many
    # times each row appears...
    rows = tf.gather(presentIndices, indices=0, axis=1)
    _, _, counts = tf.unique_with_counts(rows)
    # Repeat each row according to counts...
    data = tf.repeat(data, repeats=counts, axis=0)
    # data now has 1st row copied 3 times, 2nd row copied twice, etc.

然而,考虑到我在presentIndices中有索引,我现在被困于如何从每一行中删除适当的元素。使用numpy,我可以简单地索引data并进行相应的重塑,但TensorFlow似乎没有很好的索引多维张量的能力

我已经研究了tf.boolean_mask,但是我还是需要在适当的位置分配False。我能找到的最接近的东西是tf.gather_nd,但是提取了给定索引的数据。相反,我需要这个函数的否定。给定索引,提取除这些索引处的之外的所有数据

有没有办法利用现有的TensorFlow函数来获得我想要的功能

谢谢


Tags: 数据函数元素datatftensorflowwith样品
2条回答

您可以使用tf.boolean\u masktf.scatter\u nd为(重复的)数据创建布尔向量。 首先,创建索引张量以指示要遮罩的值:

row = tf.constant([0,1,2,3,4,5,6,7,8] ,dtype = tf.int64)
mask_for_each_row = tf.stack([row ,presentIndices[: , 1]],axis = 1 )

然后将每行的掩码用作tf.scatter\u和方法中的索引:

samples =tf.boolean_mask(data ,~tf.scatter_nd(mask_for_each_row , 
            tf.ones((9,),dtype = tf.bool),(9,5)))
samples = tf.reshape(samples ,(9,4))

样本张量:

      <tf.Tensor: shape=(9, 4), dtype=int32, numpy=
      array([[ 1, 10, 10,  2],
             [ 4, 10, 10,  2],
             [ 4,  1, 10, 10],
             [10,  9, 10, 10],
             [10,  7, 10, 10],
             [ 8, 10,  3,  5],
             [ 6, 10,  3,  5],
             [ 6,  8, 10,  5],
             [ 6,  8, 10,  3]])> 

您可以执行以下操作。我知道这可能有点头晕。最简单的方法就是使用此代码作为参考来做一个示例

def f(data):
    
    # Boolean mask where it's not 10
    a = (data != 10)
    # Repeat and reshape to n x 5 x 5
    a = tf.reshape(tf.repeat(a, 5), [-1, 5, 5])
    # Create a identity matrix of size 1 x 5 x 5
    eye = tf.reshape(tf.eye(5), [1,5,5])
    # Create a mask of size n x 5 x 5. This basically forces a to have only a single false value for each row
    # This single false element is the element to be removed
    mask = ~tf.cast(tf.reshape(tf.cast(a,'int32')* tf.cast(eye, 'int32'), [-1, 5]), 'bool')

    # Remove all the rows with all elements True. This ensures at least one element is removed from all existing rows
    mask = tf.cast(mask, 'int32') * tf.cast(~tf.reduce_all(mask, axis=1, keepdims=True), 'int32')
    mask = tf.cast(mask, 'bool')
    
    # Get the required rows and discard others and reshape
    res = tf.boolean_mask(tf.repeat(data, 5, axis=0), mask)     
    res = tf.reshape(res, [-1,4])

    return res

这就产生了,

tf.Tensor(
[[ 1 10 10  2]
 [ 4 10 10  2]
 [ 4  1 10 10]
 [10  9 10 10]
 [10  7 10 10]
 [ 8 10  3  5]
 [ 6 10  3  5]
 [ 6  8 10  5]
 [ 6  8 10  3]], shape=(9, 4), dtype=int32)

相关问题 更多 >

    热门问题