前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >OpenAI gym——一款开发和比较RL算法的工具包

OpenAI gym——一款开发和比较RL算法的工具包

作者头像
CristianoC
发布2020-05-31 15:16:27
1.1K0
发布2020-05-31 15:16:27
举报

一、前言

手动编写环境是一件很耗时间的事情,所以如果可以直接使用比人编写好的环境,可以节约我们很多时间。OpenAI gym就是这样一个模块,他提供给我们很多优秀的模拟环境。我们的各种强化学习算法都能使用这些环境。之前的环境都是用tkinter来手动编写,或者想玩玩更厉害的,像OpenAI一样,使用pyglet模块来编写。 OpenAI gym官网:https://gym.openai.com/

我们可以先看看OpenAI gym有哪些游戏: 有2D的:

也有3D的:

本次将会以CartPole和MountainCar两个经典例子来给大家说明。

二、安装

笔者电脑是Ubuntu16.04,可以直接复制下面代码安装:

代码语言:javascript
复制
1# python 2.7, 复制下面
2$ pip install gym
3
4# python 3.5, 复制下面
5$ pip3 install gym

如果没有报错那就安装好gym(基本款),可以玩以下游戏:

  • algorithmic
  • toy_text
  • classic_control(这个需要pyglet模块)

如果你想玩gym提供的全套游戏,则使用以下代码:

代码语言:javascript
复制
1# python 2.7, 复制下面
2$ pip install gym[all]
3
4# python 3.5, 复制下面
5$ pip3 install gym[all]

三、CartPole例子

这个游戏的目的是让小车尽量不偏离中心以及棍子尽量垂直,我们可以看下面的示例图,经过训练后小车就会尽量呆在中间棍子也基本保持垂直。

主循环

我们还是采用DQN的方式来实现RL,完整代码最后会给我的github链接。

代码语言:javascript
复制
 1import gym
 2from RL_brain import DeepQNetwork
 3
 4env = gym.make('CartPole-v0') #定义使用gym库中的哪一个环境
 5env = env.unwrapped #还原env的原始设置,env外包了一层防作弊层
 6
 7print(env.action_space) #查看这个环境可用的action有多少个
 8print(env.observation_space) #查看这个环境中可用的state的observation有多少个
 9print(env.observation_space.high) #查看observation最高取值
10print(env.observation_space.low) #查看observation最低取值
11
12RL = DeepQNetwork(n_actions=env.action_space.n,
13                  n_features=env.observation_space.shape[0],
14                  learning_rate=0.01, e_greedy=0.9,
15                  replace_target_iter=100, memory_size=2000,
16                  e_greedy_increment=0.001,)
17
18total_steps = 0
19
20
21for i_episode in range(100):
22    #获取回合i_episode第一个observation
23    observation = env.reset()
24    ep_r = 0
25    while True:
26        env.render()#刷新环境
27
28        action = RL.choose_action(observation)#选行为
29
30        observation_, reward, done, info = env.step(action)#获取下一个state
31
32        # x是车的水平位移,所以r1是车越偏离中心,得分(reward)越少
33        # theta是棒子离垂直的角度,角度越大,越不垂直。所以r2是棒越垂直,分(reward)越高
34        # 总reward是r1和r2的结合,既考虑位置也考虑角度,这样DQN学习更有效率
35        x, x_dot, theta, theta_dot = observation_
36
37        r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
38        r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
39        reward = r1 + r2
40
41        #保存这一组记忆
42        RL.store_transition(observation, action, reward, observation_)
43
44        ep_r += reward
45        if total_steps > 1000:
46            RL.learn()
47
48        if done:
49            print('episode: ', i_episode,
50                  'ep_r: ', round(ep_r, 2),
51                  ' epsilon: ', round(RL.epsilon, 2))
52            break
53
54        observation = observation_
55        total_steps += 1
56
57RL.plot_cost()

这是更为典型的RL cost曲线:

四、MountainCar例子

小车经过谷底的震荡,慢慢地就可以爬到山顶拿旗子了。

代码和上面差不多,只是定义的reward不同:

代码语言:javascript
复制
 1import gym
 2from RL_brain import DeepQNetwork
 3
 4env = gym.make('MountainCar-v0')
 5env = env.unwrapped
 6
 7print(env.action_space)
 8print(env.observation_space)
 9print(env.observation_space.high)
10print(env.observation_space.low)
11
12RL = DeepQNetwork(n_actions=3, n_features=2, learning_rate=0.001, e_greedy=0.9,
13                  replace_target_iter=300, memory_size=3000,
14                  e_greedy_increment=0.0002,)
15
16total_steps = 0
17
18
19for i_episode in range(10):
20
21    observation = env.reset()
22    ep_r = 0
23    while True:
24        env.render()
25
26        action = RL.choose_action(observation)
27
28        observation_, reward, done, info = env.step(action)
29
30        position, velocity = observation_
31
32        # 车开的越高reward越大
33        reward = abs(position - (-0.5))     # r in [0, 1]
34
35        RL.store_transition(observation, action, reward, observation_)
36
37        if total_steps > 1000:
38            RL.learn()
39
40        ep_r += reward
41        if done:
42            get = '| Get' if observation_[0] >= env.unwrapped.goal_position else '| ----'
43            print('Epi: ', i_episode,
44                  get,
45                  '| Ep_r: ', round(ep_r, 4),
46                  '| Epsilon: ', round(RL.epsilon, 2))
47            break
48
49        observation = observation_
50        total_steps += 1
51
52RL.plot_cost()

出来的cost曲线是这样的:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 计算机视觉漫谈 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、前言
  • 二、安装
  • 三、CartPole例子
    • 主循环
    • 四、MountainCar例子
    相关产品与服务
    对象存储
    对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档