得到多个维度上的对角线元素

2024-09-30 02:35:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我想把一个大小为(n x n x m x m)的张量T转换成一个大小为(n x m x m)的张量U,同时只检索(NxN)块上T的对角线元素(即Uikl=Tiikl)。diag()只适用于二维张量,我真的不知道如何在不循环元素索引的情况下做到这一点(鉴于我认为这在计算上是低效的,所以我想避免这样做)。总之,我想将以下代码矢量化:

U = torch.zeros(n, m, m)
for i in range(n):
    for k in range(m):
        for l in range(m):
            U[i][k][l] = T[i][i][k][l]

我对pytorch完全陌生,我尝试了很多函数的组合,但是没有一个能给我一个满意的结果。有人知道吗


Tags: 代码in元素forzeros情况rangetorch
1条回答
网友
1楼 · 发布于 2024-09-30 02:35:37

可以使用np.meshgrid生成索引

i, k, l = np.meshgrid(range(n), range(m), range(m))
U[i, k, l] = T[i, i, k, l]

为了完整起见,我做了:

n = 3
m = 5

T = torch.arange(n * n * m * m).view(n, n, m, m)
U = torch.zeros(n, m, m)
U_ = torch.zeros(n, m, m)

i, k, l = np.meshgrid(range(n), range(m), range(m))

U_[i, k, l] = T[i, i, k, l]

for i in range(n):
    for k in range(m):
        for l in range(m):
            U[i][k][l] = T[i][i][k][l]

U = U.view(-1)
U_ = U_.view(-1)

print ((U == U_).all())

输出是True,所以我假设它是正确的

相关问题 更多 >

    热门问题