向前迈进的健身房复古奖励

2024-05-18 07:34:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我怎样才能奖励一个在超级马里奥兄弟这样的游戏中向前迈进的经纪人?我仅有的数据是分数和生命,但有没有办法得到一个特工的坐标?我用NEAT来训练我的经纪人这是代码。我现在奖励它以获得尽可能高的分数,而奖励它按正确的按钮是不起作用的,因为它只会撞到墙上,直到计时器用完为止

import retro
import numpy as np
import cv2
import neat
import pickle

env = retro.make('SuperMarioWorld-Snes', 'Start.state')

imgarray = []

xpos_end = 0


def eval_genomes(genomes, config):
    for genome_id, genome in genomes:
        ob = env.reset()
        ac = env.action_space.sample()

        inx, iny, inc = env.observation_space.shape

        inx = int(inx / 8)
        iny = int(iny / 8)

        net = neat.nn.recurrent.RecurrentNetwork.create(genome, config)

        current_max_fitness = 0
        fitness_current = 0
        frame = 0
        counter = 0
        xpos = 0
        xpos_max = 0

        done = False
        # cv2.namedWindow("main", cv2.WINDOW_NORMAL)

        while not done:

            env.render()
            frame += 1
            # scaledimg = cv2.cvtColor(ob, cv2.COLOR_BGR2RGB)
            # scaledimg = cv2.resize(scaledimg, (iny, inx))
            ob = cv2.resize(ob, (inx, iny))
            ob = cv2.cvtColor(ob, cv2.COLOR_BGR2GRAY)
            ob = np.reshape(ob, (inx, iny))
            # cv2.imshow('main', scaledimg)
            # cv2.waitKey(1)

            imgarray = np.ndarray.flatten(ob)

            nnOutput = net.activate(imgarray)
            for i in  range(len(nnOutput)):
                nnOutput[i] = int(nnOutput[i])
                if nnOutput[i] < 0:
                    nnOutput[i] = 0


            ob, rew, done, info = env.step(nnOutput)

            # xpos = info['x']
            # xpos_end = info['screen_x_end']

            # if xpos > xpos_max:
            # fitness_current += 1
            # xpos_max = xpos

            # if xpos == xpos_end and xpos > 500:
            # fitness_current += 100000
            # done = True




            fitness_current += rew
            print(env.statename)
            if fitness_current > current_max_fitness:
                current_max_fitness = fitness_current
                counter = 0
            else:
                counter += 1

            if done or counter == 250:
                done = True
                print(genome_id, fitness_current)

            genome.fitness = fitness_current


config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                     neat.DefaultSpeciesSet, neat.DefaultStagnation,
                     'config.txt')

p = neat.Population(config)

p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
p.add_reporter(neat.Checkpointer(10))

winner = p.run(eval_genomes)

with open('winner.pkl', 'wb') as output:
    pickle.dump(winner, output, 1)




Tags: importenvconfiggenomecurrentcv2neatmax
1条回答
网友
1楼 · 发布于 2024-05-18 07:34:45

使用print( retro.__file__ )我找到了模块为retro的文件夹,并检查了我找到的SuperMarioWorld文件夹的所有子文件夹

在我的Linux上是这样的

/usr/local/lib/python3.8/dist-packages/retro/data/stable/SuperMarioWorld-Snes

有一个文件data.json,它定义了retro如何在ROM中查找scorelives

OpenAI-Retro-SuperMarioWorld-SNES中,我找到了data.json,它也有xy等的信息

如果我替换data.json,那么我可以在代码中获得info["x"]

但是我不确定这个文件是否适用于SuperMario的每个版本

我用我在上找到的Super Mario World (Europe) (Rev 1)进行了测试

https://ia800201.us.archive.org/view_archive.php?archive=/7/items/No-Intro-Collection_2016-01-03_Fixed/Nintendo%20-%20Super%20Nintendo%20Entertainment%20System.zip

但也有其他版本——欧洲、美国、日本

相关问题 更多 >