几种常用强化学习模型的实现
rlmodels的Python项目详细描述
rlmodels:强化学习库
本计画是针对强化学习问题的一些流行优化演算法的集合。目前提供的型号是:
- dqn
- DDPG
- cmaes
- 空调
以后还会有更多的。
它与pytorch模型和openai健身房等环境类一起工作。任何模拟其基本功能的环境类包装器都应该是好的,但下面将详细介绍。
开始
先决条件
该项目使用python 3.6
和torch 1.1.0
。
安装
它可以直接从pip安装,如
pip install rlmodels
用法
下面是程序工作原理的总结。要查看完整文档,请单击here
初始化
下面是一个使用双q网络的流行cartpole环境的示例。首先是设置。
importnumpyasnpimporttorchimporttorch.optimasoptimimportgymfromrlmodels.models.DQNimport*fromrlmodels.netsimportVanillaNetimportlogging#logger parametersFORMAT='%(asctime)-15s: %(message)s'logging.basicConfig(level=logging.INFO,format=FORMAT,filename="model_fit.log",filemode="a")max_ep_ts=200env=gym.make('CartPole-v0')env._max_episode_steps=max_ep_tsenv.seed(1)np.random.seed(1)torch.manual_seed(1)
事件和时间步骤编号以及平均奖励跟踪记录到文件model_fit.log
。将记录级别设置为logging.DEBUG
也将记录有关梯度下降步骤的信息。
该库还有一个基本的网络定义vanillanet,我们只需要指定隐藏层的数量和大小、输入和输出大小以及最后激活函数。默认情况下,它在其他地方使用relu。
让我们创建基本对象
dqn_scheduler=DQNScheduler(batch_size=lambdat:200,#constantexploration_rate=lambdat:max(0.01,0.05-0.01*int(t/2500)),#decrease exploration down to 1% after 10,000 stepsPER_alpha=lambdat:1,#constantPER_beta=lambdat:1,#constanttau=lambdat:100,#constantagent_lr_scheduler_fn=lambdat:1.25**(-int(t/1000)),#decrease step size every 2,500 steps,steps_per_update=lambdat:1)#constantagent_lr=0.5#initial learning rateagent_model=VanillaNet([60],4,2,None)agent_opt=optim.SGD(agent_model.parameters(),lr=agent_lr,weight_decay=0,momentum=0)agent=Agent(agent_model,agent_opt)
这些模型以调度器对象作为参数,允许在运行时根据用户定义的规则更改参数。例如,如上所述,在一定次数的迭代之后降低学习率和探索率。最后,所有基于梯度的算法接收一个包含网络定义和优化算法的Agent
实例作为输入。一旦一切就绪,我们就可以出发了。
dqn=DQN(agent,env,dqn_scheduler)dqn.fit(n_episodes=170,max_ts_by_episode=max_ep_ts,max_memory_size=2000,td_steps=1)
一旦特工接受了训练,我们就能看到奖赏的痕迹。如果我们使用带有呈现方法的环境(比如openai),我们还可以将经过训练的代理可视化。我们也可以使用ddq
对象的forward
方法来使用经过训练的模型,或者使用ddq.agent
dqn.plot()#plot reward tracesdqn.play(n=200)#observe the agent play
有关其他算法的类似用法,请参见example
文件夹。
环境
对于自定义环境或自定义奖励,可以制作一个包装器来模拟健身房环境的s te p()和reset()函数的行为
classMyCustomEnv(object):def__init__(self,env):self.env=envdefstep(self,action):## get next state s, reward, termination flag (boolean) and any additional inforeturns,r,terminated,info#need to output these 4 things (info can be None)defreset(self):#somethingdefseed(self):#something