专栏首页计算机视觉漫谈Prioritized Experience Replay (DQN)——让DQN变得更会学习

Prioritized Experience Replay (DQN)——让DQN变得更会学习

目录

1.前言2.算法2.1 SumTree有效抽样2.2 Memory类2.3 更新方法对比结果

1.前言

这次我们还是使用MountainCar来进行实验,因为这次我们不需要重度改变它的reward了。所以只要是没有拿到小旗子reward=-1,拿到小旗子时,我们定义它获得了+10的reward。比起之前DQN中,这个reward定义更加准确。如果使用这种reward定义方式,可以想象Natural DQN会花很长时间学习。因为记忆库中只有很少很少的+10的reward可以学习,正负样本不一样。而使用Prioritized replay,就会重视这种少量,但值得学习的样本。

接下来我们就来看看他是怎么做到的。

2.算法

这一套算法的重点就在我们batch抽样的时候并不是随机抽样的,而是按照Memory中的样本优先级来抽。所以这能更有效地找到我们需要学习的样本。

那么样本的优先级是怎么定呢?原来我们可以用到TD-error,也就是Q现实-Q估计来规定优先学习的成都。如果TD-error越大,就代表我们的预测精度还有很多上升空间,那么这个样本就越需要被学习,也就是优先级p越高。

有了TD-error就有了优先级p,那我们如何有效地根据p来抽样呢?如果每次抽样都需要针对p对所有样本排序,这将会是一件非常消耗计算能力的事情,好在我们还有其他方法,这种方法不会对得到的样本进行排序,这就是论文中说到的SumTree

SumTree是一种树形结构,每片输液存储每个样本的优先级p,每个树枝节点只有两个分叉,节点的值是两个分叉的合,所以SumTree的顶端就是所有p的合。正如下面图片,最下面一层树叶存储样本的p。叶子上一层最左边的13=3+10,按这个规律相加,顶层的roor就是全部p的合了。

抽样的时,我们会将p的总和除以batch size,分成batch size那么多区间,(n=sum(p)/batch_size)。如果将所有node的priority加起来是42的话,我们如果抽6个样本,这时的区间拥有的priority可能是这样的: [0-7], [7-14], [14-21], [21-28], [28-35], [35-42] 然后在每个区间里随机选取一个数。比如在第4个区间[21-28]选到了24,就按照这个24从最顶上的42开始往下搜索。首先看到最顶上42下面有两个child nodes,拿着手中的24对比左边的chlid29,如果左边的chlid比自己手中的值大,那我们就走左边这条路,接着再对比29下面的左边那个点13,这时,手中的24比13大,那我们就走右边的路,并且将手中的值根据13修改一下,变成24-13=11.接着拿11和13右下角的12比,结果12比11大,那我们就选12当做这次选到的priority,并且也选择12对应的数据。

2.1 SumTree有效抽样

 1class SumTree(object):
 2    # 建立 tree 和 data,
 3    # 因为 SumTree 有特殊的数据结构,
 4    # 所以两者都能用一个一维 np.array 来存储
 5    def __init__(self, capacity):
 6
 7    # 当有新 sample 时, 添加进 tree 和 data
 8    def add(self, p, data):
 9
10    # 当 sample 被 train, 有了新的 TD-error, 就在 tree 中更新
11    def update(self, tree_idx, p):
12
13    # 根据选取的 v 点抽取样本
14    def get_leaf(self, v):
15
16    # 获取 sum(priorities)
17    @property
18    def totoal_p(self):

2.2 Memory类

 1class Memory(object):
 2    # 建立 SumTree 和各种参数
 3    def __init__(self, capacity):
 4
 5    # 存储数据, 更新 SumTree
 6    def store(self, transition):
 7
 8    # 抽取 sample
 9    def sample(self, n):
10
11    # train 完被抽取的 samples 后更新在 tree 中的 sample 的 priority
12    def batch_update(self, tree_idx, abs_errors):

具体完整的代码我在最后会附上我github的链接,这里说一下这个关于ISweight到底怎么算。需要提到一点是,代码中的计算方法是经过了简化的,将论文中的步骤合并了一些,比如:prob = p / self.tree.total_p; ISWeights = np.power(prob/min_prob, -self.beta)

在paper 中, ISWeight = (N*Pj)^(-beta) / maxi_wi里面的maxi_wi是为了 normalize ISWeight, 所以我们先把他放在一边. 所以单纯的importance sampling 就是(N*Pj)^(-beta),那 maxi_wi = maxi[(N*Pi)^(-beta)].

如果将这两个式子合并,

ISWeight = (N*Pj)^(-beta) / maxi[ (N*Pi)^(-beta) ]

而且如果将maxi[ (N*Pi)^(-beta)]中的 (-beta) 提出来, 这就变成了mini[ (N*Pi) ] ^ (-beta)

看出来了吧, 有的东西可以抵消掉的. 最后

ISWeight = (Pj / mini[Pi])^(-beta)

这样我们就有了代码中的样子.

还有代码中的alpha是一个决定我们要使用多少 ISweight 的影响, 如果alpha = 0,我们就没使用到任何 Importance Sampling.

2.3 更新方法

我们在_init_中加一个prioritized参数来表示DQN是否具备prioritized能力。为了对比的需要,我们的tf.Session()也单独传入,并移除原本在DQN代码中的这一句:self.sess.run(tf.global_variables_initializer())

