当我想将图像分割成多个补丁时,我想到的第一件事就是使用pytorch view()函数。例如,一个形状为(1,3256256)(pytorch样式)的图像,并将其拆分为8x8=64个面片,每个面片的高度和宽度为32。对于这幅图像,我们可以得到256/32=8的行补丁和8的列补丁,所以我们总共有8x8=64个补丁
我想将图像(1,3256256)分割成面片,每个面片的形状是(1,3,32,32),并将这些张量重塑成(1,8x8,32x32x3)的形状,这里8x8是面片的数量,32x32x3是面片的高度*宽度*通道
下面的代码使用einops重排函数可以得到正确的答案,但是当我使用view函数时,我得到了正确的形状但张量值不正确。谁能告诉我如何通过查看功能实现此操作
from einops.layers.torch import Rearrange
img = torch.randn(1, 3, 256, 256)
import copy
img2 = copy.deepcopy(img)
b, c, h, w = img.size()
p=32
to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32),
)
img2 = img2.view(b, h // p * w // p, c * p * p)
print(img2.shape)
print(img2==to_patch_embedding(img))
----------------------------------------output--------------------------------------------
torch.Size([1, 64, 3072])
tensor([[[ True, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, True]]])
您可以按照pytorch discuss中提到的方法尝试这种方法
输出形状与您预期的一样:
希望这对你有用
相关问题 更多 >
编程相关推荐