<p>如果只关心可以优化的权重,可以调用<code>tf.trainable_variables()</code>。它返回一个将<code>trainable</code>参数设置为<code>True</code>的所有变量的列表。在</p>
<pre><code>tf.reset_default_graph()
# These can be optimized
for i in range(5):
tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))
# These cannot be optimized
for i in range(5):
tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="n{}".format(i), trainable=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
graph = tf.get_default_graph()
for t_var in tf.trainable_variables():
print(t_var)
</code></pre>
<p>印刷品:</p>
^{pr2}$
<p>另一方面,<code>tf.global_variables()</code>返回所有变量的列表:</p>
<pre><code>for g_var in tf.global_variables():
print(g_var)
</code></pre>
<pre><code><tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n0:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n1:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n2:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n3:0' shape=(32, 32) dtype=float32_ref>
<tf.Variable 'n4:0' shape=(32, 32) dtype=float32_ref>
</code></pre>
<p><strong>更新</strong></p>
<p>为了更好地控制您想要接收的变量,有几种方法可以过滤它们。一种方法是<a href="https://stackoverflow.com/users/9535747/openmark">openmark</a>建议的。在这种情况下,您可以根据变量范围前缀过滤它们。在</p>
<p>但是,如果这还不够,例如,如果您希望同时访问多个组,还有其他方法。您可以简单地按名称筛选,即:</p>
<pre><code>for g_var in tf.global_variables():
if g_var.name.beginswith('h'):
print(g_var)
</code></pre>
<p>但是,您必须了解tensorflow变量的命名约定。例如,<code>:0</code>后缀,变量范围前缀等等。在</p>
<p>第二种方法,较少涉及,是创建自己的集合。例如,如果我对以可被2整除的数结尾的变量感兴趣,而在代码中的其他地方,我对名称以可被4整除的数结尾的变量感兴趣,我可以这样做:</p>
<pre><code># These can be optimized
for i in range(5):
h_var = tf.Variable(tf.random_normal(dtype=tf.float32, shape=[32,32]), name="h{}".format(i))
if i % 2 == 0:
tf.add_to_collection('vars_divisible_by_2', h_var)
if i % 4 == 0:
tf.add_to_collection('vars_divisible_by_4', h_var)
</code></pre>
<p>然后我可以简单地调用<code>tf.get_collection()</code>函数:</p>
<pre><code>tf.get_collection('vars_divisible_by_2)
</code></pre>
<pre><code>[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
<tf.Variable 'h2:0' shape=(32, 32) dtype=float32_ref>,
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]
</code></pre>
<p>或者</p>
<pre><code>tf.get_collection('vars_divisible_by_4'):
</code></pre>
<pre><code>[<tf.Variable 'h0:0' shape=(32, 32) dtype=float32_ref>,
<tf.Variable 'h4:0' shape=(32, 32) dtype=float32_ref>]
</code></pre>