用金庸、古龙群侠名称训练 LSTM,会生成多么奇葩的名字?

AI 研习社按:本文转载自 Magicly 博客,获作者授权。阅读原文请见:http://magicly.me/2017/04/07/rnn-lstm-generate-name/?utm_source=tuicool&utm_medium=referral。

Magicly:之前翻译了一篇介绍RNN的文章,一直没看到作者写新的介绍LSTM的blog,于是我又找了其他资料学习。本文先介绍一下LSTM,然后用LSTM在金庸、古龙的人名上做了训练,可以生成新的武侠名字,如果有兴趣的,还可以多搜集点人名,用于给小孩儿取名呢,哈哈,justforfun,大家玩得开心…

RNN回顾

RNN的出现是为了解决状态记忆的问题,解决方法很简单,每一个时间点t的隐藏状态h(t)不再简单地依赖于数据,还依赖于前一个时间节点t-1的隐藏状态h(t-1)。可以看出这是一种递归定义(所以循环神经网络又叫递归神经网络Recursive Neural Network),h(t-1)又依赖于h(t-2),h(t-2)依赖于h(t-3)…所以h(t)依赖于之前每一个时间点的输入,也就是说h(t)记住了之前所有的输入。

上图如果按时间展开,就可以看出RNN其实也就是普通神经网络在时间上的堆叠。

RNN问题:Long-Term Dependencies

一切似乎很完美,但是如果h(t)依赖于h(t - 1000),依赖路径特别长,会导致计算梯度的时候出现梯度消失的问题,训练时间很长根本没法实际使用。下面是一个依赖路径很长的例子:

1 我老家【成都】的。。。【此处省去500字】。。。我们那里经常吃【火锅】。。。

LSTM

Long Short Term Memory神经网络,也就是LSTM,由 Hochreiter & Schmidhuber于1997年发表。它的出现就是为了解决Long-Term Dependencies的问题,很来出现了很多改进版本,目前应用在相当多的领域(包括机器翻译、对话机器人、语音识别、Image Caption等)。

标准的RNN里,重复的模块里只是一个很简单的结构,如下图:

LSTM也是类似的链表结构,不过它的内部构造要复杂得多:

上图中的图标含义如下:

LSTM的核心思想是cell state(类似于hidden state,有LSTM变种把cell state和hidden state合并了, 比如GRU)和三种门:输入门、忘记门、输出门。

cell state每次作为输入传递到下一个时间点,经过一些线性变化后继续传往再下一个时间点(我还没看过原始论文,不知道为啥有了hidden state后还要cell state,好在确实有改良版将两者合并了,所以暂时不去深究了)。

门的概念来自于电路设计(我没学过,就不敢卖弄了)。LSTM里,门控制通过门之后信息能留下多少。如下图,sigmoid层输出[0, 1]的值,决定多少数据可以穿过门, 0表示谁都过不了,1表示全部通过。

下面我们来看看每个“门”到底在干什么。

首先我们要决定之前的cell state需要保留多少。 它根据h(t-1)和x(t)计算出一个[0, 1]的数,决定cell state保留多少,0表示全部丢弃,1表示全部保留。为什么要丢弃呢,不是保留得越多越好么?假设LSTM在生成文章,里面有小明和小红,小明在看电视,小红在厨房做饭。如果当前的主语是小明, ok,那LSTM应该输出看电视相关的,比如找遥控器啊, 换台啊,如果主语已经切换到小红了, 那么接下来最好暂时把电视机忘掉,而输出洗菜、酱油、电饭煲等。

第二步就是决定输入多大程度上影响cell state。这个地方由两部分构成, 一个用sigmoid函数计算出有多少数据留下,一个用tanh函数计算出一个候选C(t)。 这个地方就好比是主语从小明切换到小红了, 电视机就应该切换到厨房。

然后我们把留下来的(t-1时刻的)cell state和新增加的合并起来,就得到了t时刻的cell state。

最后我们把cell state经过tanh压缩到[-1, 1],然后输送给输出门([0, 1]决定输出多少东西)。

现在也出了很多LSTM的变种, 有兴趣的可以看这里。另外,LSTM只是为了解决RNN的long-term dependencies,也有人从另外的角度来解决的,比如Clockwork RNNs by Koutnik, et al. (2014).

show me the code!

