# DQN三大改进(二)-Prioritised replay

Prioritised replay原文：https://arxiv.org/pdf/1511.05952.pdf 代码地址：https://github.com/princewen/tensorflow_practice/tree/master/Prioritized_Replay_DQN_demo

# 1、背景

DQN中有两个关键的技术，叫做经验回放和双网络结构。

DQN中的损失函数定义为：

q-target如何计算呢？根据下面的公式：

SumTree 是一种树形结构, 每片树叶存储每个样本的优先级 p, 每个树枝节点只有两个分叉, 节点的值是两个分叉的合, 所以 SumTree 的顶端就是所有 p 的合. 如下图所示。最下面一层树叶存储样本的 p, 叶子上一层最左边的 13 = 3 + 10, 按这个规律相加, 顶层的 root 就是全部 p 的合了.

[0-7], [7-14], [14-21], [21-28], [28-35], [35-42]

# 2、代码实现

`#---------------------input----------------------self.s = tf.placeholder(tf.float32,[None,self.n_features],name='s')self.q_target = tf.placeholder(tf.float32,[None,self.n_actions],name='Q_target')self.s_ = tf.placeholder(tf.float32, [None, self.n_features], name='s_')if self.prioritized:    self.ISWeights = tf.placeholder(tf.float32,[None,1],name='IS_weights')`

```def build_layers(s, c_names, n_l1, w_initializer, b_initializer, trainable):
with tf.variable_scope('l1'):
w1 = tf.get_variable('w1', [self.n_features, n_l1], initializer=w_initializer, collections=c_names, trainable=trainable)
b1 = tf.get_variable('b1', [1, n_l1], initializer=b_initializer, collections=c_names,  trainable=trainable)
l1 = tf.nn.relu(tf.matmul(s, w1) + b1)

with tf.variable_scope('l2'):
w2 = tf.get_variable('w2', [n_l1, self.n_actions], initializer=w_initializer, collections=c_names,  trainable=trainable)
b2 = tf.get_variable('b2', [1, self.n_actions], initializer=b_initializer, collections=c_names,  trainable=trainable)
out = tf.matmul(l1, w2) + b2    return out```

```# ---------------------eval net -----------------with tf.variable_scope('eval_net'):
c_names, n_l1, w_initializer, b_initializer = \
['eval_net_params', tf.GraphKeys.GLOBAL_VARIABLES], 20, \
tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)  # config of layers

self.q_eval = build_layers(self.s, c_names, n_l1, w_initializer, b_initializer, True)# --------------------target net----------------with tf.variable_scope('target_net'):
c_names = ['target_net_params', tf.GraphKeys.GLOBAL_VARIABLES]    self.q_next = build_layers(self.s_, c_names, n_l1, w_initializer, b_initializer, False)```

```# --------------------loss and train -----------with tf.variable_scope('loss'):    if self.prioritized:        self.abs_errors = tf.reduce_sum(tf.abs(self.q_target - self.q_eval), axis=1)    # for updating Sumtree
self.loss = tf.reduce_mean(self.ISWeights * tf.squared_difference(self.q_target, self.q_eval))    else:        self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))
with tf.variable_scope('train'):    self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)```

```def __init__(self,capacity):    self.capacity = capacity    self.data_pointer = 0
self.tree = np.zeros(2 * capacity - 1)    self.data = np.zeros(capacity,dtype=object)

@propertydef total_p(self):    return self.tree[0]  # the root```

```def add(self,p,data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data    self.update(tree_idx,p)    self.data_pointer += 1
if self.data_pointer >= self.capacity:  # replace when exceed the capacity
self.data_pointer = 0```

```def update(self,tree_idx,p):
change = p - self.tree[tree_idx]    self.tree[tree_idx] = p    while tree_idx!=0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change```

```def get_leaf(self,v):
parent_idx = 0
while True:
cl_idx = 2 * parent_idx + 1
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree):
leaf_idx = parent_idx            break
else:            if v <= self.tree[cl_idx]:
parent_idx = cl_idx            else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx,self.tree[leaf_idx],self.data[data_idx]```

```def __init__(self, capacity):    self.tree = SumTree(capacity)    self.epsilon = 0.01  # small amount to avoid zero priority
self.alpha = 0.6  # [0~1] convert the importance of TD error to priority
self.beta = 0.4  # importance-sampling, from initial value increasing to 1
self.beta_increment_per_sampling = 0.001
self.abs_err_upper = 1.  # clipped abs error```

