【序列到序列学习】生成古诗词

生成古诗词

序列到序列学习实现两个甚至是多个不定长模型之间的映射,有着广泛的应用,包括:机器翻译、智能对话与问答、广告创意语料生成、自动编码(如金融画像编码)、判断多个文本串之间的语义相关性等。

在序列到序列学习任务中,我们首先以机器翻译任务为例,提供了多种改进模型供大家学习和使用。包括:不带注意力机制的序列到序列映射模型,这一模型是所有序列到序列学习模型的基础;使用Scheduled Sampling改善RNN模型在生成任务中的错误累积问题;带外部记忆机制的神经机器翻译,通过增强神经网络的记忆能力,来完成复杂的序列到序列学习任务。除机器翻译任务之外,我们也提供了一个基于深层LSTM网络生成古诗词,实现同语言生成的模型。

【序列到序列学习】

04

生成古诗词

|1. 简介

基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。

模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。以下是本例的简要目录结构及说明:

.

├── data # 存储训练数据及字典

│ ├── download.sh # 下载原始数据

├── README.md # 文档

├── index.html # 文档(html格式)

├── preprocess.py # 原始数据预处理

├── generate.py # 生成诗句脚本

├── network_conf.py # 模型定义

├── reader.py # 数据读取接口

├── train.py # 训练脚本

└── utils.py # 定义实用工具函数

数据处理

|2.数据处理

A.原始数据来源

本例使用中华古诗词数据库(https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。

B.原始数据下载

cd data && ./download.sh && cd ..

C.数据预处理

python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt

上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用.分隔。

训练数据示例:

登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓

觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸

晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈

模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。

|3. 模型训练

训练脚本train.py中的命令行参数可以通过python train.py --help查看。主要参数说明如下:

  • num_passes: 训练pass数 ;
  • batch_size: batch: 大小 ;
  • use_gpu: 是否使用GPU ;
  • trainer_count: trainer数目,默认为1;
  • save_dir_path: 模型存储路径,默认为当前目录下models目录 ;
  • encoder_depth: 模型中编码器LSTM深度,默认为3;
  • decoder_depth: 模型中解码器LSTM深度,默认为3 ;
  • train_data_path: 训练数据路径 ;
  • word_dict_path: 数据字典路径;
  • init_model_path: 初始模型路径,从头训练时无需指定。

A.训练执行

python train.py \

--num_passes 50 \

--batch_size 256 \

--use_gpu True \

--trainer_count 1 \

--save_dir_path models \

--train_data_path data/poems.txt \

--word_dict_path data/dict.txt \

2>&1 | tee train.log

每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。

B.最优模型参数

寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。

python -c 'import utils; utils.find_optiaml_pass("./train.log")'

|4. 生成诗句

使用generate.py脚本对输入诗句生成下一诗句,命令行参数可通过python generate.py --help查看。 主要参数说明如下:

  • model_path: 训练好的模型参数文件 ;
  • word_dict_path: 数据字典路径 ;
  • test_data_path: 输入数据路径 ;
  • batch_size: batch:大小,默认为1;
  • beam_size: beam search:中搜索范围大小,默认为5 ;
  • save_file: 输出保存路径;
  • use_gpu: 是否使用GPU。

执行生成

例如将诗句 孤帆遠影碧空盡 保存在文件 input.txt 中作为预测下句诗的输入,执行命令:

python generate.py \

--model_path models/pass_00049.tar.gz \

--word_dict_path data/dict.txt \

--test_data_path input.txt \

--save_file output.txt

生成结果将保存在文件 output.txt 中。对于上述示例输入,生成的诗句如下:

-9.6987 萬 壑 清 風 黃 葉 多

-10.0737 萬 里 遠 山 紅 葉 深

-10.4233 萬 壑 清 波 紅 一 流

-10.4802 萬 壑 清 風 黃 葉 深

-10.9060 萬 壑 清 風 紅 葉 多

今 日 AI 资 讯

(如欲了解详情,在后台回复当日日期数字,例如“316”即可!)

1.百度10.55亿元入股创维酷开。(量子位)

2.显著超越流行长短时记忆网络,阿里提出DFSMN语音识别声学模型。(新智元)

3.进化算法+AutoML,谷歌提出新型神经网络架构搜索方法。(机器之心)

本文分享自微信公众号 - PaddlePaddle(PaddleOpenSource)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2018-03-16

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏钱塘大数据

理工男图解零维到十维空间,烧脑已过度,受不了啦!

让我们从一个点开始,和我们几何意义上的点一样,它没有大小、没有维度。它只是被想象出来的、作为标志一个位置的点。它什么也没有,空间、时间通通不存在,这就是零维度。

33830
来自专栏Ken的杂谈

【系统设置】CentOS 修改机器名

18130
来自专栏前端桃园

知识体系解决迷茫的你

最近在星球里群里都有小伙伴说道自己对未来的路比较迷茫,一旦闲下来就不知道自己改干啥,今天我这篇文章就是让你觉得一天给你 25 个小时你都不够用,觉得睡觉都是浪费...

21840
来自专栏haifeiWu与他朋友们的专栏

复杂业务下向Mysql导入30万条数据代码优化的踩坑记录

从毕业到现在第一次接触到超过30万条数据导入MySQL的场景(有点low),就是在顺丰公司接入我司EMM产品时需要将AD中的员工数据导入MySQL中,因此楼主负...

29740
来自专栏腾讯社交用户体验设计

ISUX Xcube智能一键生成H5

51220
来自专栏FSociety

SQL中GROUP BY用法示例

GROUP BY我们可以先从字面上来理解,GROUP表示分组,BY后面写字段名,就表示根据哪个字段进行分组,如果有用Excel比较多的话,GROUP BY比较类...

5.2K20
来自专栏怀英的自我修炼

考研英语-1-导学

英二图表作文要重视。总体而言,英语一会比英语二难点。不过就写作而言,英语二会比英语一有难度,毕竟图表作文并不好写。

11910
来自专栏微信公众号:小白课代表

不只是软件,在线也可以免费下载百度文库了。

不管是学生,还是职场员工,下载各种文档几乎是不可避免的,各种XXX.docx,XXX.pptx更是家常便饭,人们最常用的就是百度文库,豆丁文库,道客巴巴这些下载...

44630
来自专栏钱塘大数据

中国互联网协会发布:《2018中国互联网发展报告》

在2018中国互联网大会闭幕论坛上,中国互联网协会正式发布《中国互联网发展报告2018》(以下简称《报告》)。《中国互联网发展报告》是由中国互联网协会与中国互联...

13750
来自专栏腾讯高校合作

【倒计时7天】2018教育部-腾讯公司产学合作协同育人项目申请即将截止!

15720

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励