Pytork负载_不兼容键

2024-09-28 23:41:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我培训了一个Efficent-b6模型(架构如下所示):

https://github.com/lukemelas/EfficientNet-PyTorch

现在,我尝试加载一个我用它训练的模型:

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)

但是我得到了以下错误:

_IncompatibleKeys
missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]

请让我知道如何修复,我遗漏了什么? 谢谢大家!


Tags: 模型model架构loadtorchkeysmoduleweight
1条回答
网友
1楼 · 发布于 2024-09-28 23:41:48

如果比较missing_keysunexpected_keys,您可能会意识到发生了什么

missing_keys=['_conv_stem.weight', '_bn0.weight', '_bn0.bias', ...]
unexpected_keys=['module._conv_stem.weight', 'module._bn0.weight', 'module._bn0.bias', ...]

如您所见,模型权重以module.前缀保存。 这是因为你已经用DataParallel训练了模型

现在,要在不使用DataParallel的情况下加载模型权重,可以执行以下操作

# original saved file with DataParallel
checkpoint = torch.load(path, map_location=torch.device('cpu'))

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = key.replace("module.", "") # remove `module.`
    new_state_dict[name] = v

# load params
model.load_state_dict(new_state_dict, strict=False)

或者,如果使用DataParallel包装模型,则不需要上述方法

checkpoint  = torch.load('model.pth', map_location=torch.device('cpu'))
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint, strict=False)

尽管不鼓励第二种方法(因为在许多情况下,您可能不需要DataParallel

相关问题 更多 >