擅长:python、mysql、java
<p>您可以使用<a href="https://pytorch.org/docs/stable/generated/torch.flatten.html" rel="nofollow noreferrer">^{<cd1>}</a>展平最后两个维度并在其上应用<a href="https://pytorch.org/docs/stable/generated/torch.argmax.html" rel="nofollow noreferrer">^{<cd2>}</a>:</p>
<pre><code>>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
[5934, 5494, 9717]])
</code></pre>