专栏首页机器学习算法与理论Tensorflow 之RNNinputs: shape = (batch_size, time_steps, input_size)cell: RNNCellinitial_state: shape

Tensorflow 之RNNinputs: shape = (batch_size, time_steps, input_size)cell: RNNCellinitial_state: shape

labels = tf.reshape(y, [-1]) 将矩阵变为一行

output = np.reshape(aa, -1)

output = np.reshape(aa, [-1,5]) -1表示一个占位符,分为5列。

  1. RNNcell:

它是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。

借助图片来说可能更容易理解。假设我们有一个初始状态h0,还有输入x1,调用call(x1, h0)后就可以得到(output1, h1):

1.jpg

再调用一次call(x2, h1)就可以得到(output2, h2):

2.jpg

[图片上传失败...(image-e4cb03-1533547159062)]

也就是说,每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。

两个子类:BasicRNNCell和BasicLSTMCell

state_size****:隐层的大小

output_size****:输出的大小

设输入数据的形状为(batch_size, input_size),那么计算时得到的隐层状态就是(batch_size, state_size),输出就是(batch_size, output_size)。

对于BasicLSTMCell,情况有些许不同,因为LSTM可以看做有两个隐状态h和c,对应的隐层就是一个Tuple,每个都是(batch_size, state_size)的形状。

tf.nn.dynamic_rnn:

RNNCELL是一次前进一步,如果我们的序列长度为10,就要调用10次call函数。

TensorFlow提供了一个tf.nn.dynamic_rnn函数:

设我们输入数据的格式为(batch_size, time_steps, input_size),其中time_steps表示序列本身的长度,如在Char RNN中,长度为10的句子对应的time_steps就等于10。最后的input_size就表示输入数据单个序列单个时间维度上固有的长度。另外我们已经定义好了一个RNNCell,调用该RNNCell的call函数time_steps次,对应的代码就是:

inputs: shape = (batch_size, time_steps, input_size)

cell: RNNCell

initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵

outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

得到的outputs就是time_steps步里所有的输出。它的形状为(batch_size, time_steps, cell.output_size)。state是最后一步的隐状态,它的形状为(batch_size, cell.state_size)。

堆叠RNNCell: MultiRNNCell

将x输入第一层RNN的后得到隐层状态h,这个隐层状态就相当于第二层RNN的输入,第二层RNN的隐层状态又相当于第三层RNN的输入,以此类推。在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函数对RNNCell进行堆叠,相应的示例程序如下:

在经典RNN结构中有这样的图:

3.jpg

通过MultiRNNCell得到的cell并不是什么新鲜事物,它实际也是RNNCell的子类,因此也有call方法、state_size和output_size属性。同样可以通过tf.nn.dynamic_rnn来一次运行多步。

在上面的代码中,我们好像有意忽略了调用call或dynamic_rnn函数后得到的output的介绍。找到源码中BasicRNNCell的call函数实现:

说明在BasicRNNCell中,output其实和隐状态的值是一样的。因此,我们还需要额外对输出定义新的变换,才能得到图中真正的输出y。由于output和隐状态是一回事,所以在BasicRNNCell中,state_size永远等于output_size。TensorFlow是出于尽量精简的目的来定义BasicRNNCell的,所以省略了输出参数,我们这里一定要弄清楚它和图中原始RNN定义的联系与区别。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 基于TP-GAN的侧脸人像恢复

    中科院自动化所(CASIA),中科院大学和南昌大学的一项合作研究,提出了双路径 GAN(TP-GAN),通过单一侧面照片合成正面人脸图像,取得了当前较好的结果。...

    微风、掠过
  • 《白话深度学习与Tensorflow》学习笔记(5)强化学习(reinforcement learning)

    强化学习(reinforcement learning)本身是一种人工智能在训练中得到策略的训练过程。 有这样一种比喻:如果你教一个孩子学古筝,他可以躺着,趴着...

    微风、掠过
  • 【AAAI 2020】RiskOracle: 一种时空细粒度交通事故预测方法

    【前言】城市计算领域中,智能交通、智慧出行一直是一备受关注的话题,而交通事故在交通中扮演越来越着重要的角色,据WHO统计,已逐渐成为人类第8大杀手。传统的基础交...

    微风、掠过
  • 使用pytorch进行文本分类——ADGCNN

    在文本分类任务中常用的网络是RNN系列或Transformer的Encoder,很久没有看到CNN网络的身影(很久之前有TextCNN网络)。本文尝试使用CNN...

    Dendi
  • 通过一个简单的ABAP报表窥探ABAP内存分配和管理机制

    Jerry Wang
  • 二叉树的深度

    题目描述 输入一棵二叉树,求该树的深度。从根结点到叶结点依次经过的结点(含根、叶结点)形成树的一条路径,最长路径的长度为树的深度。 代码实现 递归实现 # ...

    致Great
  • Pyqt phonon的使用

    Qt phonon地址:http://wenku.baidu.com/link?url=nH_dZ8lZbXHy8N5__8jAWLXcuMYf4yRjdCK...

    bear_fish
  • Hadoop数据分析平台实战——220项目结构整体概述离线数据分析平台实战——220项目结构整体概述

    离线数据分析平台实战——220项目结构整体概述 数据展示系统(bf_dataapi)总述 bf_dataapi项目的主要目标有两个: 第一个目标就是我们需要提...

    Albert陈凯
  • Python学习 :深浅拷贝

    只拷贝第一层数据(不可变的数据类型),并创建新的内存空间进行储蓄,例如:字符串、整型、布尔

    py3study
  • C++17 fold expression

    C++11增加了一个新特性变参模板(variadic template),它可以接受任意个模版参数,参数包不能直接展开,需要通过一些特殊的方法,比如函数参数包的...

    Dabelv

扫码关注云+社区

领取腾讯云代金券