前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >强化学习系列(三)-gym介绍和实例

强化学习系列(三)-gym介绍和实例

原创
作者头像
languageX
修改2022-02-16 16:53:42
4.1K0
修改2022-02-16 16:53:42
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

gym是openAI下的一个开发和对比强化学习算法的工具包,内部提供了强化学习需要的环境。

官方文档:https://gym.openai.com/docs/

gym库安装

我是在window下进行安装的

代码语言:shell
复制
conda create -n gym
pip install gym
pip install pyglet

gym--hello world代码

我们参考官方文档执行gym的“hello world”代码。

代码语言:python
复制
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gym
env = gym.make('CartPole-v0')
env.reset()
for _ in range(1000):
    env.render()
    # take a random action
    env.step(env.action_space.sample()) 
env.close()

运行以上代码,如果出现一个倒立摆问题的动画,说明gym库安装成功了,我们跑通了gym最简单的hello world代码。

简单介绍下上面代码主要实现的功能:

  1. env = gym.make('CartPole-v0')运创建一个cartpole问题的环境,对于cartpole问题下文会进行详细介绍。
  2. env.reset() 对环境进行重置,得到初始的observation
  3. env.render()渲染物体状态的UI,这里调用了gym的渲染接口,我们不做深究
  4. env.action_space.sample()指从动作空间中随机选取一个动作
  5. env.step() 指在环境中采取选择的动作,这里会返回reward等信息

也就是首先创建一个环境,对环境进行重置。然后循环迭代1000次,每个迭代中我们从环境的动作空间中选择一个动作进行执行,进入下一个状态。

我们在实现自己的算法时,最看重的就是选取动作和策略这一步,也就是根据现在的状态,如何选择下一步的动作。

gym实例--CartPole

通过上面简单的demo可能对整个环境的理解以及状态空间,状态空间,step返回还是不太理解。本节就对demo进行更详细的讲解。

CartPole介绍

首先,我们得知道CartPole是要解决什么问题。CartPole是一个小车倒立摆的问题。如下图所示:

小车倒立摆示意图
小车倒立摆示意图

一个小车上有一根杆,随着小车的移动,杆会产生倾斜,当倾斜到一定程度由于重力杆就会倒下。我们的目标就是让杆一直朝上不会倒下。所以需要在杆的每个状态给一个执行动作,让小车向左或者向右移动,让杆能保持平衡。

CartPole环境介绍

对于CartPole环境中的状态空间和动作空间,可以通过源码结合我们在代码中的日志进行了解。

CartPole类源码:

https://github.com/openai/gym/blob/0cd9266d986d470ed9c0dd87a41cd680b65cfe1c/gym/envs/classic_control/cartpole.py

在demo代码中打印了解数据形式:

代码语言:javascript
复制
 print("env.action_space", env.action_space)
>> Discrete(2)

动作空间是一个离散数据: 状态空间值{0,1},0--表示左移动,1--表示右移动

代码语言:javascript
复制
print("env.observation_space", env.observation_space)
>>Box(4,) 

状态空间是一个多维空间,四个维度分别表示:小车在轨道上的位置,杆子和竖直方向的夹角,小车速度,角度变化率。

每执行一个step(),返回四个值:

observation, reward, done, info = env.step(sample_action)

observation:当前状态值

reward: 每步给出奖励1

done: 本轮探索是否结束

info: 调试信息

重新了解demo逻辑,首先初始化环境observation(状态),选取一个动作,然后返回在环境中执行动作后的observation,每次step的reward都是1,当杆倒下done就为False,杆向上的时间越长,得到的reward就越大。

我们学习的目标就是保持杆一直竖直朝上,杆由于重力原因会一直倾斜,当杆倾斜到一定程度就会倒下,此时需要朝左或者右移动杆保证它不会倒下来。

下面代码为将demo代码进行一些日志扩充,让我们对CartPole-v0环境有个更充分的认识。

