深度学习第44讲:训练一个word2vec词向量

通常而言,训练一个词向量是一件非常昂贵的事情,我们一般会使用一些别人训练好的词向量模型来直接使用,很少情况下需要自己训练词向量,但这并不妨碍我们尝试来训练一个 word2vec 词向量模型进行试验。

在上一讲中,我们学习了 word2vec 的两种模型,一种是根据语境预测目标词的 CBOW 模型,另一种则是根据目标词预测语境的 skip-gram 模型。本节笔者将尝试使用 TensorFlow 根据给定语料训练一个 skip-gram 词向量模型。学习参考资料为 Andrew Ng deeplearningai 第五门课 assignment2 以及黄文坚所著的 TensorFlow 实战一书

先来回顾一下 skip-gram 词向量模型的网络结构:

skip-gram 的模型细节笔者本节不再赘述,下面我们看如何训练一个 skip-gram 模型。总体流程是先下载要训练的文本语料,然后根据语料构造词汇表,再根据词汇表和 skip-gram 模型特点生成 skip-gram 训练样本。训练样本准备好之后即可定义 skip-gram 模型网络结构,损失函数和优化计算过程,最后保存训练好的词向量即可。我们来看完整过程。

这里先导入整个试验过程所需要的 python 库。

准备语料

我们从http://mattmahoney.net/dc/网站下载好目标语料 text8.zip,当然也可以通过 python 编写 urllib 爬虫函数进行下载。

语料下载代码如下:

语料下载好后还是原始的文本,需要我们做一些进一步的处理。下面我们在读取压缩文件的同时调用 TensorFlow 的 compat.as_str 方法将语料转化为一个细分粒度以单词为单位的巨大列表:

可见整个语料被转化成了 17005207 个单词组成的巨大 list。这么大的单词数量我们肯定不能直接拿来做训练,需要进一步的对单词进行词频统计和转换。假设我们取词频 top 50000的单词作为词汇表,并将其放入 python 字典中,然后根据词汇表将列表中的每个单词根据频数排序给定一个编码,并取字典的反转形式(键值互换)。参考代码如下:

看一下词汇统计词频的前五个单词、字典中的前 10 个单词和对应的词频编码:

生成 skip-gram 训练样本

skip-gram 词向量模型是根据中间词来预测语境词,假设原始数据为 the quick brown fox jumped over the lazy dog. 现在我们需要将原始数据转化为 (quick,the)(quick,brown)(brown,quick)等词对的形式。然后我们需要定义几个关键变量:首先是生成每批训练数据的 batch_size,然后是每个单词向两边最远可以联系到距离,比如说 quick 只能和左右两个单词(quick,the)和(quick,brown) 进行联系,最后是每个单词能够生成的训练样本数量 skip_number。定义生成 skip-gram 生成样本函数如下:

生成训练样本示例如下所示:

搭建 skip-gram 模型训练过程

训练样本准备好后,便可以根据 skip-gram 模型结构进行模型搭建。同样先定义几个模型参数,第一个是训练批次 batch_size,这个我们之前在做图像处理的 CNN 模型训练的时候经常会碰到,我们训练批次为 128,然后的 embedding_size,这个是我们最后要生成词向量的维度,这里我们设置为 128,即我们要通过 skip-gram 算法将维度为 50000 的原始词汇表降维成 128 维的词向量。

执行训练:

训练过程如下图所示:

训练过程展示的 skip-gram 模型训练时的平均损失以及与验证集单词相似度最高的 8 个单词,可以看到与 may 语义相近的单词包括 can、would、will等词汇,可见由 skip-gram 模型训练得到的 word2vec 词向量表达质量是非常高的。

可视化展示和词向量保存

最后我们可以通过 t-SNE降维技术将 128 维的 skip-gram 词向量压缩到 2 维空间中进行展示,参考代码如下:

绘图效果如下:

可以看到在 2 维的词向量空间上,语义相近的词都被聚集到了一起。词向量训练好之后我们可以将其保存下来写入 txt 中方便以后调用:

最后咱们的词向量如下:

下一讲笔者将和大家分享如何使用训练好的词向量做一些 NLP 分析工作,比如说计算词汇之间的相似度,语义类比以及词汇语义除偏等分析。

参考资料:

deeplearningai.com

黄文坚 TensorFlow实战

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181116B1XSFI00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券