函数tf.squeze和tf.nn.rnn做什么?

2024-05-07 15:20:39 发布

您现在位置:Python中文网/ 问答频道 /正文

函数tf.squeze和tf.nn.rnn做什么?

我搜索了这些API,但找不到参数、示例等。 另外,以下代码使用tf.squeeze形成的p_inputs的形状是什么,使用tf.nn.rnn的含义和情况是什么?

batch_num = 10
step_num = 2000
elem_num = 26

p_input = tf.placeholder(tf.float32, [batch_num, step_num, elem_num])
p_inputs = [tf.squeeze(t, [1]) for t in tf.split(1, step_num, p_input)]

Tags: 函数api示例input参数tfstepbatch
2条回答

像这样的问题最好的答案是TensorFlow API documentation。您提到的两个函数在数据流图中创建操作和符号张量。特别是:

  • ^{}函数返回一个张量,其值与其第一个参数相同,但形状不同。它删除大小为1的维度。例如,如果t是形状为[batch_num, 1, elem_num]的张量(如您所问),则tf.squeeze(t, [1])将返回内容相同但大小为[batch_num, elem_num]的张量。

  • ^{}函数返回一对结果,其中第一个元素表示某个给定输入的递归神经网络的输出,第二个元素表示该输入的该网络的最终状态。TensorFlow网站有一个tutorial on recurrent neural networks包含更多细节。

挤压移除大小为“1”的变形。下面的示例将显示挤压的用法。

import tensorflow as tf
tf.enable_eager_execution() ##if using TF1.4 for TF2.0 eager mode is the default mode.
####example 1
a = tf.constant(value=[1,3,4,5],shape=(1,4))
print(a)
Output : tf.Tensor([[1 3 4 5]], shape=(1, 4), dtype=int32)

#after applying tf.squeeze shape has been changed from  (4,1) to (4, )
b = tf.squeeze(input=a)
print(b)
output: tf.Tensor([1 3 4 5], shape=(4,), dtype=int32)
####example2
a = tf.constant(value=[1,3,4,5,4,6], shape=(3,1,2))
print(a)
Output:tf.Tensor(
[[[1 3]]
 [[4 5]]
 [[4 6]]], shape=(3, 1, 2), dtype=int32)

#after applying tf.squeeze shape has been chnaged from (3, 1, 2) to (3, 2)
b = tf.squeeze(input=a)
print(b)
Output:tf.Tensor(
[[1 3]
 [4 5]
 [4 6]], shape=(3, 2), dtype=int32)

相关问题 更多 >