我想把一个大小为(n x n x m x m)的张量T转换成一个大小为(n x m x m)的张量U,同时只检索(NxN)块上T的对角线元素(即Uikl=Tiikl)。diag()只适用于二维张量,我真的不知道如何在不循环元素索引的情况下做到这一点(鉴于我认为这在计算上是低效的,所以我想避免这样做)。总之,我想将以下代码矢量化:
U = torch.zeros(n, m, m)
for i in range(n):
for k in range(m):
for l in range(m):
U[i][k][l] = T[i][i][k][l]
我对pytorch完全陌生,我尝试了很多函数的组合,但是没有一个能给我一个满意的结果。有人知道吗
可以使用
np.meshgrid
生成索引为了完整起见,我做了:
输出是
True
,所以我假设它是正确的相关问题 更多 >
编程相关推荐