PyTorch中批量张量的索引元素

2024-09-26 18:01:29 发布

您现在位置:Python中文网/ 问答频道 /正文

比如说,我在PyTorch中有一批图像。对于每个图像,我还有一个像素位置,比如(x, y)。对于一个图像,可以使用img[x, y]读取像素值。我正在尝试读取批处理中每个图像的像素值。请参见下面的代码片段:

import torch

# create tensors to represent random images in torch format
img_1 = torch.rand(1, 200, 300)
img_2 = torch.rand(1, 200, 300)
img_3 = torch.rand(1, 200, 300)
img_4 = torch.rand(1, 200, 300)

# for each image, x-y value are know, so creating a tuple
img1_xy = (0, 10, 70)
img2_xy = (0, 40, 20)
img3_xy = (0, 30, 50)
img4_xy = (0, 80, 60)

# this is what I am doing right now
imgs = [img_1, img_2, img_3, img_4]
imgs_xy = [img1_xy, img2_xy, img3_xy, img4_xy]
x = [img[xy] for img, xy in zip(imgs, imgs_xy)]
x = torch.as_tensor(x)

我的关注和问题

  1. 在每个图像中,像素位置即(x, y)是已知的。但是,我必须创建一个包含多个元素的元组,即0,以确保元组与图像的形状匹配。有优雅的方式吗
  2. 我们不能用张量代替tuple,然后得到像素值吗
  3. 可以将所有图像连接起来,作为img_batch = torch.cat((img_1, img_2, img_3, img_4))进行批处理。但是tuple呢

Tags: in图像imgfor像素torchimg1元组
1条回答
网友
1楼 · 发布于 2024-09-26 18:01:29

您可以将图像合并成(4, 200, 300)形的堆叠张量。然后,我们可以使用每个图像的已知(x, y)对对此进行索引,如下所示:第一个图像需要[0, x1, y1],第二个图像需要[1, x2, y2],第三个图像需要[2, x3, y3],等等。这些可以通过“奇特的索引”实现:

# stacking as you did
>>> stacked_imgs = torch.cat(imgs)
>>> stacked_imgs.shape
(4, 200, 300)

# no need for 0s in front
>>> imgs_xy = [(10, 70), (40, 20), (30, 50), (80, 60)]

# need xs together and ys together: take transpose of `imgs_xy`
>>> inds_x, inds_y = torch.tensor(imgs_xy).T

>>> inds_x
tensor([10, 40, 30, 80])

>>> inds_y
tensor([70, 20, 50, 60])

# now we index into the batch
>>> num_imgs = len(imgs)
>>> result = stacked_imgs[range(num_imgs), inds_x, inds_y]
>>> result
tensor([0.5359, 0.4863, 0.6942, 0.6071])

我们可以检查结果:

>>> torch.tensor([img[0, x, y] for img, (x, y) in zip(imgs, imgs_xy)])

tensor([0.5359, 0.4863, 0.6942, 0.6071])

回答您的问题:

1:由于我们对图像进行了堆叠,所以这个问题得到了缓解,我们使用range(4)索引到每个单独的图像中

是的,我们确实把x, y位置变成了张量

3:我们将它们分离成张量后直接使用它们进行索引

相关问题 更多 >

    热门问题