使用3D输入训练Keras LSTM

2024-09-28 17:27:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用NFL tracking data来预测比赛结束时的进攻码数结果。为此,我的输入将给train_x,其中包含一个1次播放的跟踪数据数组,与具有码数结果的y_train中的浮点配对

我应该如何安排数据来训练LSTM模型I have been trying to use this tutorial 但它不使用3d输入。LSTM模型可以处理我正在尝试做的事情吗

到目前为止,我已经尝试过:

def isolatePlay(data, gameNum, playNum):
    MAX_X_YARDS = 120
    MAX_Y_YARDS = 53.3
    d = data[data['gameId'] == gameNum]
    d = d[d['playId'] == playNum].fillna(0)
    #normalize x ,y...
    sub = d[["x","y", "s", "a", "dis", "o", "dir"]].to_numpy()
    norm = Normalizer().fit(sub)

    return norm.transform(sub)



print("creating ML training, test, and validation datasets")
first = True
for  rows in plays.itertuples():
    #print(getattr(rows, 'gameId'), gameMax)
    play = isolatePlay(week, getattr(rows, 'gameId'), getattr(rows, 'playId'))
    if (first):
        x = [play]
        y = [[getattr(rows, 'offensePlayResult')]]
        first = False
    else:
        x.append(play)
        y.append([getattr(rows, 'offensePlayResult')])
        
train_x, test_x, train_y, test_y = train_test_split(np.array(x), np.array(y), test_size=0.3)
test_x, val_x, test_y, val_y = train_test_split(test_x, test_y, test_size=0.5)
print("x data:[0]", train_x[0])
print("x data:[1]", train_x[1])

print("ML Dataset Preparation Complete")



 # create the model
embedding_vecor_length = 32
model = Sequential()
model.add(LSTM(100))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(train_x, train_y, validation_data=(val_x, val_y), epochs=3, batch_size=64)
print(model.summary())

# Final evaluation of the model
scores = model.evaluate(test_x, test_y, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))

但这并没有奏效

ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type list).

我认为这与将我的输入重塑为表单[samples,timestep,features]有关,但我不确定这对3d输入意味着什么。时间步长在跟踪数据中,并且特征数与跟踪数据行数不同(每1个结果约1500行)。任何涉及与输入数据维度相似的神经网络项目的参考资料都将不胜感激,因为我一直很难找到任何参考资料

我要传的剧本是这样的

shape (1296, 7)
normalized tracking data
 [[ 1.50585604e-01  9.04000519e-02  3.59760829e-03 ...  3.51645923e-04
   7.88471309e-01  5.89439716e-01]
 [ 2.58656328e-01  1.55707513e-01  1.74663657e-03 ...  5.13716639e-04
   5.48443883e-01  7.79770486e-01]
 [ 1.55811974e-01  1.32418437e-01  1.44941371e-04 ...  0.00000000e+00
   2.80084706e-01  9.37944603e-01]
 ...
 [ 6.18136846e-01  5.88895155e-03  1.02887429e-02 ...  1.35378197e-03
   3.76622143e-01  6.89278088e-01]
 [ 3.58395476e-01  1.52762473e-01  2.09288701e-02 ...  2.10202625e-03
   6.65474093e-01  6.36319903e-01]
 [ 9.97562556e-01  3.33489046e-03  6.60523466e-02 ...  6.56220381e-03
  -1.07577112e-02 -1.07577112e-02]]
shape (832, 7)
normalized tracking data
 [[ 2.37484336e-01  6.72201854e-02  1.79624884e-03 ...  1.26496397e-04
   8.06945816e-01  5.36572417e-01]
 [ 5.23191345e-01  1.45525783e-01  0.00000000e+00 ...  0.00000000e+00
   5.63912408e-01  6.22170278e-01]
 [ 2.46464025e-01  2.66599975e-02  3.22175196e-04 ...  8.05437991e-05
   3.36028730e-01  9.08641445e-01]
 ...
 [ 4.92441881e-01  7.33912651e-03  2.67000124e-02 ...  2.70151896e-03
   6.38594057e-01  5.90552045e-01]
 [ 4.38248969e-01  6.56688483e-02  1.62183495e-02 ...  1.63509245e-03
   6.60179626e-01  6.06221575e-01]
 [ 9.97770139e-01  1.13661959e-02  6.40144152e-02 ...  6.45599926e-03
  -9.09295670e-03 -9.09295670e-03]]

Tags: to数据testplaydatamodeltrainval