前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >AI 技术讲座精选:用Keras和Gym实现深度强化学习

AI 技术讲座精选:用Keras和Gym实现深度强化学习

作者头像
AI科技大本营
发布2018-04-26 11:10:37
1.3K0
发布2018-04-26 11:10:37
举报
文章被收录于专栏:AI科技大本营的专栏

综述

本篇博文将向大家演示,如何在仅仅运用78行代码的情况下,利用 Keras 和 Gym 实现深度强化学习,使得 CartPole 这款游戏得以成功运行。

即使你不知道强化学习是什么也没关系,我会在文章中给出详细的介绍。

本文代码可从 GitHub 上找到:https://github.com/keon/deep-q-learning

强化学习

强化学习是机器学习的其中一种类型,它使得你能够创建 AI agent,在与环境的互动中完成学习。与我们学习骑自行车的过程一样,在反复试验和不断纠错中逐渐学习。如上图所示,大脑表示作用于环境的 AI agent。每进行一个动作,agent 都会收到一个反馈。该反馈包含两部分的内容:一个是“奖励”(reward),另一个是则为“环境的下一状态”。“奖励”的内容通常由人来定义。如果拿学习自行车的过程相类比的话,该“奖励”就可以是与原先出发点之间的距离。

深度强化学习

Google DeepMind 曾于2013年发表著名的论文“用深度强化学习实现Atari”(Playing Atari with Deep Reinforcement Learning,https://arxiv.org/abs/1312.5602),文中提出了一种叫做“Deep Q Network(简称 DQN)”的新算法。其演示了一个 AIagent 学习玩游戏的过程,AIagent 要做的只是观察屏幕,不需要具备任何关于该游戏的先验知识。最终结果让人印象深刻。正是这篇论文,开启了“深层强化学习”的新时代——将深层学习和强化学习相结合。

DeepMind Atari 游戏视频链接:

https://www.youtube.com/watch?v=V1eYniJ0Rnk

在Q学习算法中,有一个Q函数,可基于某个状态估计 reward。 类似地,在深Q网络算法中,我们将神经网络看作Q函数,基于某个状态来估计 reward。下面我们会详细说明。

Cartpole 游戏

通常,训练 agent 玩 Atari 游戏需要一定的时间(可能几个小时,也可能是一天)。为了简便起见,我们让 agent 玩一个相对简单些的游戏 CartPole,使用的方法和上述2013年 Google DeepMind 所发表论文中的一致。

CartPole 是 OpenAI Gym(一个游戏模拟器)中最简单的环境之一。从上面的动画中我们可以看到,CartPole 的目标是:设法去平衡正在移动的小车顶部的竖杆。从其状态中我们可以得到4条信息,包括杆的角度和小车的位置等,但是不包括像素信息。agent 通过一系列0或1的动作,推动小车向左或向右移动。

Gym 让 agent 与游戏环境之间的互动变得非常简单。

如上所述,动作可以是0也可以是1。我们在传送0或1的同时,env(游戏环境)会给出相应的结果。Done 是一个布尔值,用于表示游戏结束与否。旧的 state 信息与 action 和 next_state 相配合,reward 是我们训练 agent 所必须的信息。

利用 Keras 实现简单神经网络

本篇文章并不是一篇关于深度学习和神经网络的文章。因此,这里我们仅仅将神经网络看作是黑盒算法。神经网络算法的一般过程为:对输入和输出数据对进行学习;得到一种模型;根据未知输入数据预测输出。不过我们应该理解神经网络在 DQN 算法中所扮演的角色。

接下来我们将要使用的神经网络与上图类似,总共有1个输入层和3个隐含层,并且输入层需要接收4个信息。我们的输出层将包含2个节点,也就是说输出两个值,因为游戏有两个按钮(0和1)。

在 Keras 的帮助下,基本神经网络的实现过程真的简单了许多。下面的代码创建一了个空的神经网络模型。Activation、loss 和 optimizer是与神经网络特性相关的参数,此处不做讨论。

神经网络必须经过训练才能在环境数据的基础上进行理解并作出预测,我们需要事先为其提供必要的信息。Fit() 函数将 states 和 target_f 信息传给模型,下面我将给出解释,其他的参数可以暂且搁置。

该训练过程使得神经网络可以根据某个 state 预测 reward 值(target_f)。

当对模型调用 predict() 函数时,模型将根据训练数据预测当前状态的 reward 值。

实现 DQN

存储和重写是 DQN 算法两个最为显著的特征。两者都是相当简单的概念。

存储

DQN 的问题之一是,算法中使用的神经网络会丢失先前所得的经验,只能记住当下的新经验。因此,我们需要一个含有先前经验值和观察值的列表(实际上是一个 memory 数组),以便用先前经验重新训练模型。调用这个 memory 数组,并使用 remember() 函数将 state、action、reward 和 next_State 信息写入内存中。

在我们的示例中,内存列表具有以下形式:

如下所示,remember 函数会简单地将 state、action、reward 和 next_State存入内存:

是不是很简单呢?

重写

利用 momory 中的经验来训练神经网络的方法称为 replay()。首先,我们从 memory 中抽取一些经验并把它们命名为 batches。

上述代码通过对内存索引随机采样得到 batches。例如,如果 batches 是[1,5,2,7],则每个数字对应表示内存索引1,5,2,7。

为了使 agent 能够长久并良好地运行下去,我们不仅要考虑当下的 rewards,还要考虑将来会得到的 future rewards。鉴于此,我们会为其设置衰减率(discount rate 或 gamma)。这样一来,在既定状态的基础上,agent 将会学习以取得衰减后的 future reward 的最大值。

Agent 如何决策

首先,agent 将按一定比例(探索率(epsilon),即 agent 随机动作的次数)随机选择动作(0或1)。这是因为,在 agent 还没有得到最终的模型时,尽可能多的尝试会非常有利。当不能随机决定动作时,agent 将基于当前状态预测 reward 值,并选择能得到最高 reward 值的动作。Np.argmax() 是选择 act_values [0] 中两个元素之间最大值的函数。

Act_values [0] 形式如下:[0.67,0.2],其中的数字分别表示选择动作0和1的 reward 值。Argmax() 函数会给出最大值的索引。如在[0.67,0.2]的示例中,若是第0个索引中的值最大,argmax 就会返回0。

主要参数

有些参数必须要传给强化学习 agent,这些参数你会在程序中多次看到:

  • episodeds ——我们想要agent进行的游戏次数。
  • gamma——衰减率,计算衰减的future reward值。
  • epsilon——探索率,agent随机决定其动作的比例。
  • epsilon_decay——随着agent越来越擅长游戏,我们想要减少的随机探索数量。
  • epsilon_min——我们希望agent随机探索的最低次数。
  • learning_rate——在每次迭代的过程中,神经网络的学习率。

整合:深度强化学习 Agnet 代码

以上就是我们对 agent 各个部分的解释。接下来的代码将我们的讨论变为了现实,我们将这个既美观又整洁的类称为 DQNAgent。

让我们一起训练这个 Agent

训练部分的代码更短。在评论部分我会做出解释。

结果

一开始,agent 通过随机地操作进行探索。

经过多个阶段的学习后:

  1. 小车可以保持杆平衡。
  2. 当超出边界时,结束游戏。
  3. 当靠近边界时,agent试图远离,但杆会掉下来。
  4. 小车可以保持杆的平衡并能控制住杆。

经过几百个回合(episode)后(花了10分钟),agent 开始学习使如何最大化 reward 的值(score)。

最后,一个技艺高超的 CartPole 游戏玩家诞生了!

文中的代码均可在 GitHub 上找到。与此同时,我也保存了模型权重,如果你想省去训练过程,可以直接使用。(https://github.com/keon/deep-q-learning

参考文献

  • Playing Atari with Deep Reinforcement Learning(https://arxiv.org/abs/1312.5602
  • Human-level Control Through Deep Reinforcement Learning(http://home.uchicago.edu/~arij/journalclub/papers/2015_Mnih_et_al.pdf

本文作者 Keon Kim 目前就读于纽约大学,致力于研究数据科学、网络开发和机器学习,尤其热衷于深度学习。

本文由 AI100 编译,转载需得到本公众号同意。


编译:AI100

原文链接:https://keon.io/rl/deep-q-learning-with-keras-and-gym/


本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2017-02-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI科技大本营 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档