Theano学习笔记(一)——scan函数

一 scan的介绍

       函数scan是Theano中迭代的一般形式,所以可以用于类似循环(looping)的场景。Reduction和map都是scan的特殊形式,即将某函数依次作用一个序列的每个元素上。但scan在计算的时候,可以访问以前n步的输出结果,所以比较适合RNN网络。        看起来scan完全可以用for… loop来代替,然而scan有其自身的优点:        ① Number of iterations to be part of the symbolic graph.        ② Minimizes GPU transfers (if GPU is involved).        ③ Computes gradients through sequential steps.     ④Slightly faster than using a for loop in Python with a compiled Theano function.        ⑤ Can lower the overall memory usage by detecting the actual amount of memory needed.

二 scan的一般形式

theano.scan()

results, updates = theano.scan(fn = lambda y, p, x_tm2, x_tm1,A: y+p+x_tm2+xtm1+A,
sequences=[Y, P[::-1]], 
outputs_info=[dict(initial=X, taps=[-2, -1])]), 
non_sequences=A)

其中的参数

fn:函数可以在外部定义好,也可以在内部再定义。在内部在定义的fn一般用lambda来定义需要用到的参数,在外部就def好的函数,fn直接函数名即可。        构造出描述一步迭代的输出的变量。同样还需要看成是 theano 的输入变量,表示输入序列的所有分片和过去的输出值,以及所有赋给 scan 的 non_sequences 的这些其他参数。

sequences:scan进行迭代的变量;序列是 Theano 变量或者字典的列表,告诉程序 scan 必须迭代的序列,scan会在T.arange()生成的list上遍历。        任何在 sequence 列表的 Theano 变量都会自动封装成一个字典,其 taps 被设置为 [0]

outputs_info:初始化fn的输出变量,描述了需要用到的初始化值,以及是否需要用到前几次迭代输出的结果,dict(initial=X, taps=[-2, -1])表示使用序列x作为初始化值,taps表示会用到前一次和前两次输出的结果。如果当前迭代输出为x(t),则计算中使用了(x(t-1)和x(t-2)。

non_sequences:fn函数用到的其他变量,迭代过程中不可改变(unchange),即A是一个固定的输入,每次迭代加的A都是相同的。如果Y是一个向量,A就是一个常数。总之,A比Y少一个维度。

n_steps:fn的迭代次数。

三 具体实例

例子1:

# A的K次方
k = T . iscalar('k')
A = T . vector( 'A')
outputs, updates = theano.scan(lambda result, A : result * A,
             non_sequences = A, outputs_info=T.ones_like(A), n_steps = k)
result = outputs [-1]
fn_Ak = theano . function([A,k ], result, updates=updates )
print fn_Ak( range(10 ), 2 )

Result:
[  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]

例子2:tanh(wx+b)

import theano
import theano.tensor as T
import numpy as np

# defining the tensor variables
X = T.matrix("X")
W = T.matrix("W")
b_sym = T.vector("b_sym")

results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function(inputs=[X, W, b_sym], outputs=[results])

# test values
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2

print(compute_elementwise(x, w, b)[0])

# comparison with numpy
print(np.tanh(x.dot(w) + b))

[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]
[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]

参考文献

  1. http://deeplearning.net/software/theano/library/scan.html#lib-scan-shared-variables
  2. http://deeplearning.net/software/theano/tutorial/loop.html

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Python攻城狮

Python-生成器1.什么是生成器2.创建生成器方法 3.send 4.实现多任务 5.迭代器 6.闭包

通过列表生成式,我们可以直接创建一个列表。但是,受到内存限制,列表容量肯定是有限的。而且,创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅...

851
来自专栏机器学习算法与Python学习

Python: numpy总结(2)

11、xrange 例子: for i in xrange(3): print i test=[1,2,3,4] print test...

3125
来自专栏人工智能LeadAI

Python生成器

通过列表生成式,我们可以直接创建一个列表。但是,受到内存限制,列表容量肯定是有限的。而且,创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅...

932
来自专栏人工智能LeadAI

Python数据分析模块 | pandas做数据分析(一):基本数据对象

pandas有两个最主要的数据结构,分别是Series和DataFrame,所以一开始的任务就是好好熟悉一下这两个数据结构。 1、Series 官方文档: pa...

3295
来自专栏王小雷

Python之数据规整化:清理、转换、合并、重塑

Python之数据规整化:清理、转换、合并、重塑 1. 合并数据集 pandas.merge可根据一个或者多个不同DataFrame中的行连接起来。 panda...

2026
来自专栏决胜机器学习

PHP数据结构(六) ——数组的相乘、广义表

PHP数据结构(六)——数组的相乘、广义表 (原创内容,转载请注明来源,谢谢) 本文接PHP数据结构(五)的内容。 4.2 行逻辑链接的顺序表 行逻辑链接的顺...

3559
来自专栏Fred Liang

Numpy

对数组运算相当于对数组每一个元素进行运算 a = np.arange(24).reshape((2,3,4))

712
来自专栏无所事事者爱嘲笑

常用的sort打乱数组方法真的有用?

1796
来自专栏数据小魔方

左手用R右手Python系列之——迭代器与迭代对象

接触过Python的小伙伴儿肯定都知道,Python中关于迭代器和可迭代对象运用的很广泛。迭代器可以以一种非常友好的方式使用在循环中,不仅节省内存,还能优化代码...

3538
来自专栏老秦求学

快速排序

快速排序: 设要排序的数组是A[0]……A[N-1], 思想:分治法(递归实现)关键是求出基准记录所在的位置(由于两个数之间进行交换,导致原来基准的位置发生改变...

2696

扫码关注云+社区