在张量流中,如何遍历存储在张量中的输入序列?

2024-09-23 04:30:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我在一个变长多元序列分类问题上尝试RNN。

我定义了以下函数来获得序列的输出(即,在序列的最终输入被馈送之后,RNN单元的输出)

def get_sequence_output(x_sequence, initial_hidden_state):
    previous_hidden_state = initial_hidden_state
    for x_single in x_sequence:
        hidden_state = gru_unit(previous_hidden_state, x_single)
        previous_hidden_state = hidden_state
    final_hidden_state = hidden_state
    return final_hidden_state

这里x_sequence是形状的张量(?, ?, 10)首先在哪里?是批量和第二批?表示序列长度,每个输入元素的长度为10。gru函数接受上一个隐藏状态和当前输入,并输出下一个隐藏状态(标准选通递归单元)。

我得到一个错误:'Tensor' object is not iterable. 如何按顺序迭代张量(一次读取单个元素)?

我的目标是对序列中的每个输入应用gru函数,并获得最终的隐藏状态。


Tags: 函数元素状态分类序列hiddeninitial单元
2条回答

在TF>;=1.0中,tf.packtf.unpack分别重命名为tf.stacktf.unstack

可以使用unpack函数将张量转换为列表,该函数将第一个维度转换为列表。还有一个split函数可以做类似的事情。我在我正在研究的RNN模型中使用unstack。

y = tf.unstack(tf.transpose(y, (1, 0, 2)))

在本例中,y从shape(BATCH_SIZE,TIME_STEPS,128)开始,我将其转置,使TIME STEPS成为外部维度,然后将其解压成一个张量列表,每次解压一个张量。现在y列表中的每个元素如果是形状的(BATCH_SIZE,128),我可以将其输入RNN。

相关问题 更多 >