Tensorflow:如何修改十位数的值

2024-09-30 06:32:16 发布

您现在位置:Python中文网/ 问答频道 /正文

由于在使用Tensorflow训练模型之前需要为数据编写一些预处理,因此需要对tensor进行一些修改。但是,我不知道如何像使用numpy那样修改tensor中的值。

最好的方法是它能够直接修改tensor。然而,在当前版本的Tensorflow中,这似乎是不可能的。另一种方法是将进程的tensor更改为ndarray,然后使用tf.convert_to_tensor重新更改。

关键是如何将tensor更改为ndarray
1) tf.contrib.util.make_ndarray(tensor)https://www.tensorflow.org/versions/r0.8/api_docs/python/contrib.util.html#make_ndarray
根据文档,这似乎是最简单的方法,但是在当前版本的Tensorflow中找不到这个函数。其次,它的输入是TensorProto,而不是tensor
2) 使用a.eval()a复制到另一个ndarray
然而,它只能在笔记本中使用tf.InteractiveSession()

下面是一个带有代码的简单例子。这段代码的目的是使tfc在处理后具有与npc相同的输出。

提示
你应该认为tfcnpc彼此独立。这就满足了这样的情况:首先,检索到的训练数据是tensor格式的,带有^{}


源代码

import numpy as np
import tensorflow as tf
tf.InteractiveSession()

tfc = tf.constant([[1.,2.],[3.,4.]])
npc = np.array([[1.,2.],[3.,4.]])
row = np.array([[.1,.2]])
print('tfc:\n', tfc.eval())
print('npc:\n', npc)
for i in range(2):
    for j in range(2):
        npc[i,j] += row[0,j]

print('modified tfc:\n', tfc.eval())
print('modified npc:\n', npc)

输出:

tfc:
[[1。2.]
〔3〕。4.]]
npc:
[[1。2.]
〔3〕。4.]]
修改后的tfc:
[[1。2.]
〔3〕。4.]]
修改后的npc:
[[1.1 2.2]
[3.1 4.2]]


Tags: 数据方法版本numpytftensorflowutileval
2条回答

我挣扎了一会儿。给出的答案将向图中添加assign操作(因此,如果随后保存检查点,则不必要地增加.meta的大小)。更好的解决方案是使用tf.keras.backend.set_value。我们可以用原始的tensorflow来模拟:

    for x, value in zip(tf.global_variables(), values_npfmt):
      if hasattr(x, '_assign_placeholder'):
        assign_placeholder = x._assign_placeholder
        assign_op = x._assign_op
      else:
        assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape)
        assign_op = x.assign(assign_placeholder)
        x._assign_placeholder = assign_placeholder
        x._assign_op = assign_op
      get_session().run(assign_op, feed_dict={assign_placeholder: value})

使用assign和eval(或sess.run)赋值:

import numpy as np
import tensorflow as tf

npc = np.array([[1.,2.],[3.,4.]])
tfc = tf.Variable(npc) # Use variable 

row = np.array([[.1,.2]])

with tf.Session() as sess:   
    tf.initialize_all_variables().run() # need to initialize all variables

    print('tfc:\n', tfc.eval())
    print('npc:\n', npc)
    for i in range(2):
        for j in range(2):
            npc[i,j] += row[0,j]
    tfc.assign(npc).eval() # assign_sub/assign_add is also available.
    print('modified tfc:\n', tfc.eval())
    print('modified npc:\n', npc)

它输出:

tfc:
 [[ 1.  2.]
 [ 3.  4.]]
npc:
 [[ 1.  2.]
 [ 3.  4.]]
modified tfc:
 [[ 1.1  2.2]
 [ 3.1  4.2]]
modified npc:
 [[ 1.1  2.2]
 [ 3.1  4.2]]

相关问题 更多 >

    热门问题