前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用金庸、古龙群侠名称训练 LSTM,会生成多么奇葩的名字?

用金庸、古龙群侠名称训练 LSTM,会生成多么奇葩的名字?

作者头像
AI研习社
发布2018-03-29 10:17:06
6810
发布2018-03-29 10:17:06
举报
文章被收录于专栏:AI研习社AI研习社
AI 研习社按:本文转载自 Magicly 博客,获作者授权。阅读原文请见:http://magicly.me/2017/04/07/rnn-lstm-generate-name/?utm_source=tuicool&utm_medium=referral。

Magicly:之前翻译了一篇介绍RNN的文章,一直没看到作者写新的介绍LSTM的blog,于是我又找了其他资料学习。本文先介绍一下LSTM,然后用LSTM在金庸、古龙的人名上做了训练,可以生成新的武侠名字,如果有兴趣的,还可以多搜集点人名,用于给小孩儿取名呢,哈哈,justforfun,大家玩得开心…

RNN回顾

RNN的出现是为了解决状态记忆的问题,解决方法很简单,每一个时间点t的隐藏状态h(t)不再简单地依赖于数据,还依赖于前一个时间节点t-1的隐藏状态h(t-1)。可以看出这是一种递归定义(所以循环神经网络又叫递归神经网络Recursive Neural Network),h(t-1)又依赖于h(t-2),h(t-2)依赖于h(t-3)…所以h(t)依赖于之前每一个时间点的输入,也就是说h(t)记住了之前所有的输入。

上图如果按时间展开,就可以看出RNN其实也就是普通神经网络在时间上的堆叠。

RNN问题:Long-Term Dependencies

一切似乎很完美,但是如果h(t)依赖于h(t - 1000),依赖路径特别长,会导致计算梯度的时候出现梯度消失的问题,训练时间很长根本没法实际使用。下面是一个依赖路径很长的例子:

1 我老家【成都】的。。。【此处省去500字】。。。我们那里经常吃【火锅】。。。

LSTM

Long Short Term Memory神经网络,也就是LSTM,由 Hochreiter & Schmidhuber于1997年发表。它的出现就是为了解决Long-Term Dependencies的问题,很来出现了很多改进版本,目前应用在相当多的领域(包括机器翻译对话机器人语音识别、Image Caption等)。

标准的RNN里,重复的模块里只是一个很简单的结构,如下图:

LSTM也是类似的链表结构,不过它的内部构造要复杂得多:

上图中的图标含义如下:

LSTM的核心思想是cell state(类似于hidden state,有LSTM变种把cell state和hidden state合并了, 比如GRU)和三种门:输入门、忘记门、输出门。

cell state每次作为输入传递到下一个时间点,经过一些线性变化后继续传往再下一个时间点(我还没看过原始论文,不知道为啥有了hidden state后还要cell state,好在确实有改良版将两者合并了,所以暂时不去深究了)。

门的概念来自于电路设计(我没学过,就不敢卖弄了)。LSTM里,门控制通过门之后信息能留下多少。如下图,sigmoid层输出[0, 1]的值,决定多少数据可以穿过门, 0表示谁都过不了,1表示全部通过。

下面我们来看看每个“门”到底在干什么。

首先我们要决定之前的cell state需要保留多少。 它根据h(t-1)和x(t)计算出一个[0, 1]的数,决定cell state保留多少,0表示全部丢弃,1表示全部保留。为什么要丢弃呢,不是保留得越多越好么?假设LSTM在生成文章,里面有小明和小红,小明在看电视,小红在厨房做饭。如果当前的主语是小明, ok,那LSTM应该输出看电视相关的,比如找遥控器啊, 换台啊,如果主语已经切换到小红了, 那么接下来最好暂时把电视机忘掉,而输出洗菜、酱油、电饭煲等。

第二步就是决定输入多大程度上影响cell state。这个地方由两部分构成, 一个用sigmoid函数计算出有多少数据留下,一个用tanh函数计算出一个候选C(t)。 这个地方就好比是主语从小明切换到小红了, 电视机就应该切换到厨房。

