tensorflow学习笔记(三十九):双向rnn

tensorflow 双向 rnn

如何在tensorflow中实现双向rnn

单层双向rnn

单层双向rnn (cs224d) tensorflow中已经提供了双向rnn的接口,它就是tf.nn.bidirectional_dynamic_rnn(). 我们先来看一下这个接口怎么用.

bidirectional_dynamic_rnn(
    cell_fw, #前向 rnn cell
    cell_bw, #反向 rnn cell
    inputs, #输入序列.
    sequence_length=None,# 序列长度
    initial_state_fw=None,#前向rnn_cell的初始状态
    initial_state_bw=None,#反向rnn_cell的初始状态
    dtype=None,#数据类型
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

返回值:一个tuple(outputs, outputs_states), 其中,outputs是一个tuple(outputs_fw, outputs_bw). 关于outputs_fwoutputs_bw,如果time_major=True则它俩也是time_major的,vice versa. 如果想要concatenate的话,直接使用tf.concat(outputs, 2)即可.

如何使用: bidirectional_dynamic_rnn 在使用上和 dynamic_rnn是非常相似的.

  1. 定义前向和反向rnn_cell
  2. 定义前向和反向rnn_cell的初始状态
  3. 准备好序列
  4. 调用bidirectional_dynamic_rnn
import tensorflow as tf
from tensorflow.contrib import rnn
cell_fw = rnn.LSTMCell(10)
cell_bw = rnn.LSTMCell(10)
initial_state_fw = cell_fw.zero_state(batch_size)
initial_state_bw = cell_bw.zero_state(batch_size)
seq = ...
seq_length = ...
(outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq,
 seq_length, initial_state_fw,initial_state_bw)
out = tf.concat(outputs, 2)
# ....

多层双向rnn

多层双向rnn(cs224d)

单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell传给bidirectional_dynamic_rnn了. 想要知道为什么,我们需要看一下bidirectional_dynamic_rnn的源码片段.

with vs.variable_scope(scope or "bidirectional_rnn"):
  # Forward direction
  with vs.variable_scope("fw") as fw_scope:
    output_fw, output_state_fw = dynamic_rnn(
        cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
        initial_state=initial_state_fw, dtype=dtype,
        parallel_iterations=parallel_iterations, swap_memory=swap_memory,
        time_major=time_major, scope=fw_scope)

这只是一小部分代码,但足以看出,bi-rnn实际上是依靠dynamic-rnn实现的,如果我们使用MuitiRNNCell的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出

bidirectional_dynamic_rnn源码一探

上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现. 其实反向的过程就是做了两次reverse 1. 第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算. 2. 第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

def _reverse(input_, seq_lengths, seq_dim, batch_dim):
  if seq_lengths is not None:
    return array_ops.reverse_sequence(
        input=input_, seq_lengths=seq_lengths,
        seq_dim=seq_dim, batch_dim=batch_dim)
  else:
    return array_ops.reverse(input_, axis=[seq_dim])

with vs.variable_scope("bw") as bw_scope:
  inputs_reverse = _reverse(
      inputs, seq_lengths=sequence_length,
      seq_dim=time_dim, batch_dim=batch_dim)
  tmp, output_state_bw = dynamic_rnn(
      cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
      initial_state=initial_state_bw, dtype=dtype,
      parallel_iterations=parallel_iterations, swap_memory=swap_memory,
      time_major=time_major, scope=bw_scope)

output_bw = _reverse(
  tmp, seq_lengths=sequence_length,
  seq_dim=time_dim, batch_dim=batch_dim)

outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)

return (outputs, output_states)

tf.reverse_sequence

对序列中某一部分进行反转

reverse_sequence(
    input,#输入序列,将被reverse的序列
    seq_lengths,#1Dtensor,表示输入序列长度
    seq_axis=None,# 哪维代表序列
    batch_axis=None, #哪维代表 batch
    name=None,
    seq_dim=None,
    batch_dim=None
)

官网上的例子给的非常好,这里就直接粘贴过来:

# Given this:
batch_dim = 0
seq_dim = 1
input.dims = (4, 8, ...)
seq_lengths = [7, 2, 3, 5]

# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]

# while entries past seq_lens are copied through:
output[0, 7:, :, ...] = input[0, 7:, :, ...]
output[1, 2:, :, ...] = input[1, 2:, :, ...]
output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]

例二:

# Given this:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5]

# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]

# while entries past seq_lens are copied through:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]

参考资料

https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf https://www.tensorflow.org/api_docs/python/tf/reverse_sequence

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏上善若水

002计算机图形学之直线画线算法

主要思想是,由于我们在缓存区上画点,全部是整数。那么在画线的时候,当斜率k小于1的时候,下一个点是取(x+1,y+1)还是(x+1,y)取决于点(x+1,y+0...

1002
来自专栏腾讯NEXT学位

你需要知道的算法之基础篇

3777
来自专栏Python小屋

Python标准库itertools中函数精要

1、count() >>> import itertools >>> x = itertools.count(3) >>> x count(3) >>> for...

3488
来自专栏java一日一条

我是如何击败Java自带排序算法的

Java 8 对自带的排序算法进行了很好的优化。对于整形和其他的基本类型, Arrays.sort() 综合利用了双枢轴快速排序、归并排序和启发式插入排序。这个...

521
来自专栏AI研习社

如何准备机器学习工程师的面试?

本文给到的是相关具体可能会被问及的问题 (编程、基础算法、机器学习算法)。从本次关于算法工程师常见的九十个问题大多是各类网站的问题汇总,希望你能从中分析出一些端...

37616
来自专栏云霄雨霁

子字符串查找----各种算法总结

1800
来自专栏软件测试经验与教训

Python学习笔记(11)递归

3145
来自专栏新工科课程建设探讨——以能源与动力工程专业为例

5.1 一维导热算例

算例:一根长11m的铁棒,左侧温度100℃,右侧0℃,试计算其稳态温度场。我们将铁棒均匀分割成11段,每段1m长,假设截面积为1㎡。首先写出一维稳态常物...

820
来自专栏数据结构与算法

BZOJ4407: 于神之怒加强版(莫比乌斯反演 线性筛)

感觉好迷茫啊,很多变换看的一脸懵逼却又不知道去哪里学。一道题做一上午也是没谁了,,

982
来自专栏IT可乐

Java数据结构和算法(一)——简介

  本系列博客我们将学习数据结构和算法,为什么要学习数据结构和算法,这里我举个简单的例子。   编程好比是一辆汽车,而数据结构和算法是汽车内部的变速箱。一个开车...

2279

扫码关注云+社区