如何在恢复保存的模型后获取/打印张量值?

2024-10-02 12:32:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我在TF中建立CNN模型。
我保存了一些变量

wc1 = tf.Variable(tf.random_normal([5, 5, 1, 32]), name='wc1')
wc2 = tf.Variable(tf.random_normal([5, 5, 32, 64]), name='wc2')

通过

^{pr2}$

当我在另一个会话中恢复保存的模型并使用tf.打印(),无法打印。下面的代码用于恢复模型

sess = tf.Session()
saver = tf.train.import_meta_graph("./cnn_model.meta")
saver.restore(sess, './cnn_model')
wc1 = tf.get_default_graph().get_tensor_by_name("wc1:0")
wc2 = tf.get_default_graph().get_tensor_by_name("wc2:0")
while some_step:
    sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
    wc1 = tf.Print(wc1, [wc1], 'WC1 is: ')

如何打印/获取保存模型的张量值?


Tags: name模型getmodeltfrandomvariablecnn
1条回答
网友
1楼 · 发布于 2024-10-02 12:32:16

只需执行sess.run(wc1),就可以获得模型的值。下面是我的代码示例:

>>> import tensorflow as tf
>>> wc1 = tf.Variable(tf.random_normal([5, 5, 1, 32]), name='wc1')
>>> saver = tf.train.Saver([wc1])
>>> with tf.Session('') as sess:
...   tf.global_variables_initializer().run(session=sess)
...   saver.save(sess, './cnn_model')
'./cnn_model'
>>> sess = tf.Session('')
>>> saver = tf.train.import_meta_graph("./cnn_model.meta")
>>> saver.restore(sess, './cnn_model')
INFO:tensorflow:Restoring parameters from ./cnn_model
>>> wc_r1 = tf.get_default_graph().get_tensor_by_name('wc1:0')
>>> sess.run(wc_r1)
array([[[[ 0.82639563, -0.33938187,  0.26812711, -0.32433796,  1.2584244 ,
          -0.25379655, -0.16618967,  0.27060306,  1.53495347,  0.75791109,
          -0.87073582,  1.48225808,  1.13401747, -1.80606318,  1.0940119 ,
           0.52464408, -0.24058162, -1.36783814, -0.04032131,  0.82713342,
           1.32288456, -1.32494891,  0.93615007, -0.74220407,  1.13950729,
           0.39443189,  1.81868839,  0.91872966,  1.73204434, -1.26066136,
          -1.12299716, -1.26222265]],
    ...

相关问题 更多 >

    热门问题