将Pythorch模型从0.4.1加载到0.4.0?

2024-09-30 22:09:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用pytorch0.4.1(GPU)训练了一个DENSENET161模型,在测试环境中,我必须将其加载到pytorch版本0.4.0(CPU)中。我已经在使用model.cpu() 但是当我加载静态字典时model.load_state_dict(checkpoint['state_dict'])

我收到以下错误:

RuntimeError: Error(s) in loading state_dict for DenseNet: Unexpected key(s) in state_dict: "features.norm0.num_batches_tracked", "features.denseblock1.denselayer1.norm1.num_batches_tracked", "features.denseblock1.denselayer1.norm2.num_batches_tracked", "features.denseblock1.denselayer2.norm1.num_batches_tracked",...


Tags: in模型modelgpu测试环境batchesnumdict
1条回答
网友
1楼 · 发布于 2024-09-30 22:09:55

这似乎源于Pythorc0.4.1和0.4之间规范化层实现的差异——前者跟踪一些称为num_batches_tracked的状态变量,Pythorc0.4并不期望这样做。假设只有意外的键,没有丢失的键(这我不能确定,因为您已经剪辑了错误消息),您可以删除无关的键,希望模型可以加载。所以试试看

model_dict = checkpoint['state_dict']
filtered = {
    k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k
}
model.load_state_dict(filtered)

请注意,除了您在这里看到的之外,规范化的内部结构可能已经发生了更改,因此即使此修复消除了异常,模型也可能会默默地表现出错误行为。在

相关问题 更多 >