```def store(self, transition):
max_p = np.max(self.tree.tree[-self.tree.capacity:])    if max_p == 0:
max_p = self.abs_err_upper    self.tree.add(max_p, transition)   # set the max p for new p```

```def sample(self,n):
b_idx,b_memory,ISWeights = np.empty((n,),dtype=np.int32),np.empty((n,self.tree.data[0].size)),np.empty((n,1))

pri_seg = self.tree.total_p / n    self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])

min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p  # for later calculate ISweight

for i in range(n):
a, b = pri_seg * i, pri_seg * (i + 1)
v = np.random.uniform(a, b)
idx, p, data = self.tree.get_leaf(v)
prob = p / self.tree.total_p
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
b_idx[i], b_memory[i, :] = idx, data    return b_idx, b_memory, ISWeights```

```def batch_update(self, tree_idx, abs_errors):
abs_errors += self.epsilon  # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
ps = np.power(clipped_errors, self.alpha)    for ti, p in zip(tree_idx, ps):        self.tree.update(ti, p)```

```def choose_action(self, observation):
observation = observation[np.newaxis, :]    if np.random.uniform() < self.epsilon:
actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})
action = np.argmax(actions_value)    else:
action = np.random.randint(0, self.n_actions)    return action```

```def store(self,s,a,r,s_):    if self.prioritized:
transition = np.hstack((s, [a, r], s_))        self.memory.store(transition)    else:  # random replay
if not hasattr(self, 'memory_counter'):            self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
index = self.memory_counter % self.memory_size        self.memory[index, :] = transition        self.memory_counter += 1```

```t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]if self.learn_step_counter % self.replace_target_iter == 0:    self.sess.run(self.replace_target_op)    print('\ntarget_params_replaced\n')    ```

```if self.prioritized:
tree_idx, batch_memory, ISWeights = self.memory.sample(self.batch_size)else:
sample_index = np.random.choice(self.memory_size, size=self.batch_size)
batch_memory = self.memory[sample_index, :]```

```q_next, q_eval = self.sess.run(
[self.q_next, self.q_eval],
feed_dict={self.s_: batch_memory[:, -self.n_features:],               self.s: batch_memory[:, :self.n_features]})

q_target = q_eval.copy()
batch_index = np.arange(self.batch_size, dtype=np.int32)
eval_act_index = batch_memory[:, self.n_features].astype(int)
reward = batch_memory[:, self.n_features + 1]

q_target[batch_index, eval_act_index] = reward + self.gamma * np.max(q_next, axis=1)if self.prioritized:
_, abs_errors, self.cost = self.sess.run([self._train_op, self.abs_errors, self.loss],
feed_dict={self.s: batch_memory[:, :self.n_features],                                            self.q_target: q_target,                                            self.ISWeights: ISWeights})    self.memory.batch_update(tree_idx, abs_errors)     # update priorityelse:
_, self.cost = self.sess.run([self._train_op, self.loss],
feed_dict={self.s: batch_memory[:, :self.n_features],                                            self.q_target: q_target})self.cost_his.append(self.cost)self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_maxself.learn_step_counter += 1```

# 3、参考文献

1、Prioritized Experience Replay (DQN) (Tensorflow)：https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-6-prioritized-replay/ 2、PRIORITIZED EXPERIENCE REPLAY：https://www.cnblogs.com/wangxiaocvpr/p/5660232.html

261 篇文章89 人订阅

0 条评论

## 相关文章

46770

23550

444100

### Github 项目推荐 | 用于 C/C++、Java、Matlab/Octave 的特征选择工具箱

FEAST 是用于 C/C++、Java、Matlab/Octave 的特征选择工具集合，它提供了基于滤波器特征选择算法的常用互信息的实现以及 RELIEF 的...

51880

2.8K30

3.1K90

1.1K40

13710

### 【论文推荐】最新六篇情感分析相关论文—深度上下文、支持向量机、两级LSTM、多模态情感分析、软件工程、代码混合

【导读】专知内容组整理了最近六篇情感分析（Sentiment Analysis）相关文章，为大家进行介绍，欢迎查看! 1. Deep contextualize...

1.4K130

19450