前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >mxnet RNN简单剖析官方github教程部分代码

mxnet RNN简单剖析官方github教程部分代码

作者头像
月见樽
发布2018-04-27 12:12:52
1.2K0
发布2018-04-27 12:12:52
举报
代码语言:javascript
复制
import mxnet as mx

官方github教程部分代码

网络生成

代码语言:javascript
复制
num_layers = 2
num_hidden = 256
stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
    stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
  • mx.rnn.SequentialRNNCell():RNN容器,用于组合多个RNN层
  • mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i):LSTM单元
代码语言:javascript
复制
num_embed = 256
def sym_gen(seq_len):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=1000,output_dim=num_embed, name='embed')
#   数据生成,定义Variable并进行词向量化

    stack.reset()
    outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)
#   按时间展开输出和状态
    
    pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
    pred = mx.sym.FullyConnected(data=pred, num_hidden=1000, name='pred')
#   变换输出形式,将输出变为(-1,num_hidden)尺寸

    label = mx.sym.Reshape(label, shape=(-1,))
    pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')
#   展平label,并计算代价函数
    
    return pred, ('data',), ('softmax_label',)
sym_gen(1)
代码语言:javascript
复制
(<Symbol softmax>, ('data',), ('softmax_label',))
  • unroll()函数按时间展开RNN单元,输出最终的运算结果
  • 输出接全连接层,再转换为词向量

官方API文档代码

数据转换

代码语言:javascript
复制
step_input = mx.symbol.Variable('step_data')

# First we embed our raw input data to be used as LSTM's input.
embedded_step = mx.symbol.Embedding(data=step_input, \
                                    input_dim=50, \
                                    output_dim=50)
# print(embedded_step.shape)
mx.viz.plot_network(symbol=embedded_step)
# Then we create an LSTM cell.

output_7_0.png

Embedding是一种词向量化技术,这种技术可以保持语义(例如相近语义的词的向量距离会较近),将尺寸为(d0,d1...dn)的输入向量进行词向量化技术后转换为尺寸为(d0,d1,...,dn,out_dim)的向量,多出的一维为词向量,即使用一个向量代替原来一个词的位置。

  • 参数input_dim为输入向量的范围,即输入data的范围在[0,input_dim)之间
  • 参数output_dim为词向量大小
  • 可选参数weight,可传入指定的词向量字典
  • 可选参数name,可传入名称
代码语言:javascript
复制
vocabulary_size = 26
embed_dim = 16
seq_len, batch_size = (10, 64)
input = mx.sym.Variable('letters')
op = mx.sym.Embedding(data=input, input_dim=vocabulary_size, output_dim=embed_dim,name='embed')
op.infer_shape(letters=(seq_len, batch_size))
代码语言:javascript
复制
([(10, 64), (26, 16)], [(10, 64, 16)], [])

上文的例子可以看出输入向量尺寸为(10,64),输出向量尺寸变为了(10,64,16)

网络构建

使用了隐层为50的LSTM单元,并带入转换好的数据,该图绘制出的lstm图较经典LSTM有一些出入

代码语言:javascript
复制
lstm_cell = mx.rnn.LSTMCell(num_hidden=50)
begin_state = lstm_cell.begin_state()
output, states = lstm_cell(embedded_step, begin_state)
mx.viz.plot_network(symbol=output)

output_11_0.png

LSTM的源码的构造函数如下:

代码语言:javascript
复制
def __init__(self, num_hidden, prefix='lstm_', params=None, forget_bias=1.0):
        super(LSTMCell, self).__init__(prefix=prefix, params=params)

        self._num_hidden = num_hidden
        self._iW = self.params.get('i2h_weight')
        self._hW = self.params.get('h2h_weight')
        # we add the forget_bias to i2h_bias, this adds the bias to the forget gate activation
        self._iB = self.params.get('i2h_bias', init=init.LSTMBias(forget_bias=forget_bias))
        self._hB = self.params.get('h2h_bias')

其中:self.params.get()方法为尝试找到传入名称对应的Variable,若找不到则新建,因此该LSTM单元一共仅有两对参数:iW和iB,hW和hB

前向传播函数如下:

代码语言:javascript
复制
    def __call__(self, inputs, states):
        self._counter += 1
        name = '%st%d_'%(self._prefix, self._counter)
        i2h = symbol.FullyConnected(data=inputs, weight=self._iW, bias=self._iB,
                                    num_hidden=self._num_hidden*4,
                                    name='%si2h'%name)
        h2h = symbol.FullyConnected(data=states[0], weight=self._hW, bias=self._hB,
                                    num_hidden=self._num_hidden*4,
                                    name='%sh2h'%name)
        gates = i2h + h2h
        slice_gates = symbol.SliceChannel(gates, num_outputs=4,name="%sslice"%name)
        in_gate = symbol.Activation(slice_gates[0], act_type="sigmoid",name='%si'%name)
        forget_gate = symbol.Activation(slice_gates[1], act_type="sigmoid",name='%sf'%name)
        in_transform = symbol.Activation(slice_gates[2], act_type="tanh",name='%sc'%name)
        out_gate = symbol.Activation(slice_gates[3], act_type="sigmoid",name='%so'%name)
        next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,name='%sstate'%name)
        next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"),name='%sout'%name)
        return next_h, [next_h, next_c]

可以看出,LSTM的实现过程如下所示

  1. 计算隐层输入与状态,隐层的channel数量是配置的hidden_num的四倍
  2. 将隐层输入结果和隐层状态相加,并按channel数量切分为4份
    • 第一份作为输入门层,经过sigmoid函数
    • 第二份作为忘记门层,经过sigmoid函数
    • 第三份作为输入转换层,经过tanh函数
    • 第四份作为输出门层,经过sigmoid函数
  3. 产生输出
    • 输出状态为忘记门层乘状态的一部分加输入门层乘输入转换层
    • 输出结果为输出状态经过tanh乘输出门层

结果生成

代码语言:javascript
复制
sequence_length = 10
input_dim = 10
seq_input = mx.symbol.Variable('seq_data')
embedded_seq = mx.symbol.Embedding(data=seq_input, \
                                   input_dim=input_dim, \
                                   output_dim=embed_dim)
outputs, states = lstm_cell.unroll(length=sequence_length, \
                                   inputs=embedded_seq, \
                                   layout='NTC', \
                                   merge_outputs=True)

使用unroll方法按时间展平运算,输入数据为(batch_size,lenght,...)(layout="NTC)或(lenght,batch,...)(layout="TNC)

该函数的源码为:

代码语言:javascript
复制
def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()

        inputs, _ = _normalize_sequence(length, inputs, layout, False)
        if begin_state is None:
            begin_state = self.begin_state()
        states = begin_state
        outputs = []
        for i in range(length):
            output, states = self(inputs[i], states)
            outputs.append(output)
        outputs, _ = _normalize_sequence(length, outputs, layout, merge_outputs)
        return outputs, states

方法_normalize_sequence是对输入做一些处理,由一个for循环可以看出该方法循环了网络运算

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2018.01.31 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 官方github教程部分代码
    • 网络生成
      • 官方API文档代码
        • 数据转换
        • 网络构建
        • 结果生成
    相关产品与服务
    容器服务
    腾讯云容器服务(Tencent Kubernetes Engine, TKE)基于原生 kubernetes 提供以容器为核心的、高度可扩展的高性能容器管理服务,覆盖 Serverless、边缘计算、分布式云等多种业务部署场景,业内首创单个集群兼容多种计算节点的容器资源管理模式。同时产品作为云原生 Finops 领先布道者,主导开源项目Crane,全面助力客户实现资源优化、成本控制。
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档