用火把拯救一个受过训练的模特的最佳方法?

2024-05-09 02:10:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我在寻找另一种方法来拯救一个在PyTorch受过训练的模特。到目前为止,我已经找到了两种选择。

  1. torch.save()保存模型,torch.load()加载模型。
  2. model.state_dict()保存已训练的模型,model.load_state_dict()加载已保存的模型。

我遇到了这个discussion,这里推荐方法2而不是方法1。

我的问题是,为什么选择第二种方法?这仅仅是因为torch.nn模块有这两个功能,我们被鼓励使用它们吗?


Tags: 模块方法模型功能modelsaveloadnn
3条回答

一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。

保存/加载整个模型 保存:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

加载:

必须在某处定义模型类

model = torch.load(PATH)
model.eval()

我在他们的github repo上找到了this page,我将在这里粘贴内容。


保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(建议)仅保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

之后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

之后:

the_model = torch.load(PATH)

但是在这种情况下,序列化的数据被绑定到特定的类 以及使用的确切目录结构,因此当 在其他项目中使用,或者在一些严重的重构之后使用。

这取决于你想做什么。

案例#1:保存模型,以便自己使用它进行推理:保存模型,还原模型,然后将模型更改为评估模式。之所以这样做,是因为通常有BatchNormDropout层,默认情况下它们在构造时处于列车模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便以后继续培训:如果需要继续培训要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器、时间段、分数等的状态。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您需要执行如下操作:state = torch.load(filepath),然后,要恢复每个单独对象的状态,请执行以下操作:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后,不要调用model.eval()

案例#3:其他人无法访问您的代码时使用的模型: 在Tensorflow中,您可以创建一个.pb文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve时。在Pytorch中,同样的方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不是防弹的,因为Pythorch仍在经历很多变化,我不推荐它。

相关问题 更多 >