然后我们把留下来的(t-1时刻的)cell state和新增加的合并起来,就得到了t时刻的cell state。

最后我们把cell state经过tanh压缩到[-1, 1],然后输送给输出门([0, 1]决定输出多少东西)。

现在也出了很多LSTM的变种, 有兴趣的可以看这里。另外,LSTM只是为了解决RNN的long-term dependencies,也有人从另外的角度来解决的,比如Clockwork RNNs by Koutnik, et al. (2014).

show me the code!

我用的Andrej Karpathy大神的代码, 做了些小改动。这个代码的好处是不依赖于任何深度学习框架,只需要有numpy就可以马上run起来!

  1. """
  2. Minimal character-level Vanilla RNN model. Written by Andrej Karpathy (@karpathy)
  3. BSD License
  4. """
  5. import numpy as np
  6. # data I/O
  7. data = open('input.txt', 'r').read() # should be simple plain text file
  8. all_names = set(data.split("\n"))
  9. chars = list(set(data))
  10. data_size, vocab_size = len(data), len(chars)
  11. print('data has %d characters, %d unique.' % (data_size, vocab_size))
  12. char_to_ix = {ch: i for i, ch in enumerate(chars)}
  13. ix_to_char = {i: ch for i, ch in enumerate(chars)}
  14. # print(char_to_ix, ix_to_char)
  15. # hyperparameters
  16. hidden_size = 100 # size of hidden layer of neurons
  17. seq_length = 25 # number of steps to unroll the RNN for
  18. learning_rate = 1e-1
  19. # model parameters
  20. Wxh = np.random.randn(hidden_size, vocab_size) * 0.01 # input to hidden
  21. Whh = np.random.randn(hidden_size, hidden_size) * 0.01 # hidden to hidden
  22. Why = np.random.randn(vocab_size, hidden_size) * 0.01 # hidden to output
  23. bh = np.zeros((hidden_size, 1)) # hidden bias
  24. by = np.zeros((vocab_size, 1)) # output bias
  25. def lossFun(inputs, targets, hprev):
  26. """
  27. inputs,targets are both list of integers.
  28. hprev is Hx1 array of initial hidden state
  29. returns the loss, gradients on model parameters, and last hidden state
  30. """
  31. xs, hs, ys, ps = {}, {}, {}, {}
  32. hs[-1] = np.copy(hprev)
  33. loss = 0
  34. # forward pass
  35. for t in range(len(inputs)):
  36. xs[t] = np.zeros((vocab_size, 1)) # encode in 1-of-k representation
  37. xs[t][inputs[t]] = 1
  38. hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh,
  39. hs[t - 1]) + bh) # hidden state
  40. # unnormalized log probabilities for next chars
  41. ys[t] = np.dot(Why, hs[t]) + by
  42. # probabilities for next chars
  43. ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t]))
  44. loss += -np.log(ps[t][targets[t], 0]) # softmax (cross-entropy loss)
  45. # backward pass: compute gradients going backwards
  46. dWxh, dWhh, dWhy = np.zeros_like(
  47. Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  48. dbh, dby = np.zeros_like(bh), np.zeros_like(by)
  49. dhnext = np.zeros_like(hs[0])
  50. for t in reversed(range(len(inputs))):
  51. dy = np.copy(ps[t])
  52. # backprop into y. see
  53. # http://cs231n.github.io/neural-networks-case-study/#grad if confused
  54. # here
  55. dy[targets[t]] -= 1
  56. dWhy += np.dot(dy, hs[t].T)
  57. dby += dy
  58. dh = np.dot(Why.T, dy) + dhnext # backprop into h
  59. dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
  60. dbh += dhraw
  61. dWxh += np.dot(dhraw, xs[t].T)
  62. dWhh += np.dot(dhraw, hs[t - 1].T)
  63. dhnext = np.dot(Whh.T, dhraw)
  64. for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
  65. # clip to mitigate exploding gradients
  66. np.clip(dparam, -5, 5, out=dparam)
  67. return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs) - 1]
  68. def sample(h, seed_ix, n):
  69. """
  70. sample a sequence of integers from the model
  71. h is memory state, seed_ix is seed letter for first time step
  72. """
  73. x = np.zeros((vocab_size, 1))
  74. x[seed_ix] = 1
  75. ixes = []
  76. for t in range(n):
  77. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)
  78. y = np.dot(Why, h) + by
  79. p = np.exp(y) / np.sum(np.exp(y))
  80. ix = np.random.choice(range(vocab_size), p=p.ravel())
  81. x = np.zeros((vocab_size, 1))
  82. x[ix] = 1
  83. ixes.append(ix)
  84. return ixes
  85. n, p = 0, 0
  86. mWxh, mWhh, mWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  87. mbh, mby = np.zeros_like(bh), np.zeros_like(by) # memory variables for Adagrad
  88. smooth_loss = -np.log(1.0 / vocab_size) * seq_length # loss at iteration 0
  89. while True:
  90. # prepare inputs (we're sweeping from left to right in steps seq_length
  91. # long)
  92. if p + seq_length + 1 >= len(data) or n == 0:
  93. hprev = np.zeros((hidden_size, 1)) # reset RNN memory
  94. p = 0 # go from start of data
  95. inputs = [char_to_ix[ch] for ch in data[p:p + seq_length]]
  96. targets = [char_to_ix[ch] for ch in data[p + 1:p + seq_length + 1]]
  97. # sample from the model now and then
  98. if n % 100 == 0:
  99. sample_ix = sample(hprev, inputs[0], 200)
  100. txt = ''.join(ix_to_char[ix] for ix in sample_ix)
  101. print('----\n %s \n----' % (txt, ))
  102. # forward seq_length characters through the net and fetch gradient
  103. loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
  104. smooth_loss = smooth_loss * 0.999 + loss * 0.001
  105. if n % 100 == 0:
  106. print('iter %d, loss: %f' % (n, smooth_loss)) # print progress
  107. # perform parameter update with Adagrad
  108. for param, dparam, mem in zip([Wxh, Whh, Why, bh, by],
  109. [dWxh, dWhh, dWhy, dbh, dby],
  110. [mWxh, mWhh, mWhy, mbh, mby]):
  111. mem += dparam * dparam
  112. param += -learning_rate * dparam / \
  113. np.sqrt(mem + 1e-8) # adagrad update
  114. p += seq_length # move data pointer
  115. n += 1 # iteration counter
  116. if ((smooth_loss < 10) or (n >= 20000)):
  117. sample_ix = sample(hprev, inputs[0], 2000)
  118. txt = ''.join(ix_to_char[ix] for ix in sample_ix)
  119. predicted_names = set(txt.split("\n"))
  120. new_names = predicted_names - all_names
  121. print(new_names)
  122. print('predicted names len: %d, new_names len: %d.\n' % (len(predicted_names), len(new_names)))
  123. break

