Pythorch中的映射函数

2024-10-01 07:49:34 发布

您现在位置:Python中文网/ 问答频道 /正文

Pythorch中有地图功能吗?(类似于python中的map)。在

我需要将1xDxhxw张量变量映射到1x(9D)xhxw张量,以增加每个像素的8个相邻嵌入。Pythorch中有什么功能可以让我高效地做到这一点?在

我尝试在Python中使用map,方法是:

n, d, h, w = embedding.size()
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)
embedding = map(lambda i, j, M: M[:, :, i-1:i+2, j-1:j+2], range(1, h), range(1, w), embedding)

但它对w > 2h > 2无效。在


Tags: 方法lambda功能mapsizerange像素nn
1条回答
网友
1楼 · 发布于 2024-10-01 07:49:34

从你的问题来看,你想达到什么目的还不清楚。在

请注意,PyTorch支持完整的pythorch,但是您要做的是在最后一行代码中创建一个map对象。以下内容适用于您的目的(?)?我猜)虽然:

import torch
import torch.nn as nn
n, d, h, w = 20, 3, 32, 32
embedding = torch.randn(n, d, h, w)
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)

new = [embedding[:,:, (i-1):(i+2), (j-1):(j+2)] for i, j in zip(range(1,h), range(1,w))]

然而,请注意,有更优雅的方法来组合张量(例如torch.chunk())或对具有卷积的补丁进行操作(例如torch.nn.Conv2d

相关问题 更多 >