培训LSTM时从csv加载文件时出错

2024-10-02 14:26:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我对tensorflow还比较陌生,正在尝试在tensorflow中为我爬网并保存到CSV中的一些数据训练一个两层LSTM。但是,在我使用tensorflow网站上显示的方法之后,我不断得到以下错误:

TypeError: inputs must be a sequence

原始代码为:

file = tf.train.string_input_producer(['players_raw.csv'],
num_epochs=100, shuffle=False)
reader = tf.TextLineReader()
key, val = reader.read(file)
gameNum, age, team, homeAway, opponent, pointDiff, secs, orb, drb, ast, stl, blk, to, pts, fanPts = tf.decode_csv(val, record_defaults=defaults)
features = tf.pack([gameNum, age, team, homeAway, opponent, pointDiff, secs, orb, drb, ast, stl, blk, to, pts])
label = tf.pack([fanPts]);

lstmCell = rnn_cell.LSTMCell(NUM_FEATURES)
stacked = rnn_cell.MultiRNNCell([lstmCell] * 2)
outputs, states = rnn.rnn(stacked, features, dtype=tf.float32)

最后一行是导致错误的原因。我想我知道问题出在哪里,但我不知道该如何着手解决它


Tags: csvagetftensorflow错误valteamreader