我试图训练一个神经网络来弹起一个球,但我在预测球的运动时遇到了问题,得到了错误ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)
我的代码:
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import csv
import pygame
running=True
def main():
# load training data
data_path = 'x.dat'
with open(data_path, 'r') as f:
reader = csv.reader(f, delimiter=',')
headers = next(reader)
x_train = np.array(list(reader)).astype(float)
data_path = 'y.dat'
with open(data_path, 'r') as f:
reader = csv.reader(f, delimiter=',')
headers = next(reader)
y_train = np.array(list(reader)).astype(float)
# debug print statement
print(x_train)
# define the keras model
model = Sequential()
model.add(Dense(8, input_shape=(8,), activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(4, activation='relu'))
# compile the keras model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# fit the keras model on the dataset
model.fit(x_train, y_train, epochs=1, batch_size=1000)
# pygame initialization to visualize the ball
global running, screen
pygame.init()
screen = pygame.display.set_mode((200, 200))
pygame.display.set_caption("BallPhysics")
screen.fill((255,255,255))
pygame.display.update()
# ball info
x = 25
y = 10
xVel = -2
yVel = 0
gravity = 0.1
elasticity = 0.9999
radius = 10
friction = 0.999
while running:
screen.fill((255,255,255))
ev = pygame.event.get()
# draw ball
pygame.draw.circle(screen, (255, 0, 0), (x,y), radius)
pygame.display.update()
# make input array
inp = [x/200,y/200,((xVel/10)+1)/2,((yVel/10)+1)/2,gravity,elasticity,radius/50.0,friction]
print(inp)
out = model.predict(inp)
# set ball position and velocity to NN output
x = out[0][0]
y = out[0][1]
xVel = out[0][2]
yVel = out[0][3]
# event handling
for event in ev:
if event.type == pygame.QUIT:
running = False
pygame.display.quit()
pygame.quit()
pygame.display.flip()
main()
调试打印语句将打印出来
[[0.025568 0.131659 0.755605 ... 0.414219 0.094692 0.678865]
...
[0.08742 0.08742 0.5 ... 0.250432 0.699359 0.179118]]
这有点让人困惑,因为我注意到打印出来的数组没有逗号,而我制作的数组有逗号。这可能与此有关,但我不知道是什么原因。 感谢您的帮助
堆栈跟踪:
Traceback (most recent call last):
File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 78, in <module>
main()
File "/Users/grimtin10/Documents/Python Projects/BallPhysics/BallPhysics.py", line 63, in main
out = model.predict(inp)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 1149, in predict
x, _, _ = self._standardize_user_data(x)
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training.py", line 751, in _standardize_user_data
exception_prefix='input')
File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_utils.py", line 138, in standardize_input_data
str(data_shape))
ValueError: Error when checking input: expected dense_1_input to have shape (8,) but got array with shape (1,)
好吧,我想起来了。 事实证明,NN预期的是批处理输入,导致单个数组无法工作。我对tf和keras不太熟悉,所以这就是问题所在。 使用
np.reshape(inp,(1,8))
修复了它相关问题 更多 >
编程相关推荐