首先是一些背景。我目前正在为我的数据输入管道编写一个自定义TensorFlow 2.x预处理函数。最终,我将map
在一个批上。本质上,该函数接收一批行,并通过复制行并根据条件删除每行中的元素来生成一个更大的批。例如,如果输入批次看起来像
[[4, 1, 10, 10, 2],
[10, 7, 9, 10, 10],
[6, 8, 10, 3, 5]]
然后,函数应根据没有10的位置生成新样本。对于非-10的每次出现,将删除这些元素,例如,从第一个样品(新样品)中删除4个,从另一个样品中删除1个,…,从最后一个样品中删除5个。从输入批次中,我们将有9个样本:
[[1, 10, 10, 2],
[4, 10, 10, 2],
[4, 1, 10, 10],
[10, 9, 10, 10],
[10, 7, 10, 10],
[8, 10, 3, 5],
[6, 10, 3, 5],
[6, 8, 10, 5],
[6, 8, 10, 3]]
现在谈谈我的职责。通过使用tf.where
、tf.gather
、tf.unique_with_counts
和tf.repeat
,我能够以正确的次数复制原始行:
def myFunction(data):
# Returns a 2-column tensor, with each row
# being the index pair...
presentIndices = tf.where(data != 10)
# Grab the 1st column (rows) and count how many
# times each row appears...
rows = tf.gather(presentIndices, indices=0, axis=1)
_, _, counts = tf.unique_with_counts(rows)
# Repeat each row according to counts...
data = tf.repeat(data, repeats=counts, axis=0)
# data now has 1st row copied 3 times, 2nd row copied twice, etc.
然而,考虑到我在presentIndices
中有索引,我现在被困于如何从每一行中删除适当的元素。使用numpy
,我可以简单地索引data
并进行相应的重塑,但TensorFlow似乎没有很好的索引多维张量的能力
我已经研究了tf.boolean_mask
,但是我还是需要在适当的位置分配False
。我能找到的最接近的东西是tf.gather_nd
,但是提取了给定索引的数据。相反,我需要这个函数的否定。给定索引,提取除这些索引处的之外的所有数据
有没有办法利用现有的TensorFlow函数来获得我想要的功能
谢谢
您可以使用tf.boolean\u mask和tf.scatter\u nd为(重复的)数据创建布尔向量。 首先,创建索引张量以指示要遮罩的值:
然后将每行的掩码用作tf.scatter\u和方法中的索引:
样本张量:
您可以执行以下操作。我知道这可能有点头晕。最简单的方法就是使用此代码作为参考来做一个示例
这就产生了,
相关问题 更多 >
编程相关推荐