加载Pyrotch NN模型的检查点时出现异常

2024-10-06 07:38:25 发布

您现在位置:Python中文网/ 问答频道 /正文

调用以下单元格中定义的函数时,会引发异常“TypeError:forward()接受2个位置参数,但给出了9个”document提供了进一步的细节

def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model2 = model1(checkpoint['input_size'],
              checkpoint['output_size'],
              checkpoint['fc1'],
              checkpoint['fc2'],
              checkpoint['optimizer_state_dict'],
              checkpoint['epoch'],
              checkpoint['class_to_idx'],
              checkpoint['learning_rate'])
model2.load_state_dict(checkpoint['state_dict'])
return model2

写出检查点的代码如下所示:

checkpoint ={'input_size':512,
         'output_size':102,
         'fc1':256,
         'fc2':102,
         'state_dict': model.state_dict(),
         'optimizer_state_dict': optimizer.state_dict(),
         'epoch': epoch+1,
         'class_to_idx': model.class_to_idx,
         'learning_rate': 0.003}
torch.save(checkpoint,chkptJP)

Tags: toinputoutputsizeloadtorchdictclass
1条回答
网友
1楼 · 发布于 2024-10-06 07:38:25

您的错误表明model1是一个已经实例化的网络,而它应该是一个类。有关综合信息,请参见official documentation about saving(如有疑问,请随时查阅)。我会在整个答案中链接到它,所以一定要查看它并了解发生了什么

保存常规检查点

您的代码保存一个general checkpoint。你可以用这种方式保存任何字典和你想要的任何信息(它基本上是Python's pickle,你也可以用类似的方式调整它)。您有很多信息,其中一些与模型本身无关

加载一般检查点

正如您所做的,您可以通过torch.load加载所有这些数据。由于您已经保存了state_dict(权重),而不是整个Model(代码的外观),因此您必须使用随机权重创建一个新的模型,然后加载它们

此代码应该可以:

def load_checkpoint(chkptJP):
    checkpoint = torch.load(chkptJP)
    model = ModelClass(
        checkpoint["input_size"],
        checkpoint["output_size"],
        checkpoint["fc1"],
        checkpoint["fc2"],
        checkpoint["optimizer_state_dict"],
        checkpoint["epoch"],
        checkpoint["class_to_idx"],
        checkpoint["learning_rate"],
    )
    model.load_state_dict(checkpoint["state_dict"])
    return model

请注意,ModelClass必须是类,而不是像您在这里所做的那样是对象。如果model1是一个对象,那么运行model1(arg1, ..., arg9)将调用它的__call__方法,如果model1torch.nn.Module的实例,那么它又是一个包装的forward方法ModelClass在代码中应该是这样的(可能在某个地方定义):

import torch


class ModelClass(torch.nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        fc1,
        fc2,
        optimizer_state_dict,
        epoch,
        class_to_idx,
        learning_rate,
    ):
        # Your initialization code here
        ...

    def forward(tensor):
        # Your forward pass here
        ...

如果您没有ModelClass任何地方,您必须单独保存整个模型(例如torch.save(model)而不是torch.save(model.state_dict())),并将其作为一个整体加载(torch.load(PATH)而不是chkp=torch.load(PATH),然后在实例上调用model.load_state_dict

相关问题 更多 >