关于变张量形状的一个问题

2024-10-01 02:33:25 发布

您现在位置:Python中文网/ 问答频道 /正文

当我想将图像分割成多个补丁时,我想到的第一件事就是使用pytorch view()函数。例如,一个形状为(1,3256256)(pytorch样式)的图像,并将其拆分为8x8=64个面片,每个面片的高度和宽度为32。对于这幅图像,我们可以得到256/32=8的行补丁和8的列补丁,所以我们总共有8x8=64个补丁

我想将图像(1,3256256)分割成面片,每个面片的形状是(1,3,32,32),并将这些张量重塑成(1,8x8,32x32x3)的形状,这里8x8是面片的数量,32x32x3是面片的高度*宽度*通道

下面的代码使用einops重排函数可以得到正确的答案,但是当我使用view函数时,我得到了正确的形状但张量值不正确。谁能告诉我如何通过查看功能实现此操作

from einops.layers.torch import Rearrange
img = torch.randn(1, 3, 256, 256)
import copy
img2 = copy.deepcopy(img)
b, c, h, w = img.size()
p=32
to_patch_embedding = nn.Sequential(
    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32),
)
img2 = img2.view(b, h // p * w // p, c * p * p)


print(img2.shape)
print(img2==to_patch_embedding(img))

----------------------------------------output--------------------------------------------
torch.Size([1, 64, 3072])
tensor([[[ True, False, False,  ..., False, False, False],
     [False, False, False,  ..., False, False, False],
     [False, False, False,  ..., False, False, False],
     ...,
     [False, False, False,  ..., False, False, False],
     [False, False, False,  ..., False, False, False],
     [False, False, False,  ..., False, False,  True]]])

Tags: 函数图像viewfalseimg宽度高度torch
1条回答
网友
1楼 · 发布于 2024-10-01 02:33:25

您可以按照pytorch discuss中提到的方法尝试这种方法

import torch
a = torch.randn(1, 3, 256, 256)
a = a.unfold(2, 32, 32).unfold(3, 32, 32)
a = a.contiguous().view(a.size(0), a.size(2)*a.size(3), a.size(1)*a.size(-1)*a.size(-2))
print(a.shape)

输出形状与您预期的一样:

torch.Size([1, 64, 3072])

希望这对你有用

相关问题 更多 >