Theano学习笔记(一)——scan函数

一 scan的介绍

       函数scan是Theano中迭代的一般形式,所以可以用于类似循环(looping)的场景。Reduction和map都是scan的特殊形式,即将某函数依次作用一个序列的每个元素上。但scan在计算的时候,可以访问以前n步的输出结果,所以比较适合RNN网络。        看起来scan完全可以用for… loop来代替,然而scan有其自身的优点:        ① Number of iterations to be part of the symbolic graph.        ② Minimizes GPU transfers (if GPU is involved).        ③ Computes gradients through sequential steps.     ④Slightly faster than using a for loop in Python with a compiled Theano function.        ⑤ Can lower the overall memory usage by detecting the actual amount of memory needed.

二 scan的一般形式

theano.scan()

results, updates = theano.scan(fn = lambda y, p, x_tm2, x_tm1,A: y+p+x_tm2+xtm1+A,
sequences=[Y, P[::-1]], 
outputs_info=[dict(initial=X, taps=[-2, -1])]), 
non_sequences=A)

其中的参数

fn:函数可以在外部定义好,也可以在内部再定义。在内部在定义的fn一般用lambda来定义需要用到的参数,在外部就def好的函数,fn直接函数名即可。        构造出描述一步迭代的输出的变量。同样还需要看成是 theano 的输入变量,表示输入序列的所有分片和过去的输出值,以及所有赋给 scan 的 non_sequences 的这些其他参数。

sequences:scan进行迭代的变量;序列是 Theano 变量或者字典的列表,告诉程序 scan 必须迭代的序列,scan会在T.arange()生成的list上遍历。        任何在 sequence 列表的 Theano 变量都会自动封装成一个字典,其 taps 被设置为 [0]

outputs_info:初始化fn的输出变量,描述了需要用到的初始化值,以及是否需要用到前几次迭代输出的结果,dict(initial=X, taps=[-2, -1])表示使用序列x作为初始化值,taps表示会用到前一次和前两次输出的结果。如果当前迭代输出为x(t),则计算中使用了(x(t-1)和x(t-2)。

non_sequences:fn函数用到的其他变量,迭代过程中不可改变(unchange),即A是一个固定的输入,每次迭代加的A都是相同的。如果Y是一个向量,A就是一个常数。总之,A比Y少一个维度。

n_steps:fn的迭代次数。

三 具体实例

例子1:

# A的K次方
k = T . iscalar('k')
A = T . vector( 'A')
outputs, updates = theano.scan(lambda result, A : result * A,
             non_sequences = A, outputs_info=T.ones_like(A), n_steps = k)
result = outputs [-1]
fn_Ak = theano . function([A,k ], result, updates=updates )
print fn_Ak( range(10 ), 2 )

Result:
[  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]

例子2:tanh(wx+b)

import theano
import theano.tensor as T
import numpy as np

# defining the tensor variables
X = T.matrix("X")
W = T.matrix("W")
b_sym = T.vector("b_sym")

results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function(inputs=[X, W, b_sym], outputs=[results])

# test values
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2

print(compute_elementwise(x, w, b)[0])

# comparison with numpy
print(np.tanh(x.dot(w) + b))

[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]
[[ 0.96402758  0.99505475]
 [ 0.96402758  0.99505475]]

参考文献

  1. http://deeplearning.net/software/theano/library/scan.html#lib-scan-shared-variables
  2. http://deeplearning.net/software/theano/tutorial/loop.html

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏计算机视觉与深度学习基础

Leetcode 31 Next Permutation

Implement next permutation, which rearranges numbers into the lexicographically...

1875
来自专栏xiaoxi666的专栏

和为0的最长连续子数组【转载+优化代码】

题目描述和思路来自博客:http://www.cnblogs.com/coding-wtf/p/5849222.html,在此表示感谢。

472
来自专栏desperate633

LintCode 完美平方题目分析代码

给一个正整数 n, 找到若干个完全平方数(比如1, 4, 9, ... )使得他们的和等于 n。你需要让平方数的个数最少。

722
来自专栏武培轩的专栏

Leetcode#561. Array Partition I(数组拆分 I)

给定长度为 2n 的数组, 你的任务是将这些数分成 n 对, 例如 (a1, b1), (a2, b2), ..., (an, bn) ,使得从1 到 n 的 ...

1222
来自专栏我是业余自学C/C++的

原 三对角矩阵

1393
来自专栏赵俊的Java专栏

恢复旋转排序数组

1212
来自专栏老秦求学

题目1054:字符串内排序

题目描述: 输入一个字符串,长度小于等于200,然后将输出按字符顺序升序排序后的字符串。 输入: 测试数据有多组,输入字符串。 输出: 对于每组输入,输出处理后...

3287
来自专栏前端儿

ASCII码排序

输入第一行输入一个数N,表示有N组测试数据。后面的N行输入多组数据,每组输入数据都是占一行,有三个字符组成,之间无空格。输出对于每组输入数据,输出一行,字符中间...

1272
来自专栏我的博客

C编程笔记

1.编译命令gcc test.c -o test 带上参数o就是指定编译文件名 2.printf(“%.2lf”,b) 其中前面2是小数点后位数,l是字母...

3395
来自专栏King_3的技术专栏

leetcode-46-全排列

vector<vector<int>> permute(vector<int>& nums) 

1313

扫码关注云+社区