跨多个维度的Pyrotch argmax

2024-10-02 20:40:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个4D张量,我想得到最后二维的argmaxtorch.argmax只接受整数作为“dim”参数,而不接受元组

我怎样才能做到这一点

这是我的想法,但我不知道如何匹配我的两个“指数”张量的维数original_array是形状[1,512,37,59]

max_vals, indices_r = torch.max(original_array, dim=2)
max_vals, indices_c = torch.max(max_vals, dim=2)
indices = torch.hstack((indices_r, indices_c))

Tags: 参数整数torch指数arraymax元组形状
2条回答

正如其他人提到的,最好将最后两个维度展平并应用argmax

original_array = torch.rand(1, 512, 37, 59)
original_flatten = original_array.view(1, 512, -1)
_, max_ind = original_flatten.max(-1)

。。您将获得最大值的线性索引。如果需要最大值的二维索引,可以使用列数“取消设置”索引

# 59 is the number of columns for the (37, 59) part
torch.stack([max_ind // 59, max_ind % 59], -1)

这将为您提供一个(1, 512, 2),其中最后两个dim包含2D坐标

您可以使用^{}展平最后两个维度并在其上应用^{}

>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
        [5934, 5494, 9717]])

相关问题 更多 >