前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【序列到序列学习】生成古诗词

【序列到序列学习】生成古诗词

作者头像
用户1386409
发布2018-04-02 14:38:13
1.5K1
发布2018-04-02 14:38:13
举报
文章被收录于专栏:PaddlePaddle

生成古诗词

序列到序列学习实现两个甚至是多个不定长模型之间的映射,有着广泛的应用,包括:机器翻译、智能对话与问答、广告创意语料生成、自动编码(如金融画像编码)、判断多个文本串之间的语义相关性等。

在序列到序列学习任务中,我们首先以机器翻译任务为例,提供了多种改进模型供大家学习和使用。包括:不带注意力机制的序列到序列映射模型,这一模型是所有序列到序列学习模型的基础;使用Scheduled Sampling改善RNN模型在生成任务中的错误累积问题;带外部记忆机制的神经机器翻译,通过增强神经网络的记忆能力,来完成复杂的序列到序列学习任务。除机器翻译任务之外,我们也提供了一个基于深层LSTM网络生成古诗词,实现同语言生成的模型。

【序列到序列学习】

04

生成古诗词

|1. 简介

基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。

模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。以下是本例的简要目录结构及说明:

.

├── data # 存储训练数据及字典

│ ├── download.sh # 下载原始数据

├── README.md # 文档

├── index.html # 文档(html格式)

├── preprocess.py # 原始数据预处理

├── generate.py # 生成诗句脚本

├── network_conf.py # 模型定义

├── reader.py # 数据读取接口

├── train.py # 训练脚本

└── utils.py # 定义实用工具函数

数据处理

|2.数据处理

A.原始数据来源

本例使用中华古诗词数据库(https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。

B.原始数据下载

cd data && ./download.sh && cd ..

C.数据预处理

python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt

上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用.分隔。

训练数据示例:

登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓

觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸

晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈

模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。

|3. 模型训练

训练脚本train.py中的命令行参数可以通过python train.py --help查看。主要参数说明如下:

  • num_passes: 训练pass数 ;
  • batch_size: batch: 大小 ;
  • use_gpu: 是否使用GPU ;
  • trainer_count: trainer数目,默认为1;
  • save_dir_path: 模型存储路径,默认为当前目录下models目录 ;
  • encoder_depth: 模型中编码器LSTM深度,默认为3;
  • decoder_depth: 模型中解码器LSTM深度,默认为3 ;
  • train_data_path: 训练数据路径 ;
  • word_dict_path: 数据字典路径;
  • init_model_path: 初始模型路径,从头训练时无需指定。

A.训练执行

python train.py \

--num_passes 50 \

--batch_size 256 \

--use_gpu True \

--trainer_count 1 \

--save_dir_path models \

--train_data_path data/poems.txt \

--word_dict_path data/dict.txt \

2>&1 | tee train.log

每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。

B.最优模型参数

寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。

python -c 'import utils; utils.find_optiaml_pass("./train.log")'

|4. 生成诗句

使用generate.py脚本对输入诗句生成下一诗句,命令行参数可通过python generate.py --help查看。 主要参数说明如下:

  • model_path: 训练好的模型参数文件 ;
  • word_dict_path: 数据字典路径 ;
  • test_data_path: 输入数据路径 ;
  • batch_size: batch:大小,默认为1;
  • beam_size: beam search:中搜索范围大小,默认为5 ;
  • save_file: 输出保存路径;
  • use_gpu: 是否使用GPU。

执行生成

例如将诗句 孤帆遠影碧空盡 保存在文件 input.txt 中作为预测下句诗的输入,执行命令:

python generate.py \

--model_path models/pass_00049.tar.gz \

--word_dict_path data/dict.txt \

--test_data_path input.txt \

--save_file output.txt

生成结果将保存在文件 output.txt 中。对于上述示例输入,生成的诗句如下:

-9.6987 萬 壑 清 風 黃 葉 多

-10.0737 萬 里 遠 山 紅 葉 深

-10.4233 萬 壑 清 波 紅 一 流

-10.4802 萬 壑 清 風 黃 葉 深

-10.9060 萬 壑 清 風 紅 葉 多

今 日 AI 资 讯

(如欲了解详情,在后台回复当日日期数字,例如“316”即可!)

1.百度10.55亿元入股创维酷开。(量子位)

2.显著超越流行长短时记忆网络,阿里提出DFSMN语音识别声学模型。(新智元)

3.进化算法+AutoML,谷歌提出新型神经网络架构搜索方法。(机器之心)

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-03-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 PaddlePaddle 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
机器翻译
机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档