代码语言:python
复制
import gym
# 创建一个CartPole-v0(小车倒立摆模型)
env = gym.make('CartPole-v0')
for i_episode in range(1000):
    # 环境重置,得到一个初始observation
    observation = env.reset()
    for t in range(1000):
        # 渲染引擎显示物体状态
        env.render()
        # 动作空间,{0,1} 0-左移动, 1-右移动
        print("env.action_space", env.action_space)
        # >>Discrete(2) 一个离散空间
        # 状态空间
        print("env.observation_space", env.observation_space)
        # >>Box(4,) 多维空间
        # 奖励范围和状态空间范围
        # print("env.reward_range", env.reward_range)
        # print("env.observation_space.high", env.observation_space.high)
        # print("env.observation_space.low", env.observation_space.low)
        # print("env.observation_space.bounded_above", env.observation_space.bounded_above)
        # print("env.observation_space.bounded_below",  env.observation_space.bounded_below)
        # 随机选取动作
        sample_action = env.action_space.sample()
        """
        observation:当前观察的object的状态值
        小车在轨道上的位置,杆子和竖直方向的夹角,小车速度,角度变化率)
        
        reward: 执行上一个action后的奖,每一步给出奖励1
        
        done:本轮探索是否结束,是否需要reset环境
        达到下列条件之一片段结束:
        杆子与竖直方向角度超过12度
        小车位置距离中心超过2.4(小车中心超出画面)
        考虑片段长度超过200
        考虑连续100次尝试的平均奖励大于等于195。
        
        info:调试信息
        """
        observation, reward, done, info = env.step(sample_action)
        print('observation:{}, reward:{}, done:{}, info:{}'.format(observation, reward, done, info))
        # 如果结束, 则退出循环
        if done:
            print("Episode finished after {} timesteps".format(t + 1))
            break
env.close()

CartPole问题--Random Guessing Algorithm & Hill Climbing Algorithm

在以上代码中,我们最重要的就是如何将学习过程中的sample_action修改为通过策略决定出一个合适的action。本文介绍简单的Random Guessing Algorithm & Hill Climbing Algorithm进行策略。

通过策略选择的action的取值即0 or 1,observation是一个四维向量,如果对这个向量求它的加权和,就可以得到一个值,再根据加权和的符号来决定action。

Random Guessing Algorithm就是每次加权权重为随机给出的值。

Hill Climbing Algorithm就是每次将上一次得到的最好权重加上一个小的随机变化更新。

主要流程和逻辑在一下代码中注释进行了讲解。

主要代码如下:

代码语言:python
复制
import numpy as np
import gym
import time

def get_action(weights, observation):
    """ 根据weights,对observation进行加权求和,根据值决定动作策略
    """
    wxb = np.dot(weights[:4], observation) + weights[4]
    if wxb >= 0:
        return 1
    else:
        return 0

def get_sum_reward_by_weights(env, weights):
    """ 根据当前策略,计算本次探索的累计reward
    """
    observation = env.reset()
    sum_reward = 0
    for t in range(1000):
        # time.sleep(0.01)
        # env.render()
        action = get_action(weights, observation)
        observation, reward, done, info = env.step(action)
        sum_reward += reward
        if done:
            break
    return sum_reward


def get_weights_by_random_guess():
    """ 使用 Random Guessing Algorithm 返回weights, 没轮随机给出权重
    """
    return np.random.rand(5)

def get_weights_by_hill_climbing(best_weights):
    """ 使用 hill_climbing Algorithm 返回weights, 每次最好的权重加上一点随机变化
    """
    return best_weights + np.random.normal(0, 0.1, 5)

def get_best_result(algo="random_guess"):
    env = gym.make("CartPole-v0")
    np.random.seed(10)
    best_reward = 0
    best_weights = np.random.rand(5)
    # 进行100轮探索
    for iter in range(10000):
        # 选择策略算法
        if algo == "hill_climbing":
            cur_weights = get_weights_by_hill_climbing(best_weights)
        else:
            cur_weights = get_weights_by_random_guess()
        # 使用本轮探索得到的策略权重,得到累计reward
        cur_sum_reward = get_sum_reward_by_weights(env, cur_weights)

        # 保存探索过程中最好的reward对应的权重值
        if cur_sum_reward > best_reward:
            best_reward = cur_sum_reward
            best_weights = cur_weights

        if best_reward >= 200:
            break
    print(iter, best_reward, best_weights)
    return best_reward, best_weights

if __name__ == '__main__':
    get_best_result("hill_climbing")

参考资料:

https://www.fashici.com/tech/836.html

中国女足真牛逼!~

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • gym库安装
  • gym--hello world代码
  • gym实例--CartPole
    • CartPole介绍
      • CartPole环境介绍
      • CartPole问题--Random Guessing Algorithm & Hill Climbing Algorithm
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档