专栏首页专知【AlphaGo Zero 核心技术-深度强化学习教程代码实战04】Agent类和SARSA算法实现

【AlphaGo Zero 核心技术-深度强化学习教程代码实战04】Agent类和SARSA算法实现

【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值网络和策略网络整合为一个架构,3天训练后就以100比0击败了上一版本的AlphaGo。Alpha Zero的背后核心技术是深度强化学习,为此,专知有幸邀请到叶强博士根据DeepMind AlphaGo的研究人员David Silver《深度强化学习》视频公开课进行创作的中文学习笔记,在专知发布推荐给大家!(关注专知公众号,获取强化学习pdf资料,详情文章末尾查看!)

叶博士创作的David Silver的《强化学习》学习笔记包括以下:

笔记序言:【教程】AlphaGo Zero 核心技术 - David Silver深度强化学习课程中文学习笔记

  1. 《强化学习》第一讲 简介
  2. 《强化学习》第二讲 马尔科夫决策过程
  3. 《强化学习》第三讲 动态规划寻找最优策略
  4. 《强化学习》第四讲 不基于模型的预测
  5. 《强化学习》第五讲 不基于模型的控制
  6. 《强化学习》第六讲 价值函数的近似表示
  7. 《强化学习》第七讲 策略梯度
  8. 《强化学习》第八讲 整合学习与规划
  9. 《强化学习》第九讲 探索与利用

以及包括也叶博士独家创作的强化学习实践系列!

  • 强化学习实践一 迭代法评估4*4方格世界下的随机策略
  • 强化学习实践二 理解gym的建模思想
  • 强化学习实践三 编写通用的格子世界环境类
  • 强化学习实践四 Agent类和SARSA算法实现
  • 强化学习实践五 SARSA(λ)算法实现
  • 强化学习实践六 给Agent添加记忆功能
  • 强化学习实践七 DQN的实现

今天实践四 Agent类和SARSA算法实现。

声明:本文及涉及代码为本人原创,代码适用协议MIT。

通过前几次强化学习实践讲解,我们基本上理解了个体与环境的建模思想,特别是对gym库有了一定的了解。在本讲中,我们将尝试编写一个简单的Agent类,并且使它能够和我们之前编写的格子世界环境类进行交互。然后我们将实现SARSA算法,确切地说是SARSA(0)。我们将看看它在简单格子世界中的训练效果。

由于蒙特卡洛学习和单纯的基于价值的TD学习实际应用不多,在实践环节我也不打算实现这两个算法,当然了SARSA也是TD学习的一种形式。掌握了SARSA算法,再去实现类似的Q学习算法、MC或单纯的TD算法也不会有什么难度。在下一讲中我们将实现

算法,同时完善我们的Agent类。

为了帮助加深强化学习的各种概念,我将按照解决问题的思路一步一步编写代码,希望这种写法可以让读者加深对强化学习基本概念的理解。当Agent编写完成后,我们可以不做任何修改地使用它训练各种格子世界的环境,这也体现了DeepMind公司一直所宣扬的通用智能算法(General Intelligent Algorithms)理念。

一个简单的Agent类

我们在实践二讲gym的建模思想中简单提到了个体类的建模。本讲我们将围绕SARSA(0)算法的实现来具体实现个体(Agent)类应该具备的功能。SARSA(0)算法流程如下:

我们暂且就用Agent给个体类命名,由于要使用到之前编写的格子世界环境类,我们需要导入相关的包:

from random import random    # 随机策略时用到
from gym import Env
import gym
from gridworld import *      # 可以导入各种格子世界环境

我们已经知道强化学习中个体(Agent)遵循当前一个策略得到一个行为,通过施加这个行为给环境并分析环境(Environment)的反馈信息(个体观测值的变化和获得的即时奖励)来优化策略。同时SARSA算法要求个体类维护一个状态行为价值表 Q(s,a),也就是状态行为价值函数,针对一个状态 s ,在该状态下采取一个行为a,个体要能查得出该状态行为对的价值。因此Agent类要包括如下功能:

class Agent():
    def __init__(self, env: Env):
        self.env = env      # 个体持有环境的引用
        self.Q = {}         # 个体维护一张行为价值表Q
        self.state = None   # 个体当前的观测,最好写成obs.

    def performPolicy(self, state): pass # 执行一个策略

    def act(self, a):       # 执行一个行为
        return self.env.step(a)

    def learning(self): pass   # 学习过程

