为什么在将astate_dict
加载到同一模型体系结构的新实例中之后,序列化pytorchstate_dict
所获得的字节会发生变化
看看:
import binascii
import torch.nn as nn
import pickle
lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")
lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")
印刷品
--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207
--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207
如您所见,(pickles of the)state_dict
的散列是不同的,而权重是正确复制的。我假设新模型中的state_dict
在各个方面都等于旧模型。表面上看,它不是,因此不同的散列
这可能是因为pickle不希望生成适合散列的repr(请参见Using pickle.dumps to hash mutable objects)。比较键,然后比较dict键中存储的张量是否相等/接近,这可能是一个更好的主意
下面是这个想法的一个粗略实现
但是,如果您仍然希望散列状态dict并避免使用上面的
isclose
之类的比较,那么可以使用下面的函数相关问题 更多 >
编程相关推荐