tensorflow学习笔记(三十九):双向rnn

tensorflow 双向 rnn

如何在tensorflow中实现双向rnn

单层双向rnn

单层双向rnn (cs224d) tensorflow中已经提供了双向rnn的接口,它就是tf.nn.bidirectional_dynamic_rnn(). 我们先来看一下这个接口怎么用.

bidirectional_dynamic_rnn(
    cell_fw, #前向 rnn cell
    cell_bw, #反向 rnn cell
    inputs, #输入序列.
    sequence_length=None,# 序列长度
    initial_state_fw=None,#前向rnn_cell的初始状态
    initial_state_bw=None,#反向rnn_cell的初始状态
    dtype=None,#数据类型
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

返回值:一个tuple(outputs, outputs_states), 其中,outputs是一个tuple(outputs_fw, outputs_bw). 关于outputs_fwoutputs_bw,如果time_major=True则它俩也是time_major的,vice versa. 如果想要concatenate的话,直接使用tf.concat(outputs, 2)即可.

如何使用: bidirectional_dynamic_rnn 在使用上和 dynamic_rnn是非常相似的.

  1. 定义前向和反向rnn_cell
  2. 定义前向和反向rnn_cell的初始状态
  3. 准备好序列
  4. 调用bidirectional_dynamic_rnn
import tensorflow as tf
from tensorflow.contrib import rnn
cell_fw = rnn.LSTMCell(10)
cell_bw = rnn.LSTMCell(10)
initial_state_fw = cell_fw.zero_state(batch_size)
initial_state_bw = cell_bw.zero_state(batch_size)
seq = ...
seq_length = ...
(outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq,
 seq_length, initial_state_fw,initial_state_bw)
out = tf.concat(outputs, 2)
# ....

多层双向rnn

多层双向rnn(cs224d)

单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell传给bidirectional_dynamic_rnn了. 想要知道为什么,我们需要看一下bidirectional_dynamic_rnn的源码片段.

with vs.variable_scope(scope or "bidirectional_rnn"):
  # Forward direction
  with vs.variable_scope("fw") as fw_scope:
    output_fw, output_state_fw = dynamic_rnn(
        cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
        initial_state=initial_state_fw, dtype=dtype,
        parallel_iterations=parallel_iterations, swap_memory=swap_memory,
        time_major=time_major, scope=fw_scope)

这只是一小部分代码,但足以看出,bi-rnn实际上是依靠dynamic-rnn实现的,如果我们使用MuitiRNNCell的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出

bidirectional_dynamic_rnn源码一探

上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现. 其实反向的过程就是做了两次reverse 1. 第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算. 2. 第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

def _reverse(input_, seq_lengths, seq_dim, batch_dim):
  if seq_lengths is not None:
    return array_ops.reverse_sequence(
        input=input_, seq_lengths=seq_lengths,
        seq_dim=seq_dim, batch_dim=batch_dim)
  else:
    return array_ops.reverse(input_, axis=[seq_dim])

with vs.variable_scope("bw") as bw_scope:
  inputs_reverse = _reverse(
      inputs, seq_lengths=sequence_length,
      seq_dim=time_dim, batch_dim=batch_dim)
  tmp, output_state_bw = dynamic_rnn(
      cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
      initial_state=initial_state_bw, dtype=dtype,
      parallel_iterations=parallel_iterations, swap_memory=swap_memory,
      time_major=time_major, scope=bw_scope)

output_bw = _reverse(
  tmp, seq_lengths=sequence_length,
  seq_dim=time_dim, batch_dim=batch_dim)

outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)

return (outputs, output_states)

tf.reverse_sequence

对序列中某一部分进行反转

reverse_sequence(
    input,#输入序列,将被reverse的序列
    seq_lengths,#1Dtensor,表示输入序列长度
    seq_axis=None,# 哪维代表序列
    batch_axis=None, #哪维代表 batch
    name=None,
    seq_dim=None,
    batch_dim=None
)

官网上的例子给的非常好,这里就直接粘贴过来:

# Given this:
batch_dim = 0
seq_dim = 1
input.dims = (4, 8, ...)
seq_lengths = [7, 2, 3, 5]

# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]

# while entries past seq_lens are copied through:
output[0, 7:, :, ...] = input[0, 7:, :, ...]
output[1, 2:, :, ...] = input[1, 2:, :, ...]
output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]

例二:

# Given this:
batch_dim = 2
seq_dim = 0
input.dims = (8, ?, 4, ...)
seq_lengths = [7, 2, 3, 5]

# then slices of input are reversed on seq_dim, but only up to seq_lengths:
output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]

# while entries past seq_lens are copied through:
output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]

参考资料

https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf https://www.tensorflow.org/api_docs/python/tf/reverse_sequence

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

素数的筛法

素数的筛法有很多种 在此给出常见的三种方法 以下给出的所有代码均已通过这里的测试 埃拉托斯特尼筛法 名字好长 :joy:  不过代码很短 思路非常简单,对于每一...

3416
来自专栏ml

由判断三一点是否在三角形内部而引发的思考.....

判断一个点是否在三角形里面(包括边界上),这个问题对于许多初学者来说,可谓是一头雾水,如何判断呢? 假如有四个点A(x0,y0),B(x1,y1),C(x2,y...

2758
来自专栏小詹同学

Leetcode打卡 | No.016 最接近的三数之和

欢迎和小詹一起定期刷leetcode,每周一和周五更新一题,每一题都吃透,欢迎一题多解,寻找最优解!这个记录帖哪怕只有一个读者,小詹也会坚持刷下去的!

1184
来自专栏尾尾部落

[剑指offer] 矩阵中的路径

请设计一个函数,用来判断在一个矩阵中是否存在一条包含某字符串所有字符的路径。路径可以从矩阵中的任意一个格子开始,每一步可以在矩阵中向左,向右,向上,向下移动一个...

873
来自专栏机器学习算法全栈工程师

一道网易笔试题引发的血案……

作者:柳行刚 编辑:黄俊嘉 网易的2016年笔试题,题目经典。 所以特地找来给各位有兴趣的童鞋看看, 有详细的解题思路以及代码喔~ 各位,请看题! 题目描述: ...

39612
来自专栏数据结构与算法

洛谷P3796 【模板】AC自动机(加强版)

1213
来自专栏和蔼的张星的图像处理专栏

28. 搜索二维矩阵二分法

写出一个高效的算法来搜索 m × n矩阵中的值。 这个矩阵具有以下特性: 每行中的整数从左到右是排序的。 每行的第一个数大于上一行的最后一个整数。 样例...

512
来自专栏小樱的经验随笔

从零开始学算法:高精度计算

前言:由于计算机运算是有模运算,数据范围的表示有一定限制,如整型int(C++中int 与long相同)表达范围是(-2^31~2^31-1),unsigned...

33713
来自专栏数据结构与算法

字符串hash入门

简单介绍一下字符串hash 相信大家对于hash都不陌生 hash算法广泛应用于计算机的各类领域,像什么md5,文件效验,磁力链接 等等都会用到hash算法 在...

3195
来自专栏冰霜之地

神奇的德布鲁因序列

数学中存在这样一个序列,它充满魔力,在实际工程中也有一部分的应用。今天就打算分享一下这个序列,它在 Google S2 中是如何使用的以及它在图论中,其他领域中...

1023

扫码关注云+社区