用十个数字一次将多个切片切片

2024-10-01 15:41:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在为tensorflow RNN准备输入张量。
目前我正在做以下工作

rnn_format = list()
for each in range(batch_size):
    rnn_format.append(tf.slice(input2Dpadded,[each,0],[max_steps,10]))
lstm_input = tf.stack(rnn_format)

有没有可能在没有循环的情况下,用张量流函数,一次完成这个任务吗?在


Tags: informatforsizetftensorflowbatchslice
2条回答

正如peterhawkins建议的那样,您可以使用gather_nd和适当的索引来达到这个目的。在

内部维度上的统一裁剪可以在调用gather_nd之前完成。在

示例:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# integer image simply because it is more readable to me
im0 = np.random.randint(10, size=(20,20))
im = tf.constant(im0)

max_steps = 3
batch_size = 10

# create the appropriate indices here
indices = (np.arange(max_steps) +
    np.arange(batch_size)[:,np.newaxis])[...,np.newaxis]
# crop then call gather_nd
res = tf.gather_nd(im[:,:10], indices).eval()

# check that the resulting tensors are equal to what you had previously
for each in range(batch_size):
  assert(np.all(tf.slice(im, [each,0],[max_steps,10]).eval() == res[each]))

编辑

如果切片索引在张量中,则在创建indices时,只需将numpy的操作替换为tensorflow的操作:

^{pr2}$

进一步说明:

  • indices是在加法过程中利用broadcasting创建的:数组实际上是平铺的,以便它们的维度匹配。广播由numpy和tensorflow以类似的方式支持。在
  • 省略号...是标准numpy slicing notation的一部分,它基本上填充了其他切片索引留下的所有剩余维度。所以[..., newaxis]基本上等同于expand_dims(·, -1)。在

尝试tf.splittf.split_v。请看这里:

https://www.tensorflow.org/api_docs/python/tf/split

这有帮助吗?在

相关问题 更多 >

    热门问题