view rawmin-char-rnn.py hosted with ❤ by GitHub

然后从网上找了金庸小说的人名,做了些预处理,每行一个名字,保存到input.txt里,运行代码就可以了。古龙的没有找到比较全的名字, 只有这份武功排行榜,只有100多人。

下面是根据两份名单训练的结果,已经将完全一致的名字(比如段誉)去除了,所以下面的都是LSTM“新创作发明”的名字哈。来, 大家猜猜哪一个结果是金庸的, 哪一个是古龙的呢?

{'姜曾铁', '袁南兰', '石万奉', '郭万嗔', '蔡家', '程伯芷', '汪铁志', '陈衣', '薛铁','哈赤蔡师', '殷飞虹', '钟小砚', '凤一刀', '宝兰', '齐飞虹', '无若之', '王老英', '钟','钟百胜', '师', '李沅震', '曹兰', '赵一刀', '钟灵四', '宗家妹', '崔树胜', '桑飞西','上官公希轰', '刘之余人童怀道', '周云鹤', '天', '凤', '西灵素', '大智虎师', '阮徒忠','王兆能', '袁铮衣商宝鹤', '常伯凤', '苗人大', '倪不凤', '蔡铁', '无伯志', '凤一弼','曹鹊', '黄宾', '曾铁文', '姬胡峰', '李何豹', '上官铁', '童灵同', '古若之', '慕官景岳','崔百真', '陈官', '陈钟', '倪调峰', '妹沅刀', '徐双英', '任通督', '上官铁褚容', '大剑太','胡阳', '生', '南仁郑', '南调', '石双震', '海铁山', '殷鹤真', '司鱼督', '德小','若四', '武通涛', '田青农', '常尘英', '常不志', '倪不涛', '欧阳', '大提督', '胡玉堂','陈宝鹤', '南仁通四蒋赫侯'}

