前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Google BERT 中文应用之春节对对联

Google BERT 中文应用之春节对对联

作者头像
刀刀老高
发布2019-03-06 10:45:10
1.3K0
发布2019-03-06 10:45:10
举报
文章被收录于专栏:奇点大数据奇点大数据

在网上看到有人用 seq2seq 训练一个对对联的机器人,很好奇能不能用Google的BERT预训练模型微调,训练出一个不仅可以对传统对子,也可以对新词新句的泛化能力更好的对对联高手。今天大年初一,这样的例子刚好应景。在Google公开的BERT源代码中,附带两个微调的例子,一个是阅读理解,run_squad.py, 另一个是双句或单句分类, run_classifier.py ,并没有命名实体识别或者是 seq2seq 的例子。这次实验我会深度修改 Google BERT 在预训练数据上的微调模型,使得输出是与输入等长的序列。即上联中的每个字都会对应下联中相同位置的一个字,此任务比seq2seq简单,不需要将上联映射到潜在空间的一个向量后使用解码器产生非等长序列。既然 BERT 对输入的每一个 token 都产生了一个潜在空间的 768 维的向量,我们只需要再加一层,将每个token的768维向量变换成字典空间的 N (N=21128)维向量即可。

数据准备

数据从这里下载, 这些数据可能是一开始某高手从新浪抓取的,并提供了spider脚本,后来spider被封,但是数据流传了下来。此下载的链接从王斌的GitHub页面找到。https://github.com/wb14123/couplet-datasetGoogle Bert 中文预训练模型使用的字典文件比对联数据集使用的字典文件要小,为了省事,我们可以直接把训练数据测试数据中出现生僻字的那些对联去除,以使 Google Bert 的tokenization 对象能够正确的将字转化为id。

构建BERT模型

我们可以直接把 run_classifier.py 复制为 run_couplet.py, 其中 couplet 就是“对联”的英文单词。在修改代码的过程中,需要时刻牢记在简单分类器的任务中,每组数据的输出是二值的,0 或1,但在目前的任务中,上联的每个字都对应着一个输出,所以输出是一个序列。这个序列可以是NER命名实体识别任务对每个字的标记,也可以是对联的下联。

1。 需要在程序中手动指定字典的大小,因为对联中上联的每个字都对应下联中的一个字,这个字需要用 one-hot 向量表示,向量维度是字典的大小。

代码语言:javascript
复制
 vocab_size = 21128

2。修改微调模型架构,将BERT最后一层的权重矩阵由 768 * 2 改为 768 * vocab_size 的大小,注意在程序中,此操作可以省略,因为权重矩阵大小定义为 768 * num_labels ,而在对联任务中 num_labels 直接等于字典的大小。但是需要万分注意的是,name_to_features字典中label_ids 的描述里面一定要加 seq_length , 否则最后建立的模型 label_ids 的形状是 (batchsize, ), 而不是我们需要的 (batchsize, seq_length)。这一步至关重要!!!

代码语言:javascript
复制
"""Creates an `input_fn` closure to be passed to TPUEstimator."""                            |   545   """Creates an `input_fn` closure to be passed to TPUEstimator."""                           
                                                                                              |   546                                                                                               
   name_to_features = {                                                                         |   547   name_to_features = {                                                                        
       "input_ids": tf.FixedLenFeature([seq_length], tf.int64),                                 |   548       "input_ids": tf.FixedLenFeature([seq_length], tf.int64),                                
       "input_mask": tf.FixedLenFeature([seq_length], tf.int64),                                |   549       "input_mask": tf.FixedLenFeature([seq_length], tf.int64),                               
       "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),                               |   550       "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),                              
       "label_ids": tf.FixedLenFeature([seq_length], tf.int64),                                 |   551       "label_ids": tf.FixedLenFeature([], tf.int64),                                          
       "is_real_example": tf.FixedLenFeature([], tf.int64),                                     |   552       "is_real_example": tf.FixedLenFeature([], tf.int64),                                    
   }            

此外,在程序里有一段注释,提示我们如果想把二值分类问题改为序列标注问题,一定要将输出从maxpooling改为序列输出:

