调用以下单元格中定义的函数时,会引发异常“TypeError:forward()接受2个位置参数,但给出了9个” 本document提供了进一步的细节
def load_checkpoint(chkptJP):
checkpoint = torch.load(chkptJP)
model2 = model1(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['fc1'],
checkpoint['fc2'],
checkpoint['optimizer_state_dict'],
checkpoint['epoch'],
checkpoint['class_to_idx'],
checkpoint['learning_rate'])
model2.load_state_dict(checkpoint['state_dict'])
return model2
写出检查点的代码如下所示:
checkpoint ={'input_size':512,
'output_size':102,
'fc1':256,
'fc2':102,
'state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch+1,
'class_to_idx': model.class_to_idx,
'learning_rate': 0.003}
torch.save(checkpoint,chkptJP)
您的错误表明
model1
是一个已经实例化的网络,而它应该是一个类。有关综合信息,请参见official documentation about saving(如有疑问,请随时查阅)。我会在整个答案中链接到它,所以一定要查看它并了解发生了什么保存常规检查点
您的代码保存一个general checkpoint。你可以用这种方式保存任何字典和你想要的任何信息(它基本上是Python's pickle,你也可以用类似的方式调整它)。您有很多信息,其中一些与模型本身无关
加载一般检查点
正如您所做的,您可以通过
torch.load
加载所有这些数据。由于您已经保存了state_dict
(权重),而不是整个Model
(代码的外观),因此您必须使用随机权重创建一个新的模型,然后加载它们此代码应该可以:
请注意,
ModelClass
必须是类,而不是像您在这里所做的那样是对象。如果model1
是一个对象,那么运行model1(arg1, ..., arg9)
将调用它的__call__
方法,如果model1
是torch.nn.Module
的实例,那么它又是一个包装的forward
方法ModelClass
在代码中应该是这样的(可能在某个地方定义):如果您没有
ModelClass
任何地方,您必须单独保存整个模型(例如torch.save(model)
而不是torch.save(model.state_dict()))
,并将其作为一个整体加载(torch.load(PATH)
而不是chkp=torch.load(PATH)
,然后在实例上调用model.load_state_dict
)相关问题 更多 >
编程相关推荐