{'邀三', '熊猫开', '鹰星', '陆开', '花', '薛玉罗平', '南宫主', '南宫九', '孙夫人','荆董灭', '铁不愁', '裴独', '玮剑', '人', '陆小龙王紫无牙', '连千里', '仲先生','俞白', '方大', '叶雷一魂', '独孤上红', '叶怜花', '雷大归', '恕飞', '白双发','邀一郎', '东楼', '铁中十一点红', '凤星真', '无魏柳老凤三', '萧猫儿', '东郭先凤','日孙', '地先生', '孟摘星', '江小小凤', '花双楼', '李佩', '仇珏', '白坏刹', '燕悲情','姬悲雁', '东郭大', '谢晓陆凤', '碧玉伯', '司实三', '陆浪', '赵布雁', '荆孤蓝','怜燕南天', '萧怜静', '龙布雁', '东郭鱼', '司东郭金天', '薛啸天', '熊宝玉', '无莫静','柳罗李', '东官小鱼', '渐飞', '陆地鱼', '阿吹王', '高傲', '萧十三', '龙童', '玉罗赵','谢郎唐傲', '铁夜帝', '江小凤', '孙玉玉夜', '仇仲忍', '萧地孙', '铁莫棠', '柴星夫','展夫人', '碧玉', '老无鱼', '铁铁花', '独', '薛月宫九', '老郭和尚', '东郭大路陆上龙关飞','司藏', '李千', '孙白人', '南双平', '王玮', '姬原情', '东郭大路孙玉', '白玉罗生', '高儿','东珏天', '萧王尚', '九', '凤三静', '和空摘星', '关吹雪', '上官官小凤', '仇上官金飞','陆上龙啸天', '司空星魂', '邀衣人', '主', '李寻欢天', '东情', '玉夫随', '赵小凤', '东郭灭', '邀祟厚', '司空星'}

感兴趣的还可以用古代诗人、词人等的名字来做训练,大家机器好或者有时间的可以多训练下,训练得越多越准确。

总结

RNN由于具有记忆功能,在NLP、Speech、Computer Vision等诸多领域都展示了强大的力量。实际上,RNN是图灵等价的。

1 If training vanilla neural nets is optimization over functions, training recurrent nets is optimization over programs.

LSTM是一种目前相当常用和实用的RNN算法,主要解决了RNN的long-term dependencies问题。另外RNN也一直在产生新的研究,比如Attention机制。有空再介绍咯。。。

Refers

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

http://karpathy.github.io/2015/05/21/rnn-effectiveness/

https://www.zhihu.com/question/29411132

https://gist.github.com/karpathy/d4dee566867f8291f086

https://deeplearning4j.org/lstm.html

延伸阅读:人脸识别 + 手机推送,从此不再害怕老板背后偷袭!

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-04-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • RNN回顾
  • RNN问题:Long-Term Dependencies
  • LSTM
  • show me the code!
  • 总结
  • Refers
相关产品与服务
人脸识别
腾讯云神图·人脸识别(Face Recognition)基于腾讯优图强大的面部分析技术,提供包括人脸检测与分析、比对、搜索、验证、五官定位、活体检测等多种功能,为开发者和企业提供高性能高可用的人脸识别服务。 可应用于在线娱乐、在线身份认证等多种应用场景,充分满足各行业客户的人脸属性识别及用户身份确认等需求。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档