我加载一个图,并希望访问图中定义为h1
,h2
,h3
的权重。在
我可以很容易地用手对每一个权重张量h
进行以下操作:
sess = tf.Session()
graph = tf.get_default_graph()
h1 = sess.graph.get_tensor_by_name("h1:0")
h2 = sess.graph.get_tensor_by_name("h2:0")
我不喜欢这种方法,因为对于一个大的图形来说,它会很难看。我更喜欢像一个循环超过所有的重量张量,把他们放入一个列表。在
我确实在堆栈溢出问题上找到了另外两个问题(here和here),但它们没有帮助我解决这个问题。在
我尝试了以下方法,但有两个问题:
^{pr2}$第一个问题:我必须定义图中权重张量的数量,这使得代码变得不灵活。第二个问题:get_tensor_by_name()
的参数是静态的。在
有没有办法把所有的张量都取出来,然后放到一个列表中?在
如果只关心可以优化的权重,可以调用
tf.trainable_variables()
。它返回一个将trainable
参数设置为True
的所有变量的列表。在印刷品:
^{pr2}$另一方面,
tf.global_variables()
返回所有变量的列表:更新
为了更好地控制您想要接收的变量,有几种方法可以过滤它们。一种方法是openmark建议的。在这种情况下,您可以根据变量范围前缀过滤它们。在
但是,如果这还不够,例如,如果您希望同时访问多个组,还有其他方法。您可以简单地按名称筛选,即:
但是,您必须了解tensorflow变量的命名约定。例如,
:0
后缀,变量范围前缀等等。在第二种方法,较少涉及,是创建自己的集合。例如,如果我对以可被2整除的数结尾的变量感兴趣,而在代码中的其他地方,我对名称以可被4整除的数结尾的变量感兴趣,我可以这样做:
然后我可以简单地调用
tf.get_collection()
函数:或者
您可以尝试tf.get_collection():
它返回由}指定包含不可训练变量的全局变量列表。查看上面的链接以获取可用密钥类型的列表。您还可以指定
^{pr2}$key
和scope
指定的集合中的项目列表。key
是标准图集合中的键tf.GraphKeys,例如,tf.GraphKeys.TRAINABLE_VARIABLES
指定由优化器训练的变量子集,而{scope
参数来筛选结果列表,以仅返回特定名称范围内的项,下面是一个小示例:相关问题 更多 >
编程相关推荐