Sampled Softmax

sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation 以及tensorflow关于candidate sampling的文档:candidate sampling

1. 问题背景

在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:

因为我们输出的是一个概率值,所以(6)式的归一化银子ZZ的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。 基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)

2. 解决方法

(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)

3. tensorflow的实现

def sampled_softmax_loss(weights,
                         biases,
                         labels,
                         inputs,
                         num_sampled, # 每一个batch随机选择的类别
                         num_classes, # 所有可能的类别
                         num_true=1, #每一个sample的类别数量
                         sampled_values=None,
                         remove_accidental_hits=True,
                         partition_strategy="mod",
                         name="sampled_softmax_loss"):

tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。

原文: This operation is for training only. It is generally an underestimate of the full softmax loss. A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.

这个函数的主体主要调用了另外一个函数:

logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)

上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: Si∪tiS_i \cup{t_i} 而这个函数采样集合的代码如下:

sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,# 真实的label
          num_true=num_true,
          num_sampled=num_sampled, # 需要采样的子集大小
          unique=True,
          range_max=num_classes)

而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution 即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:

Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能

理解卷积

原文作者:Christopher Olah

56414
来自专栏人工智能

机器学习,Hello World from Javascript!

导语 JavaScript 适合做机器学习吗?这是一个问号。但每一位开发者都应该了解机器学习解决问题的思维和方法,并思考:它将会给我们的工作带来什么?同样,算法...

1825
来自专栏AI研习社

教你从零开始检测皮卡丘-CNN目标检测入门教程(下)

本文为大家介绍实验过程中训练、测试过程及结果。算法和数据集参见《从零开始码一个皮卡丘检测器-CNN目标检测入门教程(上)》 训练 Train 损失函数 Lo...

2733
来自专栏大数据挖掘DT机器学习

R语言实现 支持向量机

一、SVM的想法 回到我们最开始讨论的KNN算法,它占用的内存十分的大,而且需要的运算量也非常大。那么我们有没有可能找到几个最有代表性的点(即保...

2723
来自专栏一名叫大蕉的程序员

尝试克服一下小伙伴对神经网络的恐惧No.26

我是小蕉。 研表究明,这的网官的demo,代码确实的是己打自的。 这两天仔细研究了一下神经网络,简单的结构其实没想象中那么恐怖,只是我们自己吓自己,今天希望能把...

1856
来自专栏大数据挖掘DT机器学习

R语言与机器学习(分类算法)支持向量机

说到支持向量机,必须要提到july大神的《支持向量机通俗导论》,个人感觉再怎么写也不可能写得比他更好的了。这也正如青莲居士见到崔颢的黄鹤楼后也...

2734
来自专栏机器学习算法全栈工程师

如何利用深度学习写诗歌(使用Python进行文本生成)

翻译:李雪冬 编辑:李雪冬 前 言 从短篇小说到写5万字的小说,机器不断涌现出前所未有的词汇。在web上有大量的例子可供开发人员使...

5467
来自专栏黄成甲

数据分析之数据处理

数据处理是根据数据分析目的,将收集到的数据,用适当的处理方法进行加工、整理,形成适合数据分析的要求样式,它是数据分析前必不可少的工作,并且在整个数据分析工作量中...

852
来自专栏腾讯Bugly的专栏

基于 TensorFlow 在手机端实现文档检测

手机端运行卷积神经网络的一次实践 — 基于 TensorFlow 和 OpenCV 实现文档检测功能 1. 前言 本文不是神经网络或机器学习的入门教学,而是通过...

4504
来自专栏月色的自留地

从锅炉工到AI专家(4)

1657

扫码关注云+社区