前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >模仿学习(Imitation Learning)完全介绍

模仿学习(Imitation Learning)完全介绍

作者头像
用户1737318
发布2018-06-06 15:51:12
3.8K0
发布2018-06-06 15:51:12
举报
文章被收录于专栏:人工智能头条人工智能头条

作者:罗宇矗 原文:模仿学习(Imitation Learning)完全介绍(一) http://dwz.cn/5wOd4F

在传统的强化学习任务中,通常通过计算累积奖赏来学习最优策略(policy),这种方式简单直接,而且在可以获得较多训练数据的情况下有较好的表现。

然而在多步决策(sequential decision)中,学习器不能频繁地得到奖励,且这种基于累积奖赏及学习方式存在非常巨大的搜索空间。

而模仿学习(Imitation Learning)的方法经过多年的发展,已经能够很好地解决多步决策问题,在机器人、 NLP 等领域也有很多的应用。

模仿学习是指从示教者提供的范例中学习,一般提供人类专家的决策数据

,每个决策包含状态和动作序列

,将所有「状态-动作对」抽取出来构造新的集合

之后就可以把状态作为特征(feature),动作作为标记(label)进行分类(对于离散动作)或回归(对于连续动作)的学习从而得到最优策略模型。

模型的训练目标是使模型生成的状态-动作轨迹分布和输入的轨迹分布相匹配。从某种角度说,有点像自动编码器(Autoencoder),也与目前大火的 GANs 很类似。

在简单自动驾驶任务中(如下图),状态就是指汽车摄像头所观测到的画面

(很多强化学习任务中

是可以互换的),动作即转向角度。

根据人类提供的状态动作对来习得驾驶策略。这个任务也叫做行为克隆(Behavior Cloning),即作为监督学习的模仿学习。

但是不是我们完成训练后模型就能够有比较好的效果?

答案是否定的,这里存在复合误差(compounding errors),训练好的策略模型

执行的轨迹和训练轨迹的误差会随时间的增加而越变越大,用公式表示即

代表在 t 时刻c误差的概率,在每个时刻 T ,

),具体效果见下图:

数据增广 (Data Augmentation)

为了解决误差随时间越来越大的问题,可以采用数据增广(Data Augmentation)方法,如下图。

这是一个端对端的自动驾驶解决方案(NVIDIA 2016),汽车装配了左右两侧的摄像头与中央摄像头来获取当前观测的环境,并且能够通过 Back propagation 使其从错误状态中恢复。

它在训练模型前人为地调整了环境不好时汽车的运动动作,另外,摄像头图像的识别采用的是卷积神经网络。

如果我们能够获取大量的训练数据形成一个合适的状态概率分布或者说样本空间,从而得到一个很好的策略模型同样能直接地解决这个问题(如图)。

但这往往不太现实,因为需要耗费的成本太大。起初大部分研究者也几乎全在研究如何优化策略减少误差,并提出了很多方法,但都不是十分有效。

DAgger (Dataset Aggregation)

该方法则将研究目的从优化策略

,即令

趋近

,转移到增加训练数据上,即令样本空间更加接近真实样本空间。具体算法如下:

1.通过数据集

训练出策略

2.执行

得到一个新的数据集

3.人工给

中的状态标上动作(action)

4.聚合(Aggregate):

5.跳到步骤 1

从上述步骤中,我们可以看到 DAgger 最大的问题是第三步,第三步需要人为地去打标记,这是没有人愿意干的工作。

那是否有「自动」的办法代替人完成这个工作呢?

答案是肯定的,比如 iLQR 方法,比较复杂,会在之后发的关于 CS294 的笔记里细说。

对模仿学习的应用场景而言,在仅仅使用模仿学习算法本身时,没有具体的理论去说明模仿学习什么时候表现好,什么时候表现不好,但很多时候都得不到很好的效果。

它通常在下面几种情况里表现很好:

  1. 有很好的观测器,比如前文提到的左右摄像头引入了更强更丰富的状态信息
  2. 数据取样自一个很健壮的路径规划分布
  3. 增加更多在线策略生成的数据,比如使用 DAgger

