在TensorFlow上用LSTM做情感分析

小白有话说

最近有点忙,果然我不适合女强人这么高难度的角色,还是安心当我的废柴好了。手动再见~

这次分享的是上周五开会前做的东西,在TensorFlow上使用LSTM做情感分析。毕业生已经陆续答辩结束了.....被放养了这么久

花了好久才算初步弄明白LSTM的原理,呵呵,只是初步而已,从网上找了段代码,跑通了才有底气做接下来细心钻研的工作啊!

数据集:IMDB电影评论集;tensorflow 1.1.0

1

RNN

Recurrent:时间维度的展开,代表信息在时间维度从前往后的传递和积累,在神经网络结构上表现为后面的神经网络的隐藏层的输入是前面的神经网络的隐藏层的输出。

Recursive:空间维度的展开,是一个树结构,假设句子是一个树状结构,由几个部分(主语,谓语,宾语)组成,而每个部分又可以在分成几个小部分,即某一部分的信息由它的子树的信息组合而来,整句话的信息由组成这句话的几个部分组成。

但是RNN的缺陷也是很明显的,存在梯度爆炸和梯度消失问题(具体是什么意思,并没有弄明白,主要是我的重点本来也不是RNN....)

2

LSTM

LSTM,长短时记忆网络,是一种特殊的RNN,可以学习长期依赖信息。通过刻意的设计来避免长期依赖问题。记住长期的信息在实践中是 LSTM 的默认行为,而非需要付出很大代价才能获得的能力。

下面是LSTM最经典的原理图:

由图中可以看出,

LSTM有三个门,input gate(输入门)、output gate(输出门)、forget gate(遗忘门);

有4个输入,每个输入都会经过一个激活函数(一般用的是sigmoid函数),比如,z经过激活函数得到g(z);

g(z)与输入门产生的f(zi)相乘,如果f(zi)在0附近,相当于输入门是关闭的,要是为1,相当于输入门是打开的;

再看遗忘门,输入zf,得到相应值f(zf),与之前“记忆体”中储值c相乘得到cf(zf);

到了这一步,若是RNN,是直接写入输入值得,而LSTM,则是先得到C`=g(z)f(zi)+cf(zf),把c1写入“记忆”。要是f(zf)为0,则意味着把一起的“记忆”遗忘了,不再写入;要是为1,说明“记忆”会保存;

输出门同理可得了。

这么一看,似乎明白了,又似乎没明白...

3

Tensorflow

TensorFlow是一个使用数据流图进行数值计算的开放源代码软件库。

图中的节点代表数学运算,而图中的边则代表在这些节点之间传递的多维数组(张量)。

从目前的文档看,TensorFlow支持CNN、RNN和LSTM算法,这都是目前在Image,Speech和NLP最流行的深度神经网络模型。

所以,选择tensorflow不是偶然嘛.....(心里表示呵呵呵呵)

LSTM-SA

1

向量化

训练一个词向量生成模型(比如Word2Vec)或者加载预训练的词向量

2

RNN

RNN(使用LSTM单元)图形创建

8

训练

9

测试

1

加载预训练模型

Google新闻训练集上训练Word2Vec模型:300万个词向量,每个向量维数是300

导入的是:由一个类似的词向量生成模型Glove训练。矩阵包含40万个词向量,每个维数为50。

2

训练集

IMDB电影评论数据集。这个集合中有25000个电影评论,12,500次正面评论和12,500次评论。

可视化数据(确定最大序列长度为250)

3

定义超参数

4

指定占位符

一个用于输入到网络中,一个用于标签。

定义这些占位符的最重要的是了解每个维度。

5

把输入填充进LSTM网络

函数输入一个整数代表我们要用到的LSTM单元数。这是用来调整以利于确定最优值的超参数之一。然后我们将LSTM单元包装在一个退出层,以防止网络过拟合。

6

定义正确的预测和精度指标

正确的预测公式通过查看2个输出值的最大值的索引

查看它是否与训练标签相匹配工作

7

定义标准交叉熵

基于最终预测值上的激活函数层定义标准交叉熵,使用Adam优化器,默认学习率为0.01。

8

加载预训练模型

9

测试

代码如下

import tensorflow as tf

hello = tf.constant('Hello, TensorFlow!')

sess = tf.Session()

print(sess.run(hello))

运行结果:

numDimensions = 300

maxSeqLength = 250 #最大序列长度

batchSize = 24 #批处理大小

lstmUnits = 64 #LSTM单元数

numClasses = 2 #输出类别数

iterations = 100000 #训练次数

import numpy as np

wordsList = np.load('.\wordsList.npy').tolist()

wordsList = [word.decode('UTF-8') for word in wordsList] #Encode words as UTF-8

wordVectors = np.load('.\wordVectors.npy')

import tensorflow as tf

tf.reset_default_graph()

labels = tf.placeholder(tf.float32, [batchSize, numClasses])

input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])

data = tf.Variable(tf.zeros([batchSize, maxSeqLength, numDimensions]),dtype=tf.float32)

lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.25)

weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))

bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))

value = tf.transpose(value, [1, 0, 2])

last = tf.gather(value, int(value.get_shape()[0]) - 1)

prediction = (tf.matmul(last, weight) + bias)

correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))

accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))

sess = tf.InteractiveSession()

saver = tf.train.Saver()

saver.restore(sess, tf.train.latest_checkpoint('.\models'))

import re

strip_special_chars = re.compile("[^A-Za-z0-9 ]+")

def cleanSentences(string):

string = string.lower().replace("

", " ")

return re.sub(strip_special_chars, "", string.lower())

def getSentenceMatrix(sentence):

arr = np.zeros([batchSize, maxSeqLength])

sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32')

cleanedSentence = cleanSentences(sentence)

split = cleanedSentence.split()

for indexCounter,word in enumerate(split):

try:

sentenceMatrix[0,indexCounter] = wordsList.index(word)

except ValueError:

sentenceMatrix[0,indexCounter] = 399999 #Vector for unkown words

return sentenceMatrix

inputText = "That movie was terrible."

inputMatrix = getSentenceMatrix(inputText)

predictedSentiment = sess.run(prediction, )[0]

if (predictedSentiment[0] > predictedSentiment[1]):

print ("Positive Sentiment")

else:

print ("Negative Sentiment")

运行结果:

写在最后:

tensorflow的版本一定要是1.1.0,如果不是会报错,虽然我也不知道是为什么,但是就是会报错....可怜的我啊,起初我是不信的,就觉得怎么可能是版本的问题呢!百度,谷歌,github上翻遍了,发现还真就是版本的问题....所以如果你也遇到报错,不妨改改版本吧~

代码中用到的数据集,还有预加载的word2vec模型和预训练模型,就不传了(主要是我也不会),如果有需要,私信我。

如果文中有不严谨或者不对的对方,欢迎指正,不接受批评。如果你也是NLP小白,欢迎加入,一起学习~

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180522G1R06U00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。

扫码关注云+社区

领取腾讯云代金券