张量(64)的展开尺寸必须与非Singleton尺寸2处的现有尺寸(66)匹配

2024-10-01 07:35:53 发布

您现在位置:Python中文网/ 问答频道 /正文

在稳定的PyTorch 1.9中运行以下(未更改)代码时,出现以下错误。我如何诊断和修复它

(fashcomp) [jalal@goku fashion-compatibility]$ python main.py --test --l2_embed --resume runs/nondisjoint_l2norm/model_best.pth.tar --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
=> loaded checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar' (epoch 5)
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Traceback (most recent call last):
  File "main.py", line 312, in <module>
    main()    
  File "main.py", line 149, in main
    test_acc = test(test_loader, tnet)
  File "main.py", line 244, in test
    embeddings.append(tnet.embeddingnet(images).data)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch3/research/code/fashion/fashion-compatibility/type_specific_network.py", line 119, in forward
    masked_embedding = masked_embedding / norm.expand_as(masked_embedding)
RuntimeError: The expanded size of the tensor (64) must match the existing size (66) at non-singleton dimension 2.  Target sizes: [256, 66, 64].  Tensor sizes: [256, 66]

相关问题:https://github.com/mvasil/fashion-compatibility/pull/13


Tags: theinpytestusemainlineruns
1条回答
网友
1楼 · 发布于 2024-10-01 07:35:53

问题是norm缺少额外的维度:masked_embedding[256, 66, 64],而norm[256, 66]。您可以通过以下方式将此额外维度([256, 66, 1])添加到norm来修复此问题:

masked_embedding = masked_embedding / norm.unsqueeze(-1).expand_as(masked_embedding)

或者,通过修改生成norm的调用(this one):

norm = torch.norm(masked_embedding, p=2, dim=2, keepdim=True) + 1e-10

相关问题 更多 >