R2D2: Recurrent Replay Distributed DQN
R2D2算法介绍:
https://mp.weixin.qq.com/s/KoFuCePBNPCXGWKpAy_wjw
算法的核心要点:
• Stored state: Storing the recurrent state in replay and using it to initialize the network at training time. This partially remedies the weakness of the zero start state strategy, however it may suffer from the effect of ‘representational drift’ leading to ‘recurrent state staleness’, as the stored recurrent state generated by a sufficiently old network could differ significantly from a typical state produced by a more recent version.
• Burn-in: Allow the network a ‘burn-in period’ by using a portion of the replay sequence only for unrolling the network and producing a start state, and update the network only on the remaining part of the sequence. We hypothesize that this allows the network to partially recover from a poor start state (zero, or stored but stale) and find itself in a better initial state before being required to produce accurate outputs.
普通强化学习Buffer:
Buffer存储的是(s, a, r, d, s')序列,数据一般为两维,第一维为batch_size,第二维为各自的维数(比如 obs_dim, act_dim), 设置Buffer的大小后进行队列存储,随机采样训练强化学习模型。
R2D2算法Buffer:
R2D2算法Buffer比普通Buffer多一维,存储的是固定长度的一段(s, a, r, d, s')序列,第一维为batch_size,第二维为序列的长度(比如 Lb+Lt),第三维为各自的维数(比如 obs_dim, act_dim), 设置Buffer的大小后进行队列存储,随机采样训练强化学习模型。
Buffer的实例化:
数据的收集:
在规定好Buffer的结构以后,改造与环境交互产生的数据使之与Buffer匹配。这里使用python库中提供的队列作为数据收集的临时存放的数据结构:
在环境reset的情况下,没有‘Burn-in’数据,可以采用生成np.zeros数据进行代替:
Lb表示‘Burn-in’数据的长度,Lt表示‘Training'数据的长度,它们的区别是是否产生强化学习Loss信号以及是否用于更新RNN网络。每隔Lt步存储一次,把队列里的数据存到Buffer里:
另外,在环境返回done=True后,需要reset,数据流将终止,同样采用生成np.zeros数据进行填满序列Lt,保存的data01用于记录数据是真实的数据(用True表示)还是np.zeros数据(用False表示)。
完整代码详见:
https://github.com/createamind/DRL/blob/master/spinup/algos/sac1_rnn/sac1_rnn.py