如何在tensorflow中保存/冻结任意pickle模型?

2024-09-27 19:12:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用的代码从zalando实现作为一个进步的GAN由Nvidia。参见:https://github.com/zalandoresearch/disentangling_conditional_gans

他们在训练时使用3个网络:GD和{}。这三个模型都是https://github.com/zalandoresearch/disentangling_conditional_gans/blob/master/tfutil.py#L424中定义的Network类的一个实例

这些模型使用许多helper函数进行存储和加载,这些函数使用python的pickle格式将3个模型保存为*.pkl。在

我只对导出Gs模型感兴趣。在

如何将其转换为保存的模型(因为代码不使用tf.节电器)最后是一个冻结的模型,这样我就可以很容易地推断了。在

加载模型后,我会:

allvars = [n.name for n in tf.get_default_graph().as_graph_def().node]
Gs_vars = [i for i in allvars if i.split('/')[0] == 'Gs']

但是,运行此程序时:

^{pr2}$

它会抛出一个错误:

^{3}$

使用Gs模型的正确实现是:

images = Gs.run(latents, labels, masks, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8)

Gs需要3个输入,它们存储为Gs/latents_inGs/masks_in和{}。在


Tags: 函数代码inhttps模型githubcomgs

热门问题