Tensorflow:通过nam得到所有的权重张量

2024-05-07 20:35:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我加载一个图,并希望访问图中定义为h1h2h3的权重。在

我可以很容易地用手对每一个权重张量h进行以下操作:

sess = tf.Session()
graph = tf.get_default_graph()
h1 = sess.graph.get_tensor_by_name("h1:0")
h2 = sess.graph.get_tensor_by_name("h2:0")

我不喜欢这种方法,因为对于一个大的图形来说,它会很难看。我更喜欢像一个循环超过所有的重量张量,把他们放入一个列表。在

我确实在堆栈溢出问题上找到了另外两个问题(herehere),但它们没有帮助我解决这个问题。在

我尝试了以下方法,但有两个问题:

^{pr2}$

第一个问题:我必须定义图中权重张量的数量,这使得代码变得不灵活。第二个问题:get_tensor_by_name()的参数是静态的。在

有没有办法把所有的张量都取出来,然后放到一个列表中?在


Tags: 方法name列表getbyhere定义tf
2条回答

如果只关心可以优化的权重,可以调用tf.trainable_variables()。它返回一个将trainable参数设置为True的所有变量的列表。在

tf.reset_default_graph()

# These can be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))

# These cannot be optimized
for i in range(5):
    tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="n{}".format(i), trainable=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = tf.get_default_graph()
    for t_var in tf.trainable_variables():
        print(t_var)

印刷品:

^{pr2}$

另一方面,tf.global_variables()返回所有变量的列表:

for g_var in tf.global_variables():
    print(g_var)
<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n4:0' shape=(32, 32) dtype=float32_ref>

更新

为了更好地控制您想要接收的变量,有几种方法可以过滤它们。一种方法是openmark建议的。在这种情况下,您可以根据变量范围前缀过滤它们。在

但是,如果这还不够,例如,如果您希望同时访问多个组,还有其他方法。您可以简单地按名称筛选,即:

for g_var in tf.global_variables():
  if g_var.name.beginswith('h'):
    print(g_var) 

但是,您必须了解tensorflow变量的命名约定。例如,:0后缀,变量范围前缀等等。在

第二种方法,较少涉及,是创建自己的集合。例如,如果我对以可被2整除的数结尾的变量感兴趣,而在代码中的其他地方,我对名称以可被4整除的数结尾的变量感兴趣,我可以这样做:

# These can be optimized
for i in range(5):
    h_var = tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))
    if i % 2 == 0:
      tf.add_to_collection('vars_divisible_by_2', h_var)
    if i % 4 == 0:
      tf.add_to_collection('vars_divisible_by_4', h_var)

然后我可以简单地调用tf.get_collection()函数:

tf.get_collection('vars_divisible_by_2)
[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]

或者

tf.get_collection('vars_divisible_by_4'):
[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
 <tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]

您可以尝试tf.get_collection()

tf.get_collection(
key,
scope=None)

它返回由keyscope指定的集合中的项目列表。key是标准图集合中的键tf.GraphKeys,例如,tf.GraphKeys.TRAINABLE_VARIABLES指定由优化器训练的变量子集,而{}指定包含不可训练变量的全局变量列表。查看上面的链接以获取可用密钥类型的列表。您还可以指定scope参数来筛选结果列表,以仅返回特定名称范围内的项,下面是一个小示例:

^{pr2}$

相关问题 更多 >