我使用thisPyTorch重新实现来训练ProGAN代理,并将代理保存为.pth
。现在,我需要将代理转换为.onnx
格式,我正在使用这个scipt:
from torch.autograd import Variable
import torch.onnx
import torchvision
import torch
device = torch.device("cuda")
dummy_input = torch.randn(1, 3, 64, 64)
state_dict = torch.load("GAN_agent.pth", map_location = device)
torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
一旦我运行它,我就会得到错误AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
(下面是完整的提示)。据我所知,问题在于将代理转换为.onnx需要更多信息。我错过什么了吗
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-c64481d4eddd> in <module>
10 state_dict = torch.load("GAN_agent.pth", map_location = device)
11
---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
146 operator_export_type, opset_version, _retain_param_name,
147 do_constant_folding, example_outputs,
--> 148 strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
149
150
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
64 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
65 example_outputs=example_outputs, strip_doc_string=strip_doc_string,
---> 66 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
67
68
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
414 example_outputs, propagate,
415 _retain_param_name, do_constant_folding,
--> 416 fixed_batch_size=fixed_batch_size)
417
418 # TODO: Don't allocate a in-memory string for the protobuf
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
277 model.graph, tuple(in_vars), False, propagate)
278 else:
--> 279 graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
280 state_dict = _unique_state_dict(model)
281 params = list(state_dict.values())
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)
226 # A basic sanity check: make sure the state_dict keys are the same
227 # before and after running the model. Fail fast!
--> 228 orig_state_dict_keys = _unique_state_dict(model).keys()
229
230 # By default, training=False, which is good because running a model in
~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)
283 # id(v) doesn't work with it. So we always get the Parameter or Buffer
284 # as values, and deduplicate the params using Parameters and Buffers
--> 285 state_dict = module.state_dict(keep_vars=True)
286 filtered_dict = type(state_dict)()
287 seen_ids = set()
AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
您拥有的文件有
state_dict
,它们只是将层名称映射到tensor
权重偏差和类似a的文件(有关更详细的介绍,请参见here)这意味着您需要一个模型,以便可以映射保存的权重和偏差,但首先要做的是:
一,。模型制备
克隆模型定义所在的the repository并打开文件
/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py
。我们需要一些修改,以便它能够与onnx
一起工作onnx
导出器要求input
仅作为torch.tensor
传递(或其中list
/dict
),而Generator
类需要int
和float
参数)简单的解决方案是将
forward
函数(文件中的80
行,您可以验证它on GitHub)稍微修改为以下内容:这里只添加了通过
item()
解包。每个不是Tensor
类型的输入都应该在函数定义中打包为一个,并尽快在函数顶部解包。它不会破坏您创建的检查点,所以不用担心,因为它只是layer-weight
映射二,。模型导出
将此脚本放置在
/pro_gan_pytorch
(其中README.md
)中:请注意以下几点:
state_dict
李>torch.nn.DataParallel
是需要的,因为这是模型的培训内容(不确定您的案例,请相应调整)。加载后,我们可以通过module
属性获取模块本身李>CPU
,我想这里不需要GPU
。如果你坚持的话,你可以把一切都投给GPU
李>512
元素的噪声李>运行它,您的
.onnx
文件应该在那里哦,由于您在不同的检查点之后,您可能希望遵循类似的过程,尽管不能保证一切都会正常工作(尽管看起来确实如此)
相关问题 更多 >
编程相关推荐