Char RNN原理介绍以及文本生成实践

正文共1523张图,3张图,预计阅读时间8分钟。

1、简介

Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness of Recurrent Neural Networks。众所周知,RNN非常擅长处理序列问题。序列数据前后有很强的关联性,而RNN通过每个单元权重与偏置的共享以及循环计算(前面处理过的信息会被利用处理后续信息)来体现。Char-RNN模型是从字符的维度上,让机器生成文本,即通过已经观测到的字符出发,预测下一个字符出现的概率,也就是序列数据的推测。现在网上介绍的用深度学习写歌、写诗、写小说的大多都是基于这个方法。

在基本的RNN单元中,只有一个隐藏状态,对于长距离的记忆效果很差(序列开始的信息在后期保留很少),而且存在梯度消失的问题,因此诞生了许多变体,如LSTM、GRU等。本文介绍的Char-RNN就是选用LSTM作为基本模型。

2、char RNN原理

Char RNN 原理

上图展示了Char-RNN的原理。以要让模型学习写出“hello”为例,Char-RNN的输入输出层都是以字符为单位。输入“h”,应该输出“e”;输入“e”,则应该输出后续的“l”。输入层我们可以用只有一个元素为1的向量来编码不同的字符,例如,h被编码为“1000”、“e”被编码为“0100”,而“l”被编码为“0010”。使用RNN的学习目标是,可以让生成的下一个字符尽量与训练样本里的目标输出一致。

在图一的例子中,根据前两个字符产生的状态和第三个输入“l”预测出的下一个字符的向量为<0.1, 0.5, 1.9, -1.1>,最大的一维是第三维,对应的字符则为“0010”,正好是“l”。这就是一个正确的预测。但从第一个“h”得到的输出向量是第四维最大,对应的并不是“e”,这样就产生代价。学习的过程就是不断降低这个代价。学习到的模型,对任何输入字符可以很好地不断预测下一个字符,如此一来就能生成句子或段落。

3、实践

下面是一个利用Char RNN实现写诗的应用,代码来自来自原先比较火的项目:https://github.com/jinfagang/tensorflow_poems,然后自己将其做成WEB应用,凑着学习了下如何使用tensorflow实现char rnn

 1def char_rnn(model,input_data,output_data,vocab_size,rnn_size=128,num_layers=2,batch_size=64,
 2     learning_rate=0.01):
 3"""
 4
 5:param model: rnn单元的类型 rnn, lstm gru
 6:param input_data: 输入数据
 7:param output_data: 输出数据
 8:param vocab_size: 词汇大小
 9:param rnn_size:
10:param num_layers:
11:param batch_size:
12:param learning_rate:学习率
13:return:
14"""
15end_points = {}
16
17if model=='rnn':
18cell_fun=tf.contrib.rnn.BasicRNNCell
19elif model=='gru':
20cell_fun=tf.contrib.rnn.GRUCell
21elif model=='lstm':
22cell_fun=tf.contrib.rnn.BasicLSTMCell
23
24cell = cell_fun(rnn_size, state_is_tuple=True)
25cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
26
27if output_data is not None:
28initial_state = cell.zero_state(batch_size, tf.float32)
29else:
30initial_state = cell.zero_state(1, tf.float32)
31
32with tf.device("/cpu:0"):
33embedding=tf.get_variable('embedding',initializer=tf.random_uniform(
34    [vocab_size+1,rnn_size],-1.0,1.0))
35
36inputs=tf.nn.embedding_lookup(embedding,input_data)
37
38# [batch_size, ?, rnn_size] = [64, ?, 128]
39outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
40output = tf.reshape(outputs, [-1, rnn_size])
41
42# logit计算
43weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
44bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
45logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
46# [?, vocab_size+1]
47if output_data is not None:
48# 独热编码
49labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
50# [?, vocab_size+1]
51
52loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
53# [?, vocab_size+1]
54
55total_loss = tf.reduce_mean(loss)
56train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
57
58end_points['initial_state'] = initial_state
59end_points['output'] = output
60end_points['train_op'] = train_op
61end_points['total_loss'] = total_loss
62end_points['loss'] = loss
63end_points['last_state'] = last_state
64else:
65prediction = tf.nn.softmax(logits)
66
67end_points['initial_state'] = initial_state
68end_points['last_state'] = last_state
69end_points['prediction'] = prediction
70
71return end_points

