我的代码中有以下嵌套for循环段。嵌套循环正在减慢我的完整执行
对于形状为[batchSize,nClass*repeat]
的火炬张量extended_output
和尺寸为[batchSize,nClass]
的火炬张量,我希望聚合如下:
for q in range(nClass):
for u in range(repeat):
output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]
这里,nClass
,repeat
都是分别具有值1400
和8
的整数变量
使用pytorch广播可以避免这种嵌套for循环吗?任何帮助都是非常有用的
一个示例工作cpode可能是这样的
import torch
nClass=1400
repeat=8
batchSize=64
output=torch.zeros([batchSize,nClass])
extended_output=torch.rand([batchSize,nClass*repeat])
for q in range(nClass):
for u in range(repeat):
output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]
对不起,这个例子很简短,可能过于简化了。我担心一个更大的会更难想象。但我希望这符合你的目的。下面是我要做的:
输出:
相关问题 更多 >
编程相关推荐