这里有两点要说明下:

  • 对于Q表,我使用的是字典套字典的数据结构,即字典里的每一个键对应于状态名,其值对应于另一个新字典,这个新字典的键值是行为名,值则对应相应的行为价值。这么设计主要是为了体现个体与环境交互的特点:个体并不掌握环境的具体信息,从与环境交互过程中个体得到的只是一个观测,我们不能事先假定观测的数据格式,用字典比较稳妥。
  • 我们把执行策略和执行行为分开成两个方法,是考虑到SARSA算法有两个地方需要使用个体的策略:一是执行一个动作前,另一处是用在观察新状态S'在遵循当前策略时产生的行为时,而后面这个行为不是马上执行的,有的算法里干脆就是不执行的,因此把执行策略生成行为和执行行为分开写也就比较合理了。执行行为方法很简单,调用环境的step方法就可以了,原封不动的把step的返回值作为自己的返回值,由于该方法的直接返回了个体需要额观测,我们就不需要单独写一个观测的方法了。

有了最基本的框架,事情就简单了。接下来我们设计几个辅助的私有方法。前面提到要把观测转化为一个字典的键,因此需要一个方法来完成此事:

  def _get_state_name(self, state):  
        return str(state)              

这里偷了个懒,仅仅把状态转化为字符串就完成了,不过这个仅适用离散观测空间的环境。

由于需要频繁检索和更新Q值,同时又要完成一些初始化工作,确保我们检索的时候避免发生键不存在的情况,因此我设计了以下几个私有方法:

  def _is_state_in_Q(self, s): # 判断s的Q值是否存在
        return self.Q.get(s) is not None

    def _init_state_value(self, s_name, randomized = True): # 初始化某状态的Q值
        if not self._is_state_in_Q(s_name):
            self.Q[s_name] = {}
            for action in range(self.env.action_space.n): # 针对其所有可能行为
                default_v = random() / 10 if randomized is True else 0.0
                self.Q[s_name][action] = default_v

    def _assert_state_in_Q(self, s, randomized=True): # 确保某状态Q值存在
        # cann't find the state
        if not self._is_state_in_Q(s):
            self._init_state_value(s, randomized)

    def _get_Q(self, s, a): # 获取Q(s,a)
        self._assert_state_in_Q(s, randomized=True)
        return self.Q[s][a]

    def _set_Q(self, s, a, value): # 设置Q(s,a)
        self._assert_state_in_Q(s, randomized=True)
        self.Q[s][a] = value

外围工作基本就这么多,接下来我们为个体实现两个主要功能:一个是策略方法、一个是训练过程。SARSA遵循的策略是

策略,为了能够使得个体随着训练次数的增多而减少产生不确定行为的几率

,进而收敛至最优策略,可以将其就改为衰减的

,这里采用的办法是将当前训练的Episode次数作为参数传递给策略函数,在此基础上确定

值。同时使用gym库提供的针对空间的sample()方法来生成随机行为,这样策略函数的参数和代码可以是如下样子:

def performPolicy(self, s, episode_num, use_epsilon):
        epsilon = 1.00 / (episode_num+1)
        Q_s = self.Q[s]
        str_act = "unknown"
        rand_value = random()
        action = None
        if use_epsilon and rand_value < epsilon:  
            action = self.env.action_space.sample()
        else:
            str_act = max(Q_s, key=Q_s.get)
            action = int(str_act)
        return action    

我们为执行策略方法增加了一个use_epsilon参数,使得我们可以随时切换是否使用

。通过这样设置,今后可以很容易将SARSA算法修改为Q学习算法。最后我们来实现SARSA算法的核心。

SARSA算法实现

SARSA的核心代码全部在learning方法里,我们为learning方法设计了几个参数来做适当的控制,具体代码如下,结合注释和算法流程,很容易理解:

# sarsa learning
    def learning(self, gamma, alpha, max_episode_num):
        # self.Position_t_name, self.reward_t1 = self.observe(env)
        total_time, time_in_episode, num_episode = 0, 0, 0
        while num_episode < max_episode_num: # 设置终止条件
            self.state = self.env.reset()    # 环境初始化
            s0 = self._get_state_name(self.state) # 获取个体对于观测的命名
            self.env.render()                # 显示UI界面
            a0 = self.performPolicy(s0, num_episode, use_epsilon = True)

            time_in_episode = 0
            is_done = False
            while not is_done:               # 针对一个Episode内部
                # a0 = self.performPolicy(s0, num_episode)
                s1, r1, is_done, info = self.act(a0) # 执行行为
                self.env.render()            # 更新UI界面
                s1 = self._get_state_name(s1)# 获取个体对于新状态的命名
                self._assert_state_in_Q(s1, randomized = True)
                # 获得A'
                a1 = self.performPolicy(s1, num_episode, use_epsilon=True)
                old_q = self._get_Q(s0, a0)  
                q_prime = self._get_Q(s1, a1)
                td_target = r1 + gamma * q_prime  
                #alpha = alpha / num_episode
                new_q = old_q + alpha * (td_target - old_q)
                self._set_Q(s0, a0, new_q)

                if num_episode == max_episode_num: # 终端显示最后Episode的信息
                    print("t:{0:>2}: s:{1}, a:{2:2}, s1:{3}".\                        format(time_in_episode, s0, a0, s1))

                s0, a0 = s1, a1
                time_in_episode += 1

            print("Episode {0} takes {1} steps.".format(
                num_episode, time_in_episode)) # 显示每一个Episode花费了多少步
            total_time += time_in_episode
            num_episode += 1
        return

