# convlab2中强化学习方法之对话策略学习浅析

#### CrossWoZ是一个多轮对话的中文数据集。对应的github地址在这

https://github.com/thu-coai/ConvLab-2

#### 论文里面为了解决多轮对话的对话策略问题，分别用了基于规则(RulePolicy)和多种强化学习方法(比如PPOPolicy)。结果如下。

• 对话状态 s
• 动作 a
• 回报 r

• 值函数
```class Value(nn.Module):
def __init__(self, s_dim, hv_dim):
super(Value, self).__init__()

self.net = nn.Sequential(nn.Linear(s_dim, hv_dim),
nn.ReLU(),
nn.Linear(hv_dim, hv_dim),
nn.ReLU(),
nn.Linear(hv_dim, 1))

def forward(self, s):
"""
:param s: [b, s_dim]
:return:  [b, 1]
"""
value = self.net(s)

return value```

• 动作函数a
```self.net = nn.Sequential(nn.Linear(s_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, a_dim))```

```def select_action(self, s, sample=True):
"""
:param s: [s_dim]
:return: [a_dim]
"""
# forward to get action probs
# [s_dim] => [a_dim]
a_weights = self.forward(s)
a_probs = torch.sigmoid(a_weights)

# [a_dim] => [a_dim, 2]
a_probs = a_probs.unsqueeze(1)
a_probs = torch.cat([1-a_probs, a_probs], 1)
a_probs = torch.clamp(a_probs, 1e-10, 1 - 1e-10)

# [a_dim, 2] => [a_dim]
a = a_probs.multinomial(1).squeeze(1) if sample else a_probs.argmax(1)

return a```

```def est_adv(self, r, v, mask):
"""
we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0.
:param r: reward, Tensor, [b]
:param v: estimated value, Tensor, [b]
:param mask: indicates ending for 0 otherwise 1, Tensor, [b]
:return: A(s, a), V-target(s), both Tensor
"""
batchsz = v.size(0)

# v_target is worked out by Bellman equation.
v_target = torch.Tensor(batchsz).to(device=DEVICE)
delta = torch.Tensor(batchsz).to(device=DEVICE)
A_sa = torch.Tensor(batchsz).to(device=DEVICE)

prev_v_target = 0
prev_v = 0
prev_A_sa = 0
for t in reversed(range(batchsz)):
# mask here indicates a end of trajectory
# this value will be treated as the target value of value network.
# mask = 0 means the immediate reward is the real V(s) since it's end of trajectory.
# formula: V(s_t) = r_t + gamma * V(s_t+1)
v_target[t] = r[t] + self.gamma * prev_v_target * mask[t]

# please refer to : https://arxiv.org/abs/1506.02438
# formula: delta(s_t) = r_t + gamma * V(s_t+1) - V(s_t)
delta[t] = r[t] + self.gamma * prev_v * mask[t] - v[t]

# formula: A(s, a) = delta(s_t) + gamma * lamda * A(s_t+1, a_t+1)
# here use symbol tau as lambda, but original paper uses symbol lambda.
A_sa[t] = delta[t] + self.gamma * self.tau * prev_A_sa * mask[t]

# update previous
prev_v_target = v_target[t]
prev_v = v[t]
prev_A_sa = A_sa[t]

# normalize A_sa
A_sa = (A_sa - A_sa.mean()) / A_sa.std()

return A_sa, v_target```

0 条评论

• ### 疫情期间网民情绪识别比赛后记

前阵子参加了 DataFountain 举办的 疫情期间网民情绪识别[1] 比赛，最终成绩排在第 20 名，成绩不是太好，本文就是纯粹记录一下，遇到太年轻的想法...

• ### 表格问答完结篇：落地应用

不知道大家还记不记得，上一篇文章中的X-SQL和HydraNet都是来自微软的模型。微软作为一个老牌科技公司近年不仅在云计算领域迎头赶上，在AI方面也有很多优秀...

• ### pytorch中文语言模型bert预训练代码

ACL2020 Best Paper有一篇论文提名奖，《Don’t Stop Pretraining: Adapt Language Models to Dom...

• ### C++ FFLIB之ffcount：通用数据分析系统

摘要: 数据分析已经变得不可或缺，几乎每个公司都依赖数据分析进行决策。在我从事的网游领域，数据分析是策划新功能、优化游戏体验最重要的手段之一。网游领域的数据分析...

• ### 弹性文件服务与云硬盘一样吗？

在公有云上，有很多的存储产品，让我们眼花缭乱，今天我们来看下弹性文件服务SFS。初一看，与我们在私有云经常使用的NAS有些神似，又与公有云上的云硬盘有些类似。只...

• ### 设计模式之单例模式

通过上面的例子，我们实现了单例模式，无论我们怎样实例化类，都只能实例化一次类，大大的节省里系统资源的创建和销毁的开销

• ### 用Python实现WGS84、火星坐标系、百度坐标系、web墨卡托四种坐标相互转换

主流被使用的地理坐标系并不统一，常用的有WGS84、GCJ02（火星坐标系）、BD09（百度坐标系）以及百度地图中保存矢量信息的web墨卡托，本文利用Pyt...