将一个热编码维度转换为1位置的索引

2024-05-18 10:52:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个三维的张量[batch_size, sequence_length, number_of_tokens]。 最后一个维度是一个热编码维度。我想得到一个二维张量,其中sequence_lengthnumber_of_tokens维的索引位置'1'组成

例如,要旋转形状为(2, 3, 4)的张量:

[[[0, 1, 0, 0]
[1, 0, 0, 0]
[0, 0, 0, 1]]
[[1, 0, 0, 0]
[1, 0, 0, 0]
[0, 0, 1, 0]]]

转换为(2, 3)形状的张量,其中number_of_tokens维转换为1的位置:

[[1, 0, 3]
[0, 0, 2]]

我这样做是为了准备模型结果,以便在计算损失时与参考答案进行比较,我希望这是正确的方法


Tags: of方法模型number编码sizebatchlength
3条回答

简单地做:

res = x.argmax(axis = 2)

如果原始张量是your previous question中指定的,则可以绕过一个热编码,直接使用argmax:

t = torch.rand(2, 3, 4)
t = t.argmax(dim=2)

通过连续的列表理解,您可以做您想做的事情:

x=[[[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]],
[[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0]]]

y=[[ell2.index(1) for ell2 in ell1] for ell1 in x]

print(y) # prints [[1, 0, 3], [0, 0, 2]]

它在主张量的元素上迭代,并在每个元素上返回该元素组件中的“1”索引列表

相关问题 更多 >