Char RNN原理介绍以及文本生成实践

1 简介

Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness of Recurrent Neural Networks。众所周知,RNN非常擅长处理序列问题。序列数据前后有很强的关联性,而RNN通过每个单元权重与偏置的共享以及循环计算(前面处理过的信息会被利用处理后续信息)来体现。Char-RNN模型是从字符的维度上,让机器生成文本,即通过已经观测到的字符出发,预测下一个字符出现的概率,也就是序列数据的推测。现在网上介绍的用深度学习写歌、写诗、写小说的大多都是基于这个方法。

在基本的RNN单元中,只有一个隐藏状态,对于长距离的记忆效果很差(序列开始的信息在后期保留很少),而且存在梯度消失的问题,因此诞生了许多变体,如LSTM、GRU等。本文介绍的Char-RNN就是选用LSTM作为基本模型。

2 Char RNN 原理

Char RNN 原理

上图展示了Char-RNN的原理。以要让模型学习写出“hello”为例,Char-RNN的输入输出层都是以字符为单位。输入“h”,应该输出“e”;输入“e”,则应该输出后续的“l”。输入层我们可以用只有一个元素为1的向量来编码不同的字符,例如,h被编码为“1000”、“e”被编码为“0100”,而“l”被编码为“0010”。使用RNN的学习目标是,可以让生成的下一个字符尽量与训练样本里的目标输出一致。在图一的例子中,根据前两个字符产生的状态和第三个输入“l”预测出的下一个字符的向量为<0.1, 0.5, 1.9, -1.1>,最大的一维是第三维,对应的字符则为“0010”,正好是“l”。这就是一个正确的预测。但从第一个“h”得到的输出向量是第四维最大,对应的并不是“e”,这样就产生代价。学习的过程就是不断降低这个代价。学习到的模型,对任何输入字符可以很好地不断预测下一个字符,如此一来就能生成句子或段落。

3 实践

下面是一个利用Char RNN实现写诗的应用,代码来自来自原先比较火的项目:https://github.com/jinfagang/tensorflow_poems,然后自己将其做成WEB应用,凑着学习了下如何使用tensorflow实现char rnn

def char_rnn(model,input_data,output_data,vocab_size,rnn_size=128,num_layers=2,batch_size=64,
             learning_rate=0.01):
    """

    :param model: rnn单元的类型 rnn, lstm gru
    :param input_data: 输入数据
    :param output_data: 输出数据
    :param vocab_size: 词汇大小
    :param rnn_size:
    :param num_layers:
    :param batch_size:
    :param learning_rate:学习率
    :return:
    """
    end_points = {}

    if model=='rnn':
        cell_fun=tf.contrib.rnn.BasicRNNCell
    elif model=='gru':
        cell_fun=tf.contrib.rnn.GRUCell
    elif model=='lstm':
        cell_fun=tf.contrib.rnn.BasicLSTMCell

    cell = cell_fun(rnn_size, state_is_tuple=True)
    cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

    if output_data is not None:
        initial_state = cell.zero_state(batch_size, tf.float32)
    else:
        initial_state = cell.zero_state(1, tf.float32)

    with tf.device("/cpu:0"):
        embedding=tf.get_variable('embedding',initializer=tf.random_uniform(
            [vocab_size+1,rnn_size],-1.0,1.0))

        inputs=tf.nn.embedding_lookup(embedding,input_data)



    # [batch_size, ?, rnn_size] = [64, ?, 128]
    outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
    output = tf.reshape(outputs, [-1, rnn_size])

    # logit计算
    weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
    bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
    logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
    # [?, vocab_size+1]


    if output_data is not None:
        # 独热编码
        labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
        # [?, vocab_size+1]

        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
        # [?, vocab_size+1]

        total_loss = tf.reduce_mean(loss)
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        end_points['initial_state'] = initial_state
        end_points['output'] = output
        end_points['train_op'] = train_op
        end_points['total_loss'] = total_loss
        end_points['loss'] = loss
        end_points['last_state'] = last_state
    else:
        prediction = tf.nn.softmax(logits)

        end_points['initial_state'] = initial_state
        end_points['last_state'] = last_state
        end_points['prediction'] = prediction

    return end_points

效果如下:

效果 1

效果 2

项目地址:https://github.com/yanqiangmiffy/char-rnn-writer/

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Astropeak

基于隐马尔科夫模型的中文分词方法

本文主要讲述隐马尔科夫模及其在中文分词中的应用。 基于中文分词语料库,建立中文分词的隐马尔科夫模型,最后用维特比方法进行求解。

11630
来自专栏杨熹的专栏

word2vec 模型思想和代码实现

CS224d-Day 3: word2vec 有两个模型,CBOW 和 Skip-Gram,今天先讲 Skip-Gram 的算法和实现。 课件: https:...

46650
来自专栏AI启蒙研究院

【读书笔记】之矩阵知识梳理

10220
来自专栏社区的朋友们

深度学习入门实战(三):图片分类中的逻辑回归

上一讲我们介绍了一下线性回归如何通过TensorFlow训练,这一讲我们介绍下逻辑回归模型,一句话说概括,逻辑回归就是多分类问题。

4.7K00
来自专栏ATYUN订阅号

【教程】OpenCV—Node.js教程系列:用Tensorflow和Caffe“做游戏”

? 今天我们来看看OpenCV的深度神经网络模块。如果你想要释放神经网络的awesomeness来识别和分类图像中的物体,但完全不知道深度学习如何工作,也不知...

53780
来自专栏SnailTyan

Very Deep Convolutional Networks for Large-Scale Image Recognition—VGG论文翻译—中英文对照

Very Deep Convolutional Networks for Large-Scale Image Recognition ABSTRACT In t...

31800
来自专栏人工智能LeadAI

线性回归与最小二乘法 | 机器学习笔记

这篇笔记会将几本的线性回归概念和最小二乘法。 在机器学习中,一个重要而且常见的问题就是学习和预测特征变量(自变量)与响应的响应变量(应变量)之间的函数关系 ...

36070
来自专栏专知

【干货】对抗自编码器PyTorch手把手实战系列——PyTorch实现自编码器

即使是非计算机行业, 大家也知道很多有名的神经网络结构, 比如CNN在处理图像上非常厉害, RNN能够建模序列数据. 然而CNN, RNN之类的神经网络结构本身...

69070
来自专栏机器学习原理

机器学习(7)——聚类算法聚类算法

聚类算法 前面介绍的集中算法都是属于有监督机器学习方法,这章和前面不同,介绍无监督学习算法,也就是聚类算法。在无监督学习中,目标属性是不存在的,也就是所说的不存...

1.6K70
来自专栏cloudskyme

跟我一起数据挖掘(23)——C4.5

C4.5简介 C4.5是一系列用在机器学习和数据挖掘的分类问题中的算法。它的目标是监督学习:给定一个数据集,其中的每一个元组都能用一组属性值来描述,每一个元组属...

36490

扫码关注云+社区

领取腾讯云代金券