AI 技术讲座精选:用Keras和Gym实现深度强化学习

综述

本篇博文将向大家演示,如何在仅仅运用78行代码的情况下,利用 Keras 和 Gym 实现深度强化学习,使得 CartPole 这款游戏得以成功运行。

即使你不知道强化学习是什么也没关系,我会在文章中给出详细的介绍。

本文代码可从 GitHub 上找到:https://github.com/keon/deep-q-learning

强化学习

强化学习是机器学习的其中一种类型,它使得你能够创建 AI agent,在与环境的互动中完成学习。与我们学习骑自行车的过程一样,在反复试验和不断纠错中逐渐学习。如上图所示,大脑表示作用于环境的 AI agent。每进行一个动作,agent 都会收到一个反馈。该反馈包含两部分的内容:一个是“奖励”(reward),另一个是则为“环境的下一状态”。“奖励”的内容通常由人来定义。如果拿学习自行车的过程相类比的话,该“奖励”就可以是与原先出发点之间的距离。

深度强化学习

Google DeepMind 曾于2013年发表著名的论文“用深度强化学习实现Atari”(Playing Atari with Deep Reinforcement Learning,https://arxiv.org/abs/1312.5602),文中提出了一种叫做“Deep Q Network(简称 DQN)”的新算法。其演示了一个 AIagent 学习玩游戏的过程,AIagent 要做的只是观察屏幕,不需要具备任何关于该游戏的先验知识。最终结果让人印象深刻。正是这篇论文,开启了“深层强化学习”的新时代——将深层学习和强化学习相结合。

DeepMind Atari 游戏视频链接:

https://www.youtube.com/watch?v=V1eYniJ0Rnk

在Q学习算法中,有一个Q函数,可基于某个状态估计 reward。 类似地,在深Q网络算法中,我们将神经网络看作Q函数,基于某个状态来估计 reward。下面我们会详细说明。

Cartpole 游戏

通常,训练 agent 玩 Atari 游戏需要一定的时间(可能几个小时,也可能是一天)。为了简便起见,我们让 agent 玩一个相对简单些的游戏 CartPole,使用的方法和上述2013年 Google DeepMind 所发表论文中的一致。

CartPole 是 OpenAI Gym(一个游戏模拟器)中最简单的环境之一。从上面的动画中我们可以看到,CartPole 的目标是:设法去平衡正在移动的小车顶部的竖杆。从其状态中我们可以得到4条信息,包括杆的角度和小车的位置等,但是不包括像素信息。agent 通过一系列0或1的动作,推动小车向左或向右移动。

Gym 让 agent 与游戏环境之间的互动变得非常简单。

如上所述,动作可以是0也可以是1。我们在传送0或1的同时,env(游戏环境)会给出相应的结果。Done 是一个布尔值,用于表示游戏结束与否。旧的 state 信息与 action 和 next_state 相配合,reward 是我们训练 agent 所必须的信息。

利用 Keras 实现简单神经网络

本篇文章并不是一篇关于深度学习和神经网络的文章。因此,这里我们仅仅将神经网络看作是黑盒算法。神经网络算法的一般过程为:对输入和输出数据对进行学习;得到一种模型;根据未知输入数据预测输出。不过我们应该理解神经网络在 DQN 算法中所扮演的角色。

接下来我们将要使用的神经网络与上图类似,总共有1个输入层和3个隐含层,并且输入层需要接收4个信息。我们的输出层将包含2个节点,也就是说输出两个值,因为游戏有两个按钮(0和1)。

在 Keras 的帮助下,基本神经网络的实现过程真的简单了许多。下面的代码创建一了个空的神经网络模型。Activation、loss 和 optimizer是与神经网络特性相关的参数,此处不做讨论。

神经网络必须经过训练才能在环境数据的基础上进行理解并作出预测,我们需要事先为其提供必要的信息。Fit() 函数将 states 和 target_f 信息传给模型,下面我将给出解释,其他的参数可以暂且搁置。

该训练过程使得神经网络可以根据某个 state 预测 reward 值(target_f)。

当对模型调用 predict() 函数时,模型将根据训练数据预测当前状态的 reward 值。

实现 DQN

存储和重写是 DQN 算法两个最为显著的特征。两者都是相当简单的概念。

存储

DQN 的问题之一是,算法中使用的神经网络会丢失先前所得的经验,只能记住当下的新经验。因此,我们需要一个含有先前经验值和观察值的列表(实际上是一个 memory 数组),以便用先前经验重新训练模型。调用这个 memory 数组,并使用 remember() 函数将 state、action、reward 和 next_State 信息写入内存中。

在我们的示例中,内存列表具有以下形式:

如下所示,remember 函数会简单地将 state、action、reward 和 next_State存入内存:

是不是很简单呢?

重写

利用 momory 中的经验来训练神经网络的方法称为 replay()。首先,我们从 memory 中抽取一些经验并把它们命名为 batches。

上述代码通过对内存索引随机采样得到 batches。例如,如果 batches 是[1,5,2,7],则每个数字对应表示内存索引1,5,2,7。

为了使 agent 能够长久并良好地运行下去,我们不仅要考虑当下的 rewards,还要考虑将来会得到的 future rewards。鉴于此,我们会为其设置衰减率(discount rate 或 gamma)。这样一来,在既定状态的基础上,agent 将会学习以取得衰减后的 future reward 的最大值。

Agent 如何决策

首先,agent 将按一定比例(探索率(epsilon),即 agent 随机动作的次数)随机选择动作(0或1)。这是因为,在 agent 还没有得到最终的模型时,尽可能多的尝试会非常有利。当不能随机决定动作时,agent 将基于当前状态预测 reward 值,并选择能得到最高 reward 值的动作。Np.argmax() 是选择 act_values [0] 中两个元素之间最大值的函数。

Act_values [0] 形式如下:[0.67,0.2],其中的数字分别表示选择动作0和1的 reward 值。Argmax() 函数会给出最大值的索引。如在[0.67,0.2]的示例中,若是第0个索引中的值最大,argmax 就会返回0。

主要参数

有些参数必须要传给强化学习 agent,这些参数你会在程序中多次看到:

  • episodeds ——我们想要agent进行的游戏次数。
  • gamma——衰减率,计算衰减的future reward值。
  • epsilon——探索率,agent随机决定其动作的比例。
  • epsilon_decay——随着agent越来越擅长游戏,我们想要减少的随机探索数量。
  • epsilon_min——我们希望agent随机探索的最低次数。
  • learning_rate——在每次迭代的过程中,神经网络的学习率。

整合:深度强化学习 Agnet 代码

以上就是我们对 agent 各个部分的解释。接下来的代码将我们的讨论变为了现实,我们将这个既美观又整洁的类称为 DQNAgent。

让我们一起训练这个 Agent

训练部分的代码更短。在评论部分我会做出解释。

结果

一开始,agent 通过随机地操作进行探索。

经过多个阶段的学习后:

  1. 小车可以保持杆平衡。
  2. 当超出边界时,结束游戏。
  3. 当靠近边界时,agent试图远离,但杆会掉下来。
  4. 小车可以保持杆的平衡并能控制住杆。

经过几百个回合(episode)后(花了10分钟),agent 开始学习使如何最大化 reward 的值(score)。

最后,一个技艺高超的 CartPole 游戏玩家诞生了!

文中的代码均可在 GitHub 上找到。与此同时,我也保存了模型权重,如果你想省去训练过程,可以直接使用。(https://github.com/keon/deep-q-learning

参考文献

  • Playing Atari with Deep Reinforcement Learning(https://arxiv.org/abs/1312.5602
  • Human-level Control Through Deep Reinforcement Learning(http://home.uchicago.edu/~arij/journalclub/papers/2015_Mnih_et_al.pdf

本文作者 Keon Kim 目前就读于纽约大学,致力于研究数据科学、网络开发和机器学习,尤其热衷于深度学习。

本文由 AI100 编译,转载需得到本公众号同意。


编译:AI100

原文链接:https://keon.io/rl/deep-q-learning-with-keras-and-gym/


原文发布于微信公众号 - AI科技大本营(rgznai100)

原文发表时间:2017-02-18

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏旭的专栏

机器学习之离散特征自动化扩展与组合

机器学习中特征工程处于十分重要的位置,整个学习中至少会花二分之一甚至更多的时间用于特征工程至少(对于统计学习)。但是如何能够有效地实现特征的自动化扩展、组合和量...

1.2K1
来自专栏玉树芝兰

如何用Python和深度神经网络发现即将流失的客户?

想不想了解如何用Python快速搭建深度神经网络,完成数据分类任务?本文一步步为你展示这一过程,让你初步领略深度学习模型的强大和易用。

2643
来自专栏人工智能头条

NLP通用模型诞生?一个模型搞定十大自然语言常见任务

1352
来自专栏IT派

7月机器学习最佳开源项目Top 10

【导读】七月就要结束了,在即将到来的 7 月最后一个周末,人工智能头条为大家整理了本月 ML 最受欢迎的十大开源项目。就算放假在家也可以知道大家现在都在学些什...

791
来自专栏机器之心

学界 | Fashion-MNIST:替代MNIST手写数字集的图像数据集

机器之心转载 公众号:PaperWeekly 作者:肖涵 FashionMNIST 是一个替代 MNIST 手写数字集 [1] 的图像数据集。 它是由 Zala...

4559
来自专栏机器之心

学界 | 强化学习+树搜索:一种新型程序合成方法

作者:Riley Simmons-Edler、Anders Miltner、Sebastian Seung

902
来自专栏机器之心

深度 | 苹果博客解读iPhone上的人脸识别深度神经网络

39110
来自专栏AI研习社

分享一波关于做 Kaggle 比赛,Jdata,天池的经验,看完我这篇就够了。

Kaggle 的数据挖掘比赛近年来很火,以至于中国兴起了很多很多类似的比赛,做了两个这种类型的比赛了,Jdata 用户商品购买预测和用户位置精准预测,积累了相当...

6374
来自专栏新智元

深度学习应用系统分析:应用组合和形态矩阵找到正确路径

【新智元导读】本文收录了arXiv.org上关于深度学习的一些最新的研究论文,列出了这些文章的内容,包括“深度学习八大灵感应用”、“深度学习用例”、“科学与工程...

43410
来自专栏新智元

【大咖读论文】田渊栋评 DeepMind 深度强化学习的异步算法

【新智元导读】在ICML2016上,谷歌 DeepMind 有9篇论文被接收,新智元特邀Facebook 人工智能实验室研究员田渊栋对其中的《 深度增强学习的异...

4125

扫码关注云+社区

领取腾讯云代金券