我怎样才能奖励一个在超级马里奥兄弟这样的游戏中向前迈进的经纪人?我仅有的数据是分数和生命,但有没有办法得到一个特工的坐标?我用NEAT来训练我的经纪人这是代码。我现在奖励它以获得尽可能高的分数,而奖励它按正确的按钮是不起作用的,因为它只会撞到墙上,直到计时器用完为止
import retro
import numpy as np
import cv2
import neat
import pickle
env = retro.make('SuperMarioWorld-Snes', 'Start.state')
imgarray = []
xpos_end = 0
def eval_genomes(genomes, config):
for genome_id, genome in genomes:
ob = env.reset()
ac = env.action_space.sample()
inx, iny, inc = env.observation_space.shape
inx = int(inx / 8)
iny = int(iny / 8)
net = neat.nn.recurrent.RecurrentNetwork.create(genome, config)
current_max_fitness = 0
fitness_current = 0
frame = 0
counter = 0
xpos = 0
xpos_max = 0
done = False
# cv2.namedWindow("main", cv2.WINDOW_NORMAL)
while not done:
env.render()
frame += 1
# scaledimg = cv2.cvtColor(ob, cv2.COLOR_BGR2RGB)
# scaledimg = cv2.resize(scaledimg, (iny, inx))
ob = cv2.resize(ob, (inx, iny))
ob = cv2.cvtColor(ob, cv2.COLOR_BGR2GRAY)
ob = np.reshape(ob, (inx, iny))
# cv2.imshow('main', scaledimg)
# cv2.waitKey(1)
imgarray = np.ndarray.flatten(ob)
nnOutput = net.activate(imgarray)
for i in range(len(nnOutput)):
nnOutput[i] = int(nnOutput[i])
if nnOutput[i] < 0:
nnOutput[i] = 0
ob, rew, done, info = env.step(nnOutput)
# xpos = info['x']
# xpos_end = info['screen_x_end']
# if xpos > xpos_max:
# fitness_current += 1
# xpos_max = xpos
# if xpos == xpos_end and xpos > 500:
# fitness_current += 100000
# done = True
fitness_current += rew
print(env.statename)
if fitness_current > current_max_fitness:
current_max_fitness = fitness_current
counter = 0
else:
counter += 1
if done or counter == 250:
done = True
print(genome_id, fitness_current)
genome.fitness = fitness_current
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
'config.txt')
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(10))
winner = p.run(eval_genomes)
with open('winner.pkl', 'wb') as output:
pickle.dump(winner, output, 1)
使用
print( retro.__file__ )
我找到了模块为retro
的文件夹,并检查了我找到的SuperMarioWorld
文件夹的所有子文件夹在我的Linux上是这样的
有一个文件
data.json
,它定义了retro
如何在ROM
中查找score
和lives
在OpenAI-Retro-SuperMarioWorld-SNES中,我找到了data.json,它也有
x
、y
等的信息如果我替换
data.json
,那么我可以在代码中获得info["x"]
但是我不确定这个文件是否适用于
SuperMario
的每个版本我用我在上找到的
Super Mario World (Europe) (Rev 1)
进行了测试https://ia800201.us.archive.org/view_archive.php?archive=/7/items/No-Intro-Collection_2016-01-03_Fixed/Nintendo%20-%20Super%20Nintendo%20Entertainment%20System.zip
但也有其他版本——欧洲、美国、日本
相关问题 更多 >
编程相关推荐