Jupyter上的TensorFlow:无法还原变量

2024-10-01 15:40:41 发布

您现在位置:Python中文网/ 问答频道 /正文

在Jupyter笔记本中使用TensorFlow时,我似乎无法恢复保存的变量。结果保存在一个新的

如果我运行模型,将其保存在params.ckpt上,然后关闭并停止,然后再次尝试加载它,我会得到以下错误:

---------------------------------------------------------------------------
StatusNotOK                               Traceback (most recent call last)
StatusNotOK: Not found: Tensor name "Variable/Adam" not found in checkpoint files params.ckpt
     [[Node: save/restore_slice_1 = RestoreSlice[dt=DT_FLOAT, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_1/tensor_name, save/restore_slice_1/shape_and_slice)]]

During handling of the above exception, another exception occurred:

SystemError                               Traceback (most recent call last)
<ipython-input-6-39ae6b7641bd> in <module>()
----> 1 saver.restore(sess, "params.ckpt")

/usr/local/lib/python3.5/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
    889       save_path: Path where parameters were previously saved.
    890     """
--> 891     sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
    892 
    893 

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict)
    366 
    367     # Run request and get response.
--> 368     results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
    369 
    370     # User may have fetched the same tensor multiple times, but we

/usr/local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, target_list, fetch_list, feed_dict)
    426 
    427       return tf_session.TF_Run(self._session, feed_dict, fetch_list,
--> 428                                target_list)
    429 
    430     except tf_session.StatusNotOK as e:

SystemError: <built-in function delete_Status> returned a result with an error set

我的培训准则是:

^{pr2}$

我做错什么了吗?为什么我不能恢复我的变量?在


Tags: runnameinselfsessionsavefeedslice
1条回答
网友
1楼 · 发布于 2024-10-01 15:40:41

看起来您正在使用Jupyter构建模型。使用默认参数构造^{}时,一个可能的问题是,它将使用变量的(自动生成的)名称作为检查点中的键。因为在Jupyter中很容易多次重复执行代码单元,所以可能会在保存的会话中得到多个变量节点的副本。请参阅my answer to this question以了解可能出错的内容的说明。在

有几种可能的解决办法。以下是最简单的:

  • 在构建模型之前调用^{}(以及Saver)。这将确保变量获得您想要的名称,但它将使先前创建的图无效。

  • 使用^{}的显式参数来指定变量的持久名称。对于您的示例,这不应该太难(尽管对于较大的模型来说,这会变得很难操作):

    saver = tf.train.Saver(var_list={"b1": b1, "W1": W1, "b2": b2, "W2": W2})
    
  • 创建一个新的tf.Graph(),并在每次创建模型时将其设为默认值。这在Jupyter中可能很棘手,因为它强制您将所有的模型构建代码放在一个单元中,但它对脚本很有效:

    with tf.Graph().as_default():
      # Model building and training/evaluation code goes here.
    

相关问题 更多 >

    热门问题