搭建神经网络时,我们发现DQN with Prioritized replay只多了一个ISWeights,这个正是刚刚算法中提到的Importance-Sampling Weights,用来恢复被Prioritized replay打乱的抽样概率分布。

 1class DQNPrioritizedReplay:
 2    def _build_net(self)
 3        ...
 4        # self.prioritized 时 eval net 的 input 多加了一个 ISWeights
 5        self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s')  # input
 6        self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q_target')  # for calculating loss
 7        if self.prioritized:
 8            self.ISWeights = tf.placeholder(tf.float32, [None, 1], name='IS_weights')
 9
10        ...
11        # 为了得到 abs 的 TD error 并用于修改这些 sample 的 priority, 我们修改如下
12        with tf.variable_scope('loss'):
13            if self.prioritized:
14                self.abs_errors = tf.reduce_sum(tf.abs(self.q_target - self.q_eval), axis=1)    # for updating Sumtree
15                self.loss = tf.reduce_mean(self.ISWeights * tf.squared_difference(self.q_target, self.q_eval))
16            else:
17                self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))

因为和Natural DQN使用的Memory不一样,所以在存储transition的时候方式也略不相同

 1class DQNPrioritizedReplay:
 2    def store_transition(self, s, a, r, s_):
 3        if self.prioritized:    # prioritized replay
 4            transition = np.hstack((s, [a, r], s_))
 5            self.memory.store(transition)
 6        else:       # random replay
 7            if not hasattr(self, 'memory_counter'):
 8                self.memory_counter = 0
 9            transition = np.hstack((s, [a, r], s_))
10            index = self.memory_counter % self.memory_size
11            self.memory[index, :] = transition
12            self.memory_counter += 1

我们在learn()部分的改变也在如下展示:

 1class DQNPrioritizedReplay:
 2    def learn(self):
 3        ...
 4        # 相对于 DQN 代码, 改变的部分
 5        if self.prioritized:
 6            tree_idx, batch_memory, ISWeights = self.memory.sample(self.batch_size)
 7        else:
 8            sample_index = np.random.choice(self.memory_size, size=self.batch_size)
 9            batch_memory = self.memory[sample_index, :]
10
11        ...
12
13        if self.prioritized:
14            _, abs_errors, self.cost = self.sess.run([self._train_op, self.abs_errors, self.loss],
15                                         feed_dict={self.s: batch_memory[:, :self.n_features],
16                                                    self.q_target: q_target,
17                                                    self.ISWeights: ISWeights})
18            self.memory.batch_update(tree_idx, abs_errors)   # update priority
19        else:
20            _, self.cost = self.sess.run([self._train_op, self.loss],
21                                         feed_dict={self.s: batch_memory[:, :self.n_features],
22                                                    self.q_target: q_target})
23
24        ...

对比结果

运行我Github中的这个MountainCar脚本,我们就不难发现,我们都从两种方法最初拿到第一个R+=10奖励的时候算起,看看经历过一次R+=10后,他们有没有好好利用这次的奖励,可以看出,有 Prioritized replay的可以高效地利用这些不常拿到的奖励,并好好学习他们。所以Prioritized replay 会更快结束每个 episode, 很快就到达了小旗子。

本文分享自微信公众号 - 计算机视觉漫谈(gh_41e38694bc19),作者:CristianoC

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-07-24

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 深度强化学习之DQN实战

    今天我们会将我们上一篇文章讲解的DQN的理论进行实战,实战的背景目前仍然是探险者上天堂游戏,不过在下一次开始我们会使用OpenAI gym的环境库,玩任何我们想...

    CristianoC
  • 迁移学习(Transfer Learning)

    距离上次更公众号已经有一段时间了,寒假到开学这段时间都没有更新,笔者在这跟大家说声抱歉。这个学期可能会更新一些有关深度学习的文章,尽量保持一周一更,也希望大家监...

    CristianoC
  • DDPG实战

    1.前言2.代码2.1 主结构2.2 Actor Critic2.3 经验池2.4 每回合算法

    CristianoC
  • Data Structurestackheapheap的实现索引堆tree并查集图 Graph

    堆的基本性质: ①堆中的某一个节点总是不小于或不大于其父节点的值。 ②堆总是一棵完全二叉树 比较经典的堆有二叉堆,费波纳茨堆等等。如果一棵二叉树最下层上的...

    西红柿炒鸡蛋
  • 用python解析xml的几种方法

    4.开始部分:<?xml version="1.0" encoding="utf-8"?>

    py3study
  • PyQt5 非模态对话框(apply 型)

    如果希望用户可以重复更改对话框中的参数并能马上看到修改结果,那么就要使用非模态对话框,这样用户就可以按照他们喜欢的方式来持续不断地修改数据并验证修改的结果了。

    用户6021899
  • 用 Python 写个坦克大战

    坦克大战是一款策略类的平面射击游戏,于 1985 年由 Namco 游戏公司发布,尽管时至今日已经有了很多衍生类的游戏,但这款游戏仍然受到了相当一部分人的欢迎,...

    纯洁的微笑
  • PyQt5 多文档窗口界面

    本篇介绍多文档窗口界面的写法。代码中并未实现关闭窗口前文档未保存的提醒对话框,因为之前已有介绍。

    用户6021899
  • Data Structures and Algorithms Basics(007):Stack-Queue

    用户5473628
  • html5点击出现燃放烟花特效

    今天我发现了一个非常好的html特效,是由HTML5来实现的,效果非常绚丽。效果如下:

    无邪Z

扫码关注云+社区

领取腾讯云代金券