我用的Andrej Karpathy大神的代码, 做了些小改动。这个代码的好处是不依赖于任何深度学习框架,只需要有numpy就可以马上run起来!

  1. """
  2. Minimal character-level Vanilla RNN model. Written by Andrej Karpathy (@karpathy)
  3. BSD License
  4. """
  5. import numpy as np
  6. # data I/O
  7. data = open('input.txt', 'r').read() # should be simple plain text file
  8. all_names = set(data.split("\n"))
  9. chars = list(set(data))
  10. data_size, vocab_size = len(data), len(chars)
  11. print('data has %d characters, %d unique.' % (data_size, vocab_size))
  12. char_to_ix = {ch: i for i, ch in enumerate(chars)}
  13. ix_to_char = {i: ch for i, ch in enumerate(chars)}
  14. # print(char_to_ix, ix_to_char)
  15. # hyperparameters
  16. hidden_size = 100 # size of hidden layer of neurons
  17. seq_length = 25 # number of steps to unroll the RNN for
  18. learning_rate = 1e-1
  19. # model parameters
  20. Wxh = np.random.randn(hidden_size, vocab_size) * 0.01 # input to hidden
  21. Whh = np.random.randn(hidden_size, hidden_size) * 0.01 # hidden to hidden
  22. Why = np.random.randn(vocab_size, hidden_size) * 0.01 # hidden to output
  23. bh = np.zeros((hidden_size, 1)) # hidden bias
  24. by = np.zeros((vocab_size, 1)) # output bias
  25. def lossFun(inputs, targets, hprev):
  26. """
  27. inputs,targets are both list of integers.
  28. hprev is Hx1 array of initial hidden state
  29. returns the loss, gradients on model parameters, and last hidden state
  30. """
  31. xs, hs, ys, ps = {}, {}, {}, {}
  32. hs[-1] = np.copy(hprev)
  33. loss = 0
  34. # forward pass
  35. for t in range(len(inputs)):
  36. xs[t] = np.zeros((vocab_size, 1)) # encode in 1-of-k representation
  37. xs[t][inputs[t]] = 1
  38. hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh,
  39. hs[t - 1]) + bh) # hidden state
  40. # unnormalized log probabilities for next chars
  41. ys[t] = np.dot(Why, hs[t]) + by
  42. # probabilities for next chars
  43. ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t]))
  44. loss += -np.log(ps[t][targets[t], 0]) # softmax (cross-entropy loss)
  45. # backward pass: compute gradients going backwards
  46. dWxh, dWhh, dWhy = np.zeros_like(
  47. Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  48. dbh, dby = np.zeros_like(bh), np.zeros_like(by)
  49. dhnext = np.zeros_like(hs[0])
  50. for t in reversed(range(len(inputs))):
  51. dy = np.copy(ps[t])
  52. # backprop into y. see
  53. # http://cs231n.github.io/neural-networks-case-study/#grad if confused
  54. # here
  55. dy[targets[t]] -= 1
  56. dWhy += np.dot(dy, hs[t].T)
  57. dby += dy
  58. dh = np.dot(Why.T, dy) + dhnext # backprop into h
  59. dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
  60. dbh += dhraw
  61. dWxh += np.dot(dhraw, xs[t].T)
  62. dWhh += np.dot(dhraw, hs[t - 1].T)
  63. dhnext = np.dot(Whh.T, dhraw)
  64. for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
  65. # clip to mitigate exploding gradients
  66. np.clip(dparam, -5, 5, out=dparam)
  67. return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs) - 1]
  68. def sample(h, seed_ix, n):
  69. """
  70. sample a sequence of integers from the model
  71. h is memory state, seed_ix is seed letter for first time step
  72. """
  73. x = np.zeros((vocab_size, 1))
  74. x[seed_ix] = 1
  75. ixes = []
  76. for t in range(n):
  77. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)
  78. y = np.dot(Why, h) + by
  79. p = np.exp(y) / np.sum(np.exp(y))
  80. ix = np.random.choice(range(vocab_size), p=p.ravel())
  81. x = np.zeros((vocab_size, 1))
  82. x[ix] = 1
  83. ixes.append(ix)
  84. return ixes
  85. n, p = 0, 0
  86. mWxh, mWhh, mWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  87. mbh, mby = np.zeros_like(bh), np.zeros_like(by) # memory variables for Adagrad
  88. smooth_loss = -np.log(1.0 / vocab_size) * seq_length # loss at iteration 0
  89. while True:
  90. # prepare inputs (we're sweeping from left to right in steps seq_length
  91. # long)
  92. if p + seq_length + 1 >= len(data) or n == 0:
  93. hprev = np.zeros((hidden_size, 1)) # reset RNN memory
  94. p = 0 # go from start of data
  95. inputs = [char_to_ix[ch] for ch in data[p:p + seq_length]]
  96. targets = [char_to_ix[ch] for ch in data[p + 1:p + seq_length + 1]]
  97. # sample from the model now and then
  98. if n % 100 == 0:
  99. sample_ix = sample(hprev, inputs[0], 200)
  100. txt = ''.join(ix_to_char[ix] for ix in sample_ix)
  101. print('----\n %s \n----' % (txt, ))
  102. # forward seq_length characters through the net and fetch gradient
  103. loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
  104. smooth_loss = smooth_loss * 0.999 + loss * 0.001
  105. if n % 100 == 0:
  106. print('iter %d, loss: %f' % (n, smooth_loss)) # print progress
  107. # perform parameter update with Adagrad
  108. for param, dparam, mem in zip([Wxh, Whh, Why, bh, by],
  109. [dWxh, dWhh, dWhy, dbh, dby],
  110. [mWxh, mWhh, mWhy, mbh, mby]):
  111. mem += dparam * dparam
  112. param += -learning_rate * dparam / \
  113. np.sqrt(mem + 1e-8) # adagrad update
  114. p += seq_length # move data pointer
  115. n += 1 # iteration counter
  116. if ((smooth_loss < 10) or (n >= 20000)):
  117. sample_ix = sample(hprev, inputs[0], 2000)
  118. txt = ''.join(ix_to_char[ix] for ix in sample_ix)
  119. predicted_names = set(txt.split("\n"))
  120. new_names = predicted_names - all_names
  121. print(new_names)
  122. print('predicted names len: %d, new_names len: %d.\n' % (len(predicted_names), len(new_names)))
  123. break

