2024-05-09 02:10:45 发布
网友
我在寻找另一种方法来拯救一个在PyTorch受过训练的模特。到目前为止,我已经找到了两种选择。
我遇到了这个discussion,这里推荐方法2而不是方法1。
我的问题是,为什么选择第二种方法?这仅仅是因为torch.nn模块有这两个功能,我们被鼓励使用它们吗?
一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
保存/加载整个模型 保存:
path = "username/directory/lstmmodelgpu.pth" torch.save(trainer, path)
加载:
model = torch.load(PATH) model.eval()
我在他们的github repo上找到了this page,我将在这里粘贴内容。
序列化和恢复模型有两种主要方法。
第一个(建议)仅保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
之后:
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
the_model = torch.load(PATH)
但是在这种情况下,序列化的数据被绑定到特定的类 以及使用的确切目录结构,因此当 在其他项目中使用,或者在一些严重的重构之后使用。
这取决于你想做什么。
案例#1:保存模型,以便自己使用它进行推理:保存模型,还原模型,然后将模型更改为评估模式。之所以这样做,是因为通常有BatchNorm和Dropout层,默认情况下它们在构造时处于列车模式:
BatchNorm
Dropout
torch.save(model.state_dict(), filepath) #Later to restore: model.load_state_dict(torch.load(filepath)) model.eval()
案例#2:保存模型以便以后继续培训:如果需要继续培训要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器、时间段、分数等的状态。您可以这样做:
state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), ... } torch.save(state, filepath)
要恢复训练,您需要执行如下操作:state = torch.load(filepath),然后,要恢复每个单独对象的状态,请执行以下操作:
state = torch.load(filepath)
model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,因此在加载时恢复状态后,不要调用model.eval()。
model.eval()
案例#3:其他人无法访问您的代码时使用的模型: 在Tensorflow中,您可以创建一个.pb文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve时。在Pytorch中,同样的方法是:
.pb
Tensorflow serve
torch.save(model, filepath) # Then later: model = torch.load(filepath)
这种方式仍然不是防弹的,因为Pythorch仍在经历很多变化,我不推荐它。
一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
保存/加载整个模型 保存:
加载:
必须在某处定义模型类
我在他们的github repo上找到了this page,我将在这里粘贴内容。
保存模型的推荐方法
序列化和恢复模型有两种主要方法。
第一个(建议)仅保存和加载模型参数:
之后:
第二个保存并加载整个模型:
之后:
但是在这种情况下,序列化的数据被绑定到特定的类 以及使用的确切目录结构,因此当 在其他项目中使用,或者在一些严重的重构之后使用。
这取决于你想做什么。
案例#1:保存模型,以便自己使用它进行推理:保存模型,还原模型,然后将模型更改为评估模式。之所以这样做,是因为通常有
BatchNorm
和Dropout
层,默认情况下它们在构造时处于列车模式:案例#2:保存模型以便以后继续培训:如果需要继续培训要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器、时间段、分数等的状态。您可以这样做:
要恢复训练,您需要执行如下操作:
state = torch.load(filepath)
,然后,要恢复每个单独对象的状态,请执行以下操作:由于您正在恢复训练,因此在加载时恢复状态后,不要调用
model.eval()
。案例#3:其他人无法访问您的代码时使用的模型: 在Tensorflow中,您可以创建一个
.pb
文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve
时。在Pytorch中,同样的方法是:这种方式仍然不是防弹的,因为Pythorch仍在经历很多变化,我不推荐它。
相关问题 更多 >
编程相关推荐