利用随机数种子来使pytorch中的结果可以复现

在神经网络中,参数默认是进行随机初始化的。不同的初始化参数往往会导致不同的结果,当得到比较好的结果时我们通常希望这个结果是可以复现的,在pytorch中,通过设置随机数种子也可以达到这么目的。

在百度如何设置随机数种子时,搜到的方法通常是:

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

自己在按照这种方法尝试后进行两次训练所得到的loss和误差都不同,结果并没有复现。

也搜过一些方法,比如设置参数:

torch.backends.cudnn.deterministic = True

但是在自己的网络中这样设置并没有用,依然得到不同的结果。

后面偶然在google中搜到有人在设置随机数种子时还加上了np.random.seed(SEED),经过尝试后发现结果是可复现的了。但检查自己网络的实现发现并没有直接调用numpy来产生随机数的地方,推测可能是pytorch内部调用了numpy的一些函数。去查看了一些pytorch中关于参数初始化的代码,比如normal的初始化:

点开source查看源码:

发现是调用了tensor.normal_函数,再去文档查看这个函数发现查看不了源码:

通过这些还是没能发现pytorch和numpy除了之前众所周知的接口外的内在联系,希望在以后的学习中随着对这两个库的理解与应用的深入能够了解,届时会对这篇文章做再次更新,毕竟知其然还要知其所以然嘛~

后面补充更新:在整理代码时,发现自己在处理数据时用上了这样一行:

data1 = data1.sample(frac=1).reset_index(drop=True) 

当时是用来打乱数据。这里是调用的pandas里面的方法,把这行代码注释掉再把np.random.seed(SEED)注释掉发现结果可以复现。可以推断是这里的随机需要给numpy也设置随机数种子。

如果没有涉及其他随机处理的话这两行可以固定pytorch中的随机数。

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏嵌入式程序猿

运算放大器使用必须遵循的六条军规

运算放大器是作为最通用的模拟器件,广泛用于信号变换调理、ADC采样前端、电源电路等场合中。虽然运放外围电路简单,不过在使用过程中还是有很多需要注意的地方。 1...

3616
来自专栏专知

【干货】快速上手图像识别:用TensorFlow API实现图像分类实例

【导读】1月17日,Arduino社区的编辑SAGAR SHARMA发布一篇基于TensorFlow API的图像识别实例教程。作者通过TensorFlow A...

9047
来自专栏IT派

干掉照片中那些讨厌的家伙!Mask R-CNN助你一键“除”人!

【导读】:看过英剧《黑镜》吗?圣诞特别版《白色圣诞节》中有这样一个场景:其中一个未来科技有自由屏蔽人像的功能,可以让你屏蔽任何一个不想看见或不喜欢的人,然后留下...

1190
来自专栏简书专栏

基于xgboost的波士顿房价预测kaggle实战

2018年8月24日笔记 这是作者在波士顿房价预测项目的第3篇文章,在查看此篇文章之前,请确保已经阅读前2篇文章。 第2篇文章链接:https://www....

2K3
来自专栏Java进阶架构师

dubbo源码解析-详解LoadBalance

终于到了集群容错中的最后一个关键词,也就是LoadBalance(负载均衡),负载均衡必然会涉及一些算法.但是也不用太担心,算法这个词虽然高大上,但是算法也有简...

1643
来自专栏小巫技术博客

Python 中文图片OCR

6713
来自专栏专知

【最新TensorFlow1.4.0教程01】TF1.4.0介绍与动态图机制 Eager Execution使用

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视...

4008
来自专栏ATYUN订阅号

Machine Box创始人教你快速建立一个ML图像分类器

AiTechYun 编辑:Yining Machine Box的创始人Mat Ryer在medium上分享了一篇博文,意在教你在硬盘上快速的建立一个机器学习图像...

3646
来自专栏ATYUN订阅号

自定义对象检测问题:使用TensorFlow追踪星球大战中的千年隼号宇宙飞船

大多数的大型科技公司(如IBM,谷歌,微软,亚马逊)都有易于使用的视觉识别API。一些规模较小的公司也提供类似的产品,如Clarifai。但没有公司能够提供对象...

4835
来自专栏芋道源码1024

Dubbo 源码解析 —— LoadBalance

前言 终于到了集群容错中的最后一个关键词,也就是 LoadBalance(负载均衡),负载均衡必然会涉及一些算法.但是也不用太担心,算法这个词虽然高大上,但是算...

3994

扫码关注云+社区

领取腾讯云代金券