我有一个三维的张量[batch_size, sequence_length, number_of_tokens]
。
最后一个维度是一个热编码维度。我想得到一个二维张量,其中sequence_length
由number_of_tokens
维的索引位置'1'组成
例如,要旋转形状为(2, 3, 4)
的张量:
[[[0, 1, 0, 0]
[1, 0, 0, 0]
[0, 0, 0, 1]]
[[1, 0, 0, 0]
[1, 0, 0, 0]
[0, 0, 1, 0]]]
转换为(2, 3)
形状的张量,其中number_of_tokens
维转换为1
的位置:
[[1, 0, 3]
[0, 0, 2]]
我这样做是为了准备模型结果,以便在计算损失时与参考答案进行比较,我希望这是正确的方法
简单地做:
如果原始张量是your previous question中指定的,则可以绕过一个热编码,直接使用argmax:
通过连续的列表理解,您可以做您想做的事情:
它在主张量的元素上迭代,并在每个元素上返回该元素组件中的“1”索引列表
相关问题 更多 >
编程相关推荐