如何使用torch.argmax并仅获取索引值

2024-09-27 19:29:19 发布

您现在位置:Python中文网/ 问答频道 /正文

使用火炬==1.7.1

MRE:

lst = [torch.rand(4).reshape(1,-1) for _ in range(5)]

对于每个张量,我想得到最大值的索引

max_indexes = [torch.argmax(tensor) for tensor in lst]

哪个输出

[tensor(0), tensor(0), tensor(2), tensor(2), tensor(2)]

我想使用它的位置能够从操作列表中获取适当的值。如何去掉张量,得到max_指数

所需的输出如下所示:

[0, 0, 2, 2, 2]

谢谢


Tags: in列表forrangetorch指数maxmre

热门问题