代码语言:javascript
复制
# In the demo, we are doing a simple classification task on the entire                            |   619   # In the demo, we are doing a simple classification task on the entire                      
# segment.                                                                                        |   620   # segment.                                                                                 
#                                                                                                 |   621   #                                                                                           
# If you want to use the token-level output, use model.get_sequence_output()                      |   622   # If you want to use the token-level output, use model.get_sequence_output()                
# instead.                                                                                        |   623   # instead.                                                                                  
#output_layer = model.get_pooled_output()                                                         |   624   output_layer = model.get_pooled_output()                                                    
output_layer = model.get_sequence_output()      

在计算损失函数的部分,也必须注意将BERT模型的输出特征形状改变为 H =(sequence_length, hiddensize),此矩阵与 权重矩阵 ( hiddensize, vocabsize) 相乘,会得到一个最终的输出矩阵,logits =(sequence_length, vocabsize),使用

代码语言:javascript
复制
probabilities = tf.nn.softmax(logits, axis=-1)

可以从 logits 计算出输出序列中某个位置是字典中每个字的几率,使用argmax()函数选取对应每个位置几率最大的那个字的id,

代码语言:javascript
复制
pred_ids = tf.argmax(probabilities, axis=-1)

并将其转化为字典文件中这个id对应的文字,

代码语言:javascript
复制
tokenizer.convert_ids_to_tokens(pred_ids)[1:ntokens+1]

3。添加 CoupletProcessor()

代码语言:javascript
复制
class CoupletProcessor(DataProcessor):
  """Processor for the Couplet data set."""
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.csv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.csv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.csv")), "test")

  def get_labels(self):
    """See base class."""
    return ["%d"%v for v in range(vocab_size)]

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []

    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      if set_type == "test":
        text_a = tokenization.convert_to_unicode(line[0])
        # 测试数据随便设标签,注意数据上联在左,下联(label)在右,'\t'分开
        label = tokenization.convert_to_unicode(line[0])
      else:
        text_a = tokenization.convert_to_unicode(line[0])
        label = tokenization.convert_to_unicode(line[1])
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
    return examples

4。最后记得注册我们的对联处理器,

代码语言:javascript
复制
 processors = {                                                                                    |   820   processors = {                                                                              
        "cola": ColaProcessor,                                                                        |   821       "cola": ColaProcessor,                                                                 
        "mnli": MnliProcessor,                                                                        |   822       "mnli": MnliProcessor,                                                                  
        "mrpc": MrpcProcessor,                                                                        |   823       "mrpc": MrpcProcessor,                                                                  
        "xnli": XnliProcessor,                                                                        |   824       "xnli": XnliProcessor,                                                                  
        "weibo":WeiboProcessor,                                                                       |   825       "weibo":WeiboProcessor                                                                  
        "couplet":CoupletProcessor                                                                    |       ----------------------------------------------------------------------------------------------
    }            

训练过程

将训练数据放在couplet_data 中,从 Google Bert 的 Github 页面下载预训练好的中文模型,运行我们修改好的 run_couplets.py 脚本,即可训练,验证以及测试。

代码语言:javascript
复制
export BERT_BASE_DIR="chinese_L-12_H-768_A-12"export COUPLET_DIR="couplet_data"python bert/run_couplets.py \
  --task_name=couplet \
  --do_train=True \
  --do_eval=True \
  --do_predict=True \
  --data_dir=$COUPLET_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --train_batch_size=32 \
  --num_train_epochs=5 \
  --learning_rate=3e-5 \
  --max_seq_length=56 \
  --output_dir=couplet_data/output/

结果分析

这个任务在一块 V100 GPU 处理器上大概要训练6小时,在4块 K40m 上大概要跑24 小时。这个时间可以通过 BERT 的输出日志估算出来

INFO:tensorflow: Num steps = 115357 INFO:tensorflow:global_step/sec: 5.96756

预计总共消耗时间:T = 115357 / 5.96 / 3600 = 5.37 小时。

在NER任务中,每个字的序列标注类型可能最多只有十几种,所以微调时需要重新训练的参数矩阵大概有 768 * 10 个左右,而在对联任务中,微调时需要重新训练的参数矩阵大小为 768 * 21128 个,这是一个巨大的数字,1500万,我不知道最终机器是否能够成功的学会映射矩阵。也许更聪明的做法是训练一个 768 * 128维的权重矩阵,然后使用 embedding 的逆操作,将128维向量反向投影到21128维字典空间。不过现阶段先不管,暴力尝试一把,不行未来再更新 :)

现在大概暴力训练了四分之一,先停下,看看预测效果,作为对比,使用了网站ai.binwang.me/couplet/ 的预测结果,标记为Seq2Seq :

