未找到Pytorch广播命令

2024-05-19 10:22:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我的代码中有以下嵌套for循环段。嵌套循环正在减慢我的完整执行

对于形状为[batchSize,nClass*repeat]的火炬张量extended_output和尺寸为[batchSize,nClass]的火炬张量,我希望聚合如下:

for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

这里,nClassrepeat都是分别具有值14008的整数变量

使用pytorch广播可以避免这种嵌套for循环吗?任何帮助都是非常有用的

一个示例工作cpode可能是这样的

import torch
nClass=1400
repeat=8
batchSize=64
output=torch.zeros([batchSize,nClass])
extended_output=torch.rand([batchSize,nClass*repeat])

for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

Tags: 代码inextendedforoutput尺寸range整数
1条回答
网友
1楼 · 发布于 2024-05-19 10:22:00

对不起,这个例子很简短,可能过于简化了。我担心一个更大的会更难想象。但我希望这符合你的目的。下面是我要做的:

import torch
nClass    = 3
repeat    = 2
batchSize = 4

torch.manual_seed(0)

output          = torch.zeros([batchSize,nClass])
extended_output = torch.rand([batchSize,nClass*repeat])


for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

idxs = (torch.arange(repeat)*nClass).unsqueeze(0)
idxs = idxs + torch.arange(nClass).unsqueeze(1)
output_vectorized = extended_output[:, idxs].sum(2)

输出:

extended_output = 
tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939, 0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816, 0.9152, 0.3971, 0.8742]])
output = 
tensor([[0.6283, 1.0756, 0.7226],
        [1.1224, 1.2453, 0.8573],
        [0.5408, 0.8665, 1.0939],
        [1.0762, 0.6794, 1.5558]])
output_vectorized = 
tensor([[0.6283, 1.0756, 0.7226],
        [1.1224, 1.2453, 0.8573],
        [0.5408, 0.8665, 1.0939],
        [1.0762, 0.6794, 1.5558]])

相关问题 更多 >

    热门问题