专栏首页朴素人工智能convlab2中强化学习方法之对话策略学习浅析

convlab2中强化学习方法之对话策略学习浅析

CrossWoZ是一个多轮对话的中文数据集。对应的github地址在这

https://github.com/thu-coai/ConvLab-2

论文里面为了解决多轮对话的对话策略问题,分别用了基于规则(RulePolicy)和多种强化学习方法(比如PPOPolicy)。结果如下。

可以看到,Success rate 在RulePolicy中的表现远高于基于强化学习模型的policy。尽管如此,本文还是以学习的态度进入github地址分析了一下作者的代码。稍微还原一下强化学习PPOPolicy在多轮对话中建模的过程。

在具体的实现过程中,一共有以下几个重要概念

  • 对话状态 s
  • 动作 a
  • 回报 r

以代码仓库中PPOPolicy中的参数为例,s 是340维的0/1分布的离散空间,分别对应着多领域对话过程中的340个状态;a 是209维0/1离散空间,分别对应着在状态s下所可能执行的对应动作;回报 r 是根据环境确定的,比如如果完成了对话,则回报会给予一个很大的数,如果没有任何增益,则为 -1。

下面是代码实现,简单易懂。

  • 值函数
class Value(nn.Module):
    def __init__(self, s_dim, hv_dim):
        super(Value, self).__init__()

        self.net = nn.Sequential(nn.Linear(s_dim, hv_dim),
                                 nn.ReLU(),
                                 nn.Linear(hv_dim, hv_dim),
                                 nn.ReLU(),
                                 nn.Linear(hv_dim, 1))

    def forward(self, s):
        """
        :param s: [b, s_dim]
        :return:  [b, 1]
        """
        value = self.net(s)

        return value

当前状态对应的价值value,完全由状态s通过三层全连接确定。

  • 动作函数a
self.net = nn.Sequential(nn.Linear(s_dim, h_dim),
                         nn.ReLU(),
                         nn.Linear(h_dim, h_dim),
                         nn.ReLU(),
                         nn.Linear(h_dim, a_dim))

状态s下的动作也是用全连接来生成。但是具体使用的时候会有采样输出

def select_action(self, s, sample=True):
    """
    :param s: [s_dim]
    :return: [a_dim]
    """
    # forward to get action probs
    # [s_dim] => [a_dim]
    a_weights = self.forward(s)
    a_probs = torch.sigmoid(a_weights)
    
    # [a_dim] => [a_dim, 2]
    a_probs = a_probs.unsqueeze(1)
    a_probs = torch.cat([1-a_probs, a_probs], 1)
    a_probs = torch.clamp(a_probs, 1e-10, 1 - 1e-10)
    
    # [a_dim, 2] => [a_dim]
    a = a_probs.multinomial(1).squeeze(1) if sample else a_probs.argmax(1)
    
    return a

然后根据nn生成的价值函数v和环境给出的真实回报r,使用贝尔曼方程和优势梯度估计获取更新的价值函数v和优势。因为多轮对话是连续的,因此代码实现的时候通过mask来控制识别单轮和多轮。新生成的价值函数可以用来更新上面的价值网络Value。优势计算的结果则可以帮助更新策略网络net,以优化动作函数a

def est_adv(self, r, v, mask):
    """
    we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0.
    :param r: reward, Tensor, [b]
    :param v: estimated value, Tensor, [b]
    :param mask: indicates ending for 0 otherwise 1, Tensor, [b]
    :return: A(s, a), V-target(s), both Tensor
    """
    batchsz = v.size(0)

    # v_target is worked out by Bellman equation.
    v_target = torch.Tensor(batchsz).to(device=DEVICE)
    delta = torch.Tensor(batchsz).to(device=DEVICE)
    A_sa = torch.Tensor(batchsz).to(device=DEVICE)

    prev_v_target = 0
    prev_v = 0
    prev_A_sa = 0
    for t in reversed(range(batchsz)):
        # mask here indicates a end of trajectory
        # this value will be treated as the target value of value network.
        # mask = 0 means the immediate reward is the real V(s) since it's end of trajectory.
        # formula: V(s_t) = r_t + gamma * V(s_t+1)
        v_target[t] = r[t] + self.gamma * prev_v_target * mask[t]

        # please refer to : https://arxiv.org/abs/1506.02438
        # for generalized adavantage estimation
        # formula: delta(s_t) = r_t + gamma * V(s_t+1) - V(s_t)
        delta[t] = r[t] + self.gamma * prev_v * mask[t] - v[t]

        # formula: A(s, a) = delta(s_t) + gamma * lamda * A(s_t+1, a_t+1)
        # here use symbol tau as lambda, but original paper uses symbol lambda.
        A_sa[t] = delta[t] + self.gamma * self.tau * prev_A_sa * mask[t]

        # update previous
        prev_v_target = v_target[t]
        prev_v = v[t]
        prev_A_sa = A_sa[t]

    # normalize A_sa
    A_sa = (A_sa - A_sa.mean()) / A_sa.std()

    return A_sa, v_target

还有很多种强化学习方法,共同的目的都是学习更好的动作函数网络。如果对强化学习方法感兴趣可以继续深入仓库探索。

本文分享自微信公众号 - 朴素人工智能(sunnyday_no1),作者:凯华

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-05-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 疫情期间网民情绪识别比赛后记

    前阵子参加了 DataFountain 举办的 疫情期间网民情绪识别[1] 比赛,最终成绩排在第 20 名,成绩不是太好,本文就是纯粹记录一下,遇到太年轻的想法...

    朴素人工智能
  • 表格问答完结篇:落地应用

    不知道大家还记不记得,上一篇文章中的X-SQL和HydraNet都是来自微软的模型。微软作为一个老牌科技公司近年不仅在云计算领域迎头赶上,在AI方面也有很多优秀...

    朴素人工智能
  • pytorch中文语言模型bert预训练代码

    ACL2020 Best Paper有一篇论文提名奖,《Don’t Stop Pretraining: Adapt Language Models to Dom...

    朴素人工智能
  • UITextField 常用方法实例

    honey缘木鱼
  • rust leetcode Longest Substring Without Repeating Characters #3

    用户2436820
  • C++ FFLIB之ffcount:通用数据分析系统

    摘要: 数据分析已经变得不可或缺,几乎每个公司都依赖数据分析进行决策。在我从事的网游领域,数据分析是策划新功能、优化游戏体验最重要的手段之一。网游领域的数据分析...

    知然
  • 弹性文件服务与云硬盘一样吗?

    在公有云上,有很多的存储产品,让我们眼花缭乱,今天我们来看下弹性文件服务SFS。初一看,与我们在私有云经常使用的NAS有些神似,又与公有云上的云硬盘有些类似。只...

    希望的田野
  • 设计模式之单例模式

    通过上面的例子,我们实现了单例模式,无论我们怎样实例化类,都只能实例化一次类,大大的节省里系统资源的创建和销毁的开销

    北溟有鱼QAQ
  • 搜索引擎&小世界网络 答辩ppt

    Defu Li
  • 用Python实现WGS84、火星坐标系、百度坐标系、web墨卡托四种坐标相互转换

      主流被使用的地理坐标系并不统一,常用的有WGS84、GCJ02(火星坐标系)、BD09(百度坐标系)以及百度地图中保存矢量信息的web墨卡托,本文利用Pyt...

    Feffery

扫码关注云+社区

领取腾讯云代金券