上联:秃笔写传奇,扬民族正气

新浪论坛下联:: 苦心辑文史,著乡党风流 BERT-Couplet下联: 清 心 传 壮 志 , 展 中 国 新 风 Seq2Seq:丹心抒壮志,树华夏新风

上联: 苦辣酸甜,遭遇一生应不少

新浪论坛下联:: 诗书画印,兼擅四绝已无多 BERT-Couplet下联: 悲 非 苦 辣 , 感 来 万 世 总 无 多 Seq2Seq:悲欢离合,相逢半世总难完

上联:欲知古邑崇文史,但看千载瓷窑,千秋贡院

新浪论坛下联:莫问定州尚德风,且赏一方庙宇,一曲秧歌 BERT-Couplet下联: 不 看 新 城 创 古 风 , 更 看 一 方 风 水 , 一 里 春 风 Seq2Seq:且看名城焕新容,更喜一城锦绣,一派澎园

上联:明月妆楼舒且雅

新浪论坛下联:公司待客信而真 BERT-Couplet下联: 清 风 入 案 静 而 香 Seq2Seq: 清风入座爽而幽

上联:磨一砚春光入句

新浪论坛下联:赊半弯珪月赠人 BERT-Couplet下联: 借 三 杯 月 色 成 诗 Seq2Seq: 赊几分月色题诗

上联:太空漫步,神州九万里灰霾,奔来眼底

新浪论坛下联:寰宇聚焦,极地千平方臭氧,惊醒议程 BERT-Couplet下联: 大 海 放 歌 , 华 夏 五 千 年 锦 气 , 化 入 心 中 Seq2Seq: 大海捞珍,华夏五千年历史,涌上心头

上联:万里春光美

新浪论坛下联:九州瑞气浓 BERT-Couplet下联: 千 秋 气 象 新 Seq2Seq: 千家瑞气新

上联:天国遗珠,撒落苗都星七点

新浪论坛下联:乌江涌翠,添来彭水韵三分 BERT-Couplet下联: 人 州 流 彩 , 铺 开 盛 苑 月 千 年 稍加修改:神 州 流 彩 , 铺 开 盛 苑 月 千 年 Seq2Seq: 人间胜景,收藏古镇景千重

上联:里肇紫岩,溪山入画

新浪论坛下联:乡连继绵,榆社长春 BERT-Couplet下联: 门 开 碧 水 , 风 月 成 诗 Seq2Seq: 中流玄壑,云水为师

上联:如此江山需放眼

新浪论坛下联:若干风雨也抒怀 BERT-Couplet下联:这 般 岁 月 要 关 心 Seq2Seq: 这些风月不关心

上联:倾心李杜千秋句

新浪论坛下联:绝唱关雎万古篇 BERT-Couplet下联:放 意 江 山 万 古 诗 Seq2Seq: 放眼江山万里图

上联:杯中影

新浪论坛下联:节外枝 BERT-Couplet下联: 笔 上 花 Seq2Seq: 座上宾

上联:辟地筑幽居,凭栏处竹瀑松风,溪云涧月

新浪论坛下联:与人增古趣,信步间神祠佛舍,壁字岩诗 BERT-Couplet下联: 登 天 开 雅 境 , 把 笔 间 花 风 竹 韵 , 古 水 诗 风 Seq2Seq: 寻幽寻胜迹,把酒时松涛竹韵,竹雨松涛

当然,并不是所有句子都对的很完美,但是有些句子真是对的很惊艳!!!

很期待完全训练之后,以及更改架构,减少微调时需要训练的参数后 BERT-Couplet 的表现。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-02-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 智能工场AIWorkshop 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 数据准备
  • 构建BERT模型
  • 训练过程
  • 结果分析
  • 上联:秃笔写传奇,扬民族正气
  • 上联: 苦辣酸甜,遭遇一生应不少
  • 上联:欲知古邑崇文史,但看千载瓷窑,千秋贡院
  • 上联:明月妆楼舒且雅
  • 上联:磨一砚春光入句
  • 上联:太空漫步,神州九万里灰霾,奔来眼底
  • 上联:万里春光美
  • 上联:天国遗珠,撒落苗都星七点
  • 上联:里肇紫岩,溪山入画
  • 上联:如此江山需放眼
  • 上联:倾心李杜千秋句
  • 上联:杯中影
  • 上联:辟地筑幽居,凭栏处竹瀑松风,溪云涧月
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档