利用随机数种子来使pytorch中的结果可以复现

在神经网络中,参数默认是进行随机初始化的。不同的初始化参数往往会导致不同的结果,当得到比较好的结果时我们通常希望这个结果是可以复现的,在pytorch中,通过设置随机数种子也可以达到这么目的。

在百度如何设置随机数种子时,搜到的方法通常是:

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

自己在按照这种方法尝试后进行两次训练所得到的loss和误差都不同,结果并没有复现。

也搜过一些方法,比如设置参数:

torch.backends.cudnn.deterministic = True

但是在自己的网络中这样设置并没有用,依然得到不同的结果。

后面偶然在google中搜到有人在设置随机数种子时还加上了np.random.seed(SEED),经过尝试后发现结果是可复现的了。但检查自己网络的实现发现并没有直接调用numpy来产生随机数的地方,推测可能是pytorch内部调用了numpy的一些函数。去查看了一些pytorch中关于参数初始化的代码,比如normal的初始化:

点开source查看源码:

发现是调用了tensor.normal_函数,再去文档查看这个函数发现查看不了源码:

通过这些还是没能发现pytorch和numpy除了之前众所周知的接口外的内在联系,希望在以后的学习中随着对这两个库的理解与应用的深入能够了解,届时会对这篇文章做再次更新,毕竟知其然还要知其所以然嘛~

后面补充更新:在整理代码时,发现自己在处理数据时用上了这样一行:

data1 = data1.sample(frac=1).reset_index(drop=True) 

当时是用来打乱数据。这里是调用的pandas里面的方法,把这行代码注释掉再把np.random.seed(SEED)注释掉发现结果可以复现。可以推断是这里的随机需要给numpy也设置随机数种子。

如果没有涉及其他随机处理的话这两行可以固定pytorch中的随机数。

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏我和未来有约会

Silverlight第三方控件专题

这里我收集整理了目前网上silverlight第三方控件的专题,若果有所遗漏请告知我一下。 名称 简介 截图 telerik 商 RadC...

3985
来自专栏陈仁松博客

ASP.NET Core 'Microsoft.Win32.Registry' 错误修复

今天在发布Asp.net Core应用到Azure的时候出现错误InvalidOperationException: Cannot find compilati...

4818
来自专栏我和未来有约会

Kit 3D 更新

Kit3D is a 3D graphics engine written for Microsoft Silverlight. Kit3D was inita...

2516
来自专栏跟着阿笨一起玩NET

c#实现打印功能

2682
来自专栏张善友的专栏

Mix 10 上的asp.net mvc 2的相关Session

Beyond File | New Company: From Cheesy Sample to Social Platform Scott Hansel...

2537
来自专栏张善友的专栏

LINQ via C# 系列文章

LINQ via C# Recently I am giving a series of talk on LINQ. the name “LINQ via C...

2625
来自专栏闻道于事

js登录滑动验证,不滑动无法登陆

js的判断这里是根据滑块的位置进行判断,应该是用一个flag判断 <%@ page language="java" contentType="text/html...

6728
来自专栏大内老A

The .NET of Tomorrow

Ed Charbeneau(http://developer.telerik.com/featured/the-net-of-tomorrow/) Exciti...

31210
来自专栏一个爱瞎折腾的程序猿

sqlserver使用存储过程跟踪SQL

USE [master] GO /****** Object: StoredProcedure [dbo].[sp_perfworkload_trace_s...

2020
来自专栏Golang语言社区

【Golang语言社区】GO1.9 map并发安全测试

var m sync.Map //全局 func maintest() { // 第一个 YongHuomap := make(map[st...

4698

扫码关注云+社区