类型错误:张量不是torch imag

2024-05-20 23:17:19 发布

您现在位置:Python中文网/ 问答频道 /正文

在Udacity学习人工智能课程的时候,我在迁移学习部分遇到了这个错误。下面是可能导致问题的代码:

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models

data_dir = 'filename'

# TODO: Define transforms for the training data and testing data
train_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])
test_transforms= transforms.Compose([transforms.Resize((224,224)), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.ToTensor()])

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)

Tags: thecomposefromtestimportdatadirtrain
2条回答

问题在于变换的顺序。ToTensor变换应该在Normalize变换之前,因为后者需要张量,但是Resize变换返回图像。更改故障线路后更正代码:

train_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

另一个不太优雅的解决方案(假设图像是用opencv加载的,因此是BGR):

t_ = transforms.Compose([transforms.ToPILImage(),
                         transforms.Resize((224,224)),
                         transforms.ToTensor()])

norm_ = transforms.Normalize([103.939, 116.779, 123.68],[1,1,1])
img = 255*t_(img)
img = norm_(img)

相关问题 更多 >