在python for循环中,Tensorflow太慢

2024-09-30 14:24:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在Tensorflow中创建一个函数,对于给定数据X的每一行,只对一些采样类应用softmax函数,比如K个总类中的2个,并返回一个矩阵S,其中S.shape = (N,K)(N:给定数据的行,K是总类)。在

矩阵S最终将包含零,以及由采样类为每行定义的索引中的非零值。在

在简单的python中,我使用高级索引,但在Tensorflow中,我不知道如何创建它。我最初的问题是this, where I present the numpy code。在

因此,我试图在张量流中找到一个解决方案,其主要思想不是将S用作二维矩阵,而是用作一维数组。代码是这样的:

num_samps = 2
S = tf.Variable(tf.zeros(shape=(N*K)))
W = tf.Variable(tf.random_uniform((K,D)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for line in range(N):
    inds_new = sampled_ind + line*K
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})

S = tf.reshape(S,shape=(N,K))

这是有效的,结果是预期的。但它的运行速度非常慢。为什么会这样?我怎样才能更快地完成任务?在


Tags: 数据函数tftensorflowline矩阵numsess
2条回答

在tensorflow中编程时,了解定义操作和执行操作之间的区别是至关重要的。在python中运行时,大多数以tf.开头的函数都会向计算图添加操作。在

例如,当您执行以下操作时:

tf.scatter_update(S,inds_new,updates)

以及:

^{pr2}$

很多次,你的计算图增长超出了需要的范围,填满了所有的内存,并极大地减慢了速度。在

您应该做的是在循环之前定义一次计算:

init = tf.initialize_all_variables()
inds_new = sampled_ind + line*K
update_op = tf.scatter_update(S, inds_new, updates)
sess = tf.Session()
sess.run(init)
for line in range(N):
    sess.run(update_op, feed_dict={tfx: X[line:line+1]})

这样,计算图只包含inds_newupdate_op的一个副本。请注意,当您执行update_op时,inds_new也将被隐式执行,因为它在计算图中是其父级。在

{cd3>你应该知道,每一次运行的结果都是不同的。在

另外,调试这类问题的一个很好的方法是使用张量板可视化计算图。在代码中添加:

summary_writer = tf.train.SummaryWriter('some_logdir', sess.graph_def)

然后在控制台中运行:

tensorboard  logdir=some_logdir

在服务的html页面上会有一个计算图的图片,在这里你可以检查你的张量。在

请记住tf.scatter_更新将返回张量S,这意味着在会话运行中会有一个大内存拷贝,甚至在分布式环境中会有网络拷贝。根据@sygi的回答,解决方案是:

update_op = tf.scatter_update(S, inds_new, updates)
update_op_op = update_op.op

然后在会话运行中,你这样做

^{pr2}$

这将避免复制大张量S

相关问题 更多 >