view rawmin-char-rnn.py hosted with ❤ by GitHub

然后从网上找了金庸小说的人名,做了些预处理,每行一个名字,保存到input.txt里,运行代码就可以了。古龙的没有找到比较全的名字, 只有这份武功排行榜,只有100多人。

下面是根据两份名单训练的结果,已经将完全一致的名字(比如段誉)去除了,所以下面的都是LSTM“新创作发明”的名字哈。来, 大家猜猜哪一个结果是金庸的, 哪一个是古龙的呢?

{'姜曾铁', '袁南兰', '石万奉', '郭万嗔', '蔡家', '程伯芷', '汪铁志', '陈衣', '薛铁','哈赤蔡师', '殷飞虹', '钟小砚', '凤一刀', '宝兰', '齐飞虹', '无若之', '王老英', '钟','钟百胜', '师', '李沅震', '曹兰', '赵一刀', '钟灵四', '宗家妹', '崔树胜', '桑飞西','上官公希轰', '刘之余人童怀道', '周云鹤', '天', '凤', '西灵素', '大智虎师', '阮徒忠','王兆能', '袁铮衣商宝鹤', '常伯凤', '苗人大', '倪不凤', '蔡铁', '无伯志', '凤一弼','曹鹊', '黄宾', '曾铁文', '姬胡峰', '李何豹', '上官铁', '童灵同', '古若之', '慕官景岳','崔百真', '陈官', '陈钟', '倪调峰', '妹沅刀', '徐双英', '任通督', '上官铁褚容', '大剑太','胡阳', '生', '南仁郑', '南调', '石双震', '海铁山', '殷鹤真', '司鱼督', '德小','若四', '武通涛', '田青农', '常尘英', '常不志', '倪不涛', '欧阳', '大提督', '胡玉堂','陈宝鹤', '南仁通四蒋赫侯'}

