如何在TensorF中导入模型

2024-05-19 11:04:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我在恢复保存的模型时遇到困难。我正在对CNN进行MNIST数据集的培训,所有这些都是根据MNIST教程Deep MNIST for Experts进行的,我用以下代码保存我的模型:

saver.save(sess, './Tensorflow_MNIST', global_step=max_steps)

这将创建以下文件:

  • Tensorflow_MNIST-1000.data-00000-of-00001
  • Tensorflow_MNIST-1000.索引
  • Tensorflow_MNIST-1000.meta
  • 检查点

稍后,我希望加载模型并继续使用以下内容进行培训:

^{pr2}$

但这将返回一个错误:

NameError: name 'train_step' is not defined

因此,图及其变量和操作似乎没有正确加载。我做错什么了?在


Tags: 数据代码模型forsavetensorflowstep教程
3条回答

当使用saver.save()时,TensorFlow保存由张量组成的计算图,即TensorFlow的对象。

它不会保存您使用的所有变量。特别是,不是tf.Tensor的任何内容都不会被保存。

您可能希望拥有自己的数据结构来保存任何其他信息。

为了方便起见,您可以使用JSON格式,甚至可以使用pickle,这在python中使用起来非常简单,但不能手工编辑。

希望有帮助

保存时:

saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_step', train_step)

恢复时:

^{pr2}$

如果您想重用该模型,我想将sess.run(train_step...)更改为 train_step(...)应该可以

按照import meta_graph中的描述,调用添加了“”和:0的所有张量似乎可以做到这一点。例如,计算精度的调用变成:

test_accuracy = sess.run("accuracy:0", feed_dict={"x:0": mnist.test.images, "y_:0": mnist.test.labels, "keep_prob:0": 1.0})

相关问题 更多 >

    热门问题