Pyrotch聚集问题(3D计算机视觉)

2024-10-05 12:24:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我有N组C维点。每组有M个点。所以,有一个张量(N,M,C)。让我们称之为功能

我通过M维计算最大元素和索引,以找到每个C维的最大点(最大池操作),得到最大张量(N,1,C)和索引张量(N,1,C)

我有另一个形状的张量(N,M,3)存储这些N*M高维点的几何坐标。现在,我想使用每个C维中最大点的索引,来获得所有这些最大点的坐标

例如,N=2,M=4,C=6

坐标张量,其形状为(2,4,3):

[[[1, 2, 3]
  [4, 5, 6]
  [7, 8, 9]
  [8, 7, 6]]

 [11, 12, 13]
 [14, 15, 16]
 [17, 18, 19]
 [18, 17, 16]]]

指数张量,其形状为(2,1,6):

[[[0, 1, 2, 1, 2, 3]]
 [[1, 2, 3, 2, 1, 0]]]

例如,索引中的第一个元素是0,我想从坐标张量中获取[1,2,3]。对于第二个元素(1),我想把[4,5,6]拿出来。对于下一维度(3)中的第三个元素,我想把[18,17,16]拿出来

结果张量如下所示:

[[[1, 2, 3]  # 0
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [8, 7, 6]] # 3

 [[14, 15, 16] # 1
  [17, 18, 19] # 2
  [18, 17, 16] # 3
  [17, 18, 19] # 2
  [14, 15, 16] # 1
  [11, 12, 13]]]# 0

其形状为(2,6,3)

我试着用手电筒,但没法用。我编写了一个简单的算法来枚举所有N个组,但实际上它很慢,即使使用TorchScript的JIT。那么,如何在pytorch中高效地编写这篇文章呢


Tags: 功能算法元素pytorch指数jit形状篇文章
1条回答
网友
1楼 · 发布于 2024-10-05 12:24:56

您可以使用integer array indexingbroadcasting semantics组合来获得结果

import torch

x = torch.tensor([
    [[1, 2, 3], 
     [4, 5, 6], 
     [7, 8, 9], 
     [8, 7, 6]],
    [[11, 12, 13],
     [14, 15, 16],
     [17, 18, 19],
     [18, 17, 16]],
])

i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
                  [[1, 2, 3, 2, 1, 0]]])

# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)

# y is [2, 6, ...]
y = x[rows, cols]

相关问题 更多 >

    热门问题