下面是 CS294-112 中提到的模仿学习的两个具体应用:

案例一(下图)是通过三个摄像头采集的图像描述

,人类行走的方向作为

拥有离散的三个量,直走(straight)、右转(right)和左转(left)。

获得数据后直接作为分类问题(classification)进行训练,取得了比较好的效果。shi x论文见文末参考资料[4]。

案例二使用了前面提到的 DAgger 方法,图中红线的位置是每个状态的标记,第三步即人工拖拽红线到合适的位置。论文见文末参考资料[5]。

结构化预测(Structured prediction)

结构化预测问题由输入空间

,输出空间

服从的一个固定但未知的分布

和一个非负的损失函数

组成。

其目的是用样本中的数据习得能最小化损失的映射

。下面是使用结构化预测在 Sequence labelling 中的几个样例:

1. Part of speech tagging(语言词性标记)

2. Name Entity Recognition(名字检测)

3. Attentive Tracking(目标跟踪)

除此之外,结构化预测在字幕生成、机器翻译等 NLP 领域也有应用。

下面是结构化预测在强化学习中的定义:

Sequence labelling

  • State: 输入序列 x 和已经生成的标签
  • Actions: 下一个输出的标签
  • Reward: 当预测值和真实值符合时

reward>0

Image captioning(图像标题)

标题:A blue monster is eating a cookie

  • State: 图像信息和已经生成词句
  • Actions: 下一个输出的词语
  • Reward: 当预测值和真实值符合时

reward>0

(这里的损失函数不可分解)

结构化预测问题通常会用到循环神经网络(RNN),值得注意的是

在 RNN 模型中,对训练数据

,需寻找到最优参数

其中

再把符号

表示参数向量和策略函数,可得

参考策略(输出 y 的真实值)是

。算法如下:

具体关于这个应用的细节可以参见论文:参考资料[8]。

再举一个使用 LSTMs 的有监督模仿学习案例,如下图。机器人需要通过模仿学习学会拣起并放置小物块,这里的输入 X 是所有对象的姿势、旋转和位置情况,以及操纵装置的状态,输出 y 是操纵装置的下一个 pose。具体实现论文见参考资料[9]。

模仿学习的方法还有 Interaction & active learning,Inverse reinforcement learning(IRL) 等。IRL 是应用比较多的一种方法,与其它方法直接模仿人类行为不同,它是根据观测来确定奖励函数,之后的文章会具体展开。

最后再说下模仿学习的局限:

一是需要人类专家提供数据这点通常得花大量精力,而且采用深度网络的学习方式又需要大量的数据;

二是在一些复杂困难的行为里,我们无法提供相关的行为数据,比如飞机的花式表演。

另外,人类可以自身的实践里获得越来越多的数据来实现自我的一步步提升,机器应该也可以逐渐达到这个效果。

参考资料:

1.CMU 10703: Deep RL and Control (Lecture 3/6)(https://katefvision.github.io/katefSlides/immitation_learning_I_katef.pdf)

2.CS 294 Lecture 2-3(https://zhuanlan.zhihu.com/p/25298020)

3.周志华《机器学习》强化学习-模仿学习部分

4.Alessandro Giusti et al., A Machine Learning Approach to Visual Perception of Forest Trails for Mobile Robots(http://rpg.ifi.uzh.ch/docs/RAL16_Giusti.pdf)

5.Shreyansh Daftry et al.,Learning Transferable Policies for Monocular Reactive MAV Control(https://arxiv.org/abs/1608.00627)

6.Bagnell, An Invitation to Imitation(http://www.ri.cmu.edu/publication_view.html?pub_id=7891)

7.Bojarski et al., End to End Learning for Self-Driving Cars(建议阅读)(https://arxiv.org/abs/1604.07316)

8.Nguyen, Imitation Learning with Recurrent Neural Networks(https://arxiv.org/abs/1607.05241)

9.Rouhollah et al., Learning real manipulation tasks from virtual demonstrations using LSTM(https://arxiv.org/abs/1603.03833)

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-03-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能头条 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
机器翻译
机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档