SARSA(0)算法这就这就完成了。我们拿一个简单格子世界来试试效果如何,为此可以添加如下代码:

def main():
    env = SimpleGridWorld()
    agent = Agent(env)
    print("Learning...")  
    agent.learning(gamma=0.9, 
                   alpha=0.1, 
                   max_episode_num=800)if __name__ == "__main__":
    main()

设置衰减系数

, 学习率

,最大学习的完整episode数量为800次。

将所有代码保存在一个文件里,通过终端调用执行该文件,可以一边显示UI界面,一边在终端看到一些输出信息。

下图是训练早期的个体表现

下图是个体训练了近400次完整Episodes的时候基本上已经锁定最优路径了。

下面是其中的一个Episode视频片段:

https://www.zhihu.com/video/873802894154006528

该程序完整的代码可从这里下载到。

至此,针对离散观测空间和离散行为空间的SARSA算法我们就完成了。在此基础上Q学习算法只要修改1-2行代码就可以了。下一次实践我们将很容易地实现

算法,同时完善我们的个体(Agent)类。敬请期待。

敬请关注专知公众号(扫一扫最下方二维码或者最上方专知蓝字关注),以及专知网站www.zhuanzhi.ai, 第一时间得到强化学习实践三 编写通用的格子世界环境类!

作者简介:

叶强,眼科专家,上海交通大学医学博士, 工学学士,现从事医学+AI相关的研究工作。

本文分享自微信公众号 - 专知(Quan_Zhuanzhi),作者:叶强

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-10-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【AlphaGo Zero 核心技术-深度强化学习教程代码实战05】SARSA(λ)算法实现

    【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值...

    WZEARW
  • 【AlphaGo Zero 核心技术-深度强化学习教程代码实战06】给Agent添加记忆功能

    【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值...

    WZEARW
  • 【David Silver 深度强化学习教程代码实战07】 DQN的实现

    点击上方“专知”关注获取更多AI知识! 【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不...

    WZEARW
  • 使用PyTorch进行情侣幸福度测试指南

    计算机视觉--图像和视频数据分析是深度学习目前最火的应用领域之一。因此,在学习深度学习的同时尝试运用某些计算机视觉技术做些有趣的事情会很有意思,也会让你发现些令...

    磐创AI
  • Python 还能实现图片去雾?FFA 去雾算法、暗通道去雾算法用起来! | 附代码

    Pytorch模块用来模型训练和网络层建立;其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接...

    AI科技大本营
  • 用 Python 实现植物大战僵尸代码

    导读:以前很火的植物大战僵尸游戏,本想在网上找个Python版本游戏学习下,无奈没有发现比较完整的,那就自己来写一个吧。

    华章科技
  • Python《植物大战僵尸》代码实现:植物卡片选择和种植

    最近一直在给这个植物大战僵尸游戏添加新的植物和僵尸, 因为网上的图片资源有限,能加的植物和僵尸比较少, 目前进展如下。

    AI科技大本营
  • 结合知识图谱实现基于电影的推荐系统

    知识图谱(Knowledge Graph,KG)可以理解成一个知识库,用来存储实体与实体之间的关系。知识图谱可以为机器学习算法提供更多的信息,帮助模型更好地完成...

    代码医生工作室
  • 【AlphaGo Zero 核心技术-深度强化学习教程代码实战05】SARSA(λ)算法实现

    【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值...

    WZEARW
  • 基础数据结构 例:栈、队列、链表、数据、字典、树、等【玩转腾讯云】

    栈(stack)又名堆栈,它是一种运算受限的线性表。限定仅在表尾进行插入和删除操作的线性表。这一端被称为栈顶,把另一端称为栈底。向一个栈插入新元素又称作 进栈、...

    IT茂茂

扫码关注云+社区

领取腾讯云代金券