{'邀三', '熊猫开', '鹰星', '陆开', '花', '薛玉罗平', '南宫主', '南宫九', '孙夫人','荆董灭', '铁不愁', '裴独', '玮剑', '人', '陆小龙王紫无牙', '连千里', '仲先生','俞白', '方大', '叶雷一魂', '独孤上红', '叶怜花', '雷大归', '恕飞', '白双发','邀一郎', '东楼', '铁中十一点红', '凤星真', '无魏柳老凤三', '萧猫儿', '东郭先凤','日孙', '地先生', '孟摘星', '江小小凤', '花双楼', '李佩', '仇珏', '白坏刹', '燕悲情','姬悲雁', '东郭大', '谢晓陆凤', '碧玉伯', '司实三', '陆浪', '赵布雁', '荆孤蓝','怜燕南天', '萧怜静', '龙布雁', '东郭鱼', '司东郭金天', '薛啸天', '熊宝玉', '无莫静','柳罗李', '东官小鱼', '渐飞', '陆地鱼', '阿吹王', '高傲', '萧十三', '龙童', '玉罗赵','谢郎唐傲', '铁夜帝', '江小凤', '孙玉玉夜', '仇仲忍', '萧地孙', '铁莫棠', '柴星夫','展夫人', '碧玉', '老无鱼', '铁铁花', '独', '薛月宫九', '老郭和尚', '东郭大路陆上龙关飞','司藏', '李千', '孙白人', '南双平', '王玮', '姬原情', '东郭大路孙玉', '白玉罗生', '高儿','东珏天', '萧王尚', '九', '凤三静', '和空摘星', '关吹雪', '上官官小凤', '仇上官金飞','陆上龙啸天', '司空星魂', '邀衣人', '主', '李寻欢天', '东情', '玉夫随', '赵小凤', '东郭灭', '邀祟厚', '司空星'}

感兴趣的还可以用古代诗人、词人等的名字来做训练,大家机器好或者有时间的可以多训练下,训练得越多越准确。

总结

RNN由于具有记忆功能,在NLP、Speech、Computer Vision等诸多领域都展示了强大的力量。实际上,RNN是图灵等价的。

1 If training vanilla neural nets is optimization over functions, training recurrent nets is optimization over programs.

LSTM是一种目前相当常用和实用的RNN算法,主要解决了RNN的long-term dependencies问题。另外RNN也一直在产生新的研究,比如Attention机制。有空再介绍咯。。。

Refers

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

http://karpathy.github.io/2015/05/21/rnn-effectiveness/

https://www.zhihu.com/question/29411132

https://gist.github.com/karpathy/d4dee566867f8291f086

https://deeplearning4j.org/lstm.html

延伸阅读:人脸识别 + 手机推送,从此不再害怕老板背后偷袭!

原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2017-04-24

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

资源 | 囊括欧亚非大陆多种语言的25个平行语料库数据集(拿走不谢!)

原文链接:https://gengo.ai/datasets/25-best-parallel-text-datasets-for-machine-transl...

2433
来自专栏专知

【AlphaGo Zero 核心技术-深度强化学习教程代码实战06】给Agent添加记忆功能

【导读】Google DeepMind在Nature上发表最新论文,介绍了迄今最强最新的版本AlphaGo Zero,不使用人类先验知识,使用纯强化学习,将价值...

5356
来自专栏机器之心

业界 | 微软提出基于程序图简化程序分析,直接从源代码中学习

1793
来自专栏大数据风控

数据分析中非常实用的自编函数和代码模块整理

大家周末好! 搞了接近四个周的模型开发工作,今天整理代码文件,评分卡模型基本告一段落了。那么在模型开发或者是我们日常的数据分析工作中,根据我们具体的业务需求,经...

25010
来自专栏SDNLAB

5G革命的技术,一个都不能少

第五代移动网络简称5G是产业界即将实现的移动技术革命,是LTE-A网络的深层演进技术。5G网络中的关键技术包括MIMO、OFDM、SC-FDMA等。 超密集微型...

45712
来自专栏数说工作室

海量文本用 Simhash, 2小时变4秒! | 文本分析:大规模文本处理(2)

这是一个相似匹配的问题(文本相似匹配基础→ 词频与余弦相似度)。但是,亿级数据库,用传统的相似度计算方法太慢了,我们需要一个文本查询方法,可以快速的把一段文本的...

4763
来自专栏iOSDevLog

Scikit-Learn教程:棒球分析 (一)

一个scikit-learn教程,通过将数据建模到KMeans聚类模型和线性回归模型来预测MLB每赛季的胜利。

2112
来自专栏大数据

季节性单位根

正如MAT8181课程中所讨论的那样,至少有两种非平稳的时间序列:存在趋势的和存在单位根(这种类型被称为 单整的)。单位根测试不能用来评估一个时间序列是否平稳,...

3525
来自专栏Python小屋

Python计算电场中两点间的电势差

根据组合数定义,需要计算3个数的阶乘,在很多编程语言中都很难直接使用整型变量表示大数的阶乘结果,虽然Python并不存在这个问题,但是计算大数的阶乘仍需要相当多...

961
来自专栏iOSDevLog

机器学习研究和开发所需的组件列表

Here is a list of components that are needed for the successful machine learning...

1022

扫码关注云+社区

领取腾讯云代金券