效果如下:

效果 1

效果 2

项目地址:https://github.com/yanqiangmiffy/char-rnn-writer/

4、参考资料

1、yanqiangmiffy/char-rnn-writer: 基于Char RNN实现的“作家”应用,可以写诗也可以写小说,看起来还ok

2、【深度学习】文本生成 - Django's blog - 博客园

3、简单的Char RNN生成文本

4、The Unreasonable Effectiveness of Recurrent Neural Networks

5、Recurrent Neural Networks (RNN) – Part 1: Basic RNN / Char-RNN – The Neural Perspective

6、Tensorflow下Char-RNN项目代码详解-学路网-学习路上 有我相伴

7、hzy46/Char-RNN-TensorFlow: Multi-language Char RNN for TensorFlow >= 1.2.

8、[译] RNN 循环神经网络系列 1:基本 RNN 与 CHAR-RNN-博客-云栖社区-阿里云

9、简单的Char RNN生成文本 - 简书

原文链接:https://www.jianshu.com/p/c55caf4c6467

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”: www.leadai.org

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-08-03

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏WOLFRAM

梵高《星夜》的『现代版』

17950
来自专栏灯塔大数据

干货|2017校招数据分析岗位笔试/面试知识点

2017校招正在火热的进行,后面会不断更新涉及到的相关知识点。 尽管听说今年几个大互联网公司招的人超少,但好像哪一年都说是就业困难,能够进去当然最好,不能进去...

69170
来自专栏个人分享

Kmeans算法学习与SparkMlLib Kmeans算法尝试

K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归...

22910
来自专栏素质云笔记

LSH︱python实现局部敏感随机投影森林——LSHForest/sklearn(一)

关于局部敏感哈希算法,之前用R语言实现过,但是由于在R中效能太低,于是放弃用LSH来做相似性检索。学了python发现很多模块都能实现,而且通过随机投影森林让查...

46180
来自专栏重庆的技术分享区

3吴恩达Meachine-Learing之线性代数回顾-(Linear-Algebra-Review)

17440
来自专栏崔庆才的专栏

Learning to Rank概述

Learning to Rank,即排序学习,简称为 L2R,它是构建排序模型的机器学习方法,在信息检索、自然语言处理、数据挖掘等场景中具有重要的作用。其达到的...

46250
来自专栏小鹏的专栏

一个隐马尔科夫模型的应用实例:中文分词

什么问题用HMM解决 现实生活中有这样一类随机现象,在已知现在情况的条件下,未来时刻的情况只与现在有关,而与遥远的过去并无直接关系。 比如天气预测,如果我...

40770
来自专栏PPV课数据科学社区

2017校招数据分析岗笔试/面试知识点

2017校招正在火热的进行,后面会不断更新涉及到的相关知识点。尽管听说今年几个大互联网公司招的人超少,但好像哪一年都说是就业困难,能够进去当然最好,不能进去是不...

67370
来自专栏崔庆才的专栏

深度学习效果不好?试试 Batch Normalization 吧!

Batch Normalization(简称BN)自从提出之后,因为效果特别好,很快被作为深度学习的标准工具应用在了各种场合。BN大法虽然好,但是也存在一些局...

1.5K30
来自专栏书山有路勤为径

卡尔曼滤波器(Kalman Filters)

卡尔曼滤波器,这是一种使用噪声传感器测量(和贝叶斯规则)来生成未知量的可靠估计的算法(例如车辆可能在3秒内的位置)。

44530

扫码关注云+社区

领取腾讯云代金券