Tensorflow恢复时忽略作用域名称或进入新的作用域nam

2024-10-01 11:29:32 发布

您现在位置:Python中文网/ 问答频道 /正文

我首先训练了网络N,并用保存程序将其保存到检查点Checkpoint_N。在N中定义了一些变量范围。在

现在,我想用这个训练的网络N建立一个暹罗网络,如下所示:

with tf.variable_scope('siameseN',reuse=False) as scope:
  networkN = N()
  embedding_1 = networkN.buildN() 
  # this defines the network graph and all the variables.
  tf.train.Saver().restore(session_variable,Checkpoint_N)
  scope.reuse_variables()
  embedding_2 = networkN.buildN()
  # define 2nd branch of the Siamese, by reusing previously restored variables.

当我执行上述操作时,restore语句抛出一个Key Error,在N图中的每个变量的检查点文件中找不到{}。在

有没有一种方法可以在不改变N的代码的情况下做到这一点?我基本上只是给N中的每个变量和操作添加了一个父作用域。我可以通过告诉tensorflow忽略父作用域或其他东西来恢复权重到正确的变量吗?在


Tags: the程序网络tfrestoreembeddingvariables作用域
2条回答

我不得不修改一下代码,编写自己的恢复函数。我决定将检查点文件作为字典加载,变量名作为键,相应的numpy数组作为值,如下所示:

checkpoint_path = '/path/to/checkpoint'
from tensorflow.python import pywrap_tensorflow

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

key_to_numpy = {}
for key in var_to_shape_map:
  key_to_numpy[key] = reader.get_tensor(key)

我已经有了一个函数,其中创建了所有变量,并从图N调用该函数,并使用所需的名称。我修改了它以使用从字典查找获得的numpy数组初始化变量。为了使查找成功,我只删除了添加的父名称范围,如下所示:

^{pr2}$

这是一种更为老套的方法。我没有使用@edit的答案,因为我已经编写了上面的代码。另外,我的所有权重都是在一个函数中创建的,该函数将这些权重赋给变量var,并返回它。因为这类似于函数式编程,变量var不断被覆盖。var从不向更高级别的函数公开。要使用@edit的答案,我必须在每次初始化时使用不同的张量变量名,并以某种方式将其公开给更高级别的函数,以便保存程序可以在它们的答案中将它们用作var_to_be_restored_to。在

但是@edit的解决方案是不那么老套的解决方案,因为它遵循文档化的用法。所以我接受这个答案。我所做的可能是另一种解决办法。在

这与:How to restore weights with different names but same shapes Tensorflow?

tf.train.Saver(var_list={'variable_name_in_checkpoint':var_to_be_restored_to,...'})

可以获取要还原的变量列表或字典

(e.g. 'variable_name_in_checkpoint':var_to_be_restored_to,...)

您可以通过遍历当前会话变量中的所有变量来准备上述字典,并使用会话变量作为值并获取当前变量的名称,并从变量名中删除“siameseN/”并将其用作键。从理论上讲,这是可行的。在

相关问题 更多 >