pytorch:retain_graph=True错误,即使我添加了这个

2024-10-02 22:35:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在犯这个错误 “第二次尝试反向浏览图形,但保存的中间结果已被释放。第一次调用.backward()或autograd.grad()时,请指定retain\u graph=True。”

在开始时,它没有retain_graph=True,然后我得到了错误,所以我将它添加到backward中,但仍然得到相同的错误

我读过类似的问题,但没有任何帮助。 希望得到帮助

trained_cnnfmnist_model=net

class CNNFMnist2(nn.Module):
    def __init__(self, trained_cnnfmnist_model):
        super(CNNFMnist2, self).__init__()
        self.trained_cnnfmnist_model = trained_cnnfmnist_model
        # now a few fully connected layers
        self.fc1 = nn.Linear(64, 32)
        self.fc2=  nn.Linear(32,16)
        self.fc3=  nn.Linear(16,10)

    def forward(self, x):
      x = self.trained_cnnfmnist_model(x)
      x = F.relu(self.fc1(x[0]))
      print(x.shape)
      # x = x.view(-1, self.num_flat_features(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x 

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transforms.ToTensor())

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transforms.ToTensor())

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False)

net2 = CNNFMnist2(trained_cnnfmnist_model).cuda()
optimizer = torch.optim.SGD(net2.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        
        inputs = inputs.cuda() # -- For GPU
        labels = labels.cuda() # -- For GPU

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        output = net2(inputs)
        loss = criterion(outputs, labels)
        loss.backward(retain_graph=True)
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if (i+1) % 2000 == 0:    
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

Tags: selftruedatalabelsmodel错误nnrunning