一个优雅的框架 | Pytorch 初体验

pytorch是啥呢?其实pytorch是一个python优先的深度学习框架,是一个和tensorflow,Caffe,MXnet一样,非常底层的框架,它的前身是torch,主要的语言接口是Lua,在如今github上前10的机器学习项目有9个都是python的时代,一直没有太多的人使用,比较小众。而pytorch如今重新归来,用python重写了整个框架,又重新回到了我的视线。

现在流行的深度学习框架都有着金主爸爸的支持,tensorflow是Google开发的,当然是他的官方框架,MXnet是Amazon的官方框架,那么pytorch后面站着的男人是谁呢?那就是Facebook了,其同样也只是Deep Learning领域的巨头,近期FAIR(Facebook Artificial Intelligence Research)也出了很多大作如mask rcnn,所以说pytorch背后的力量也是很大的。

说完了每个框架的支持者之外,我们来说说为什么我们还要学习不同的框架。首先在如今这个百花齐放的时代,任何一家公司想要独大都是不可能的,因为大家都意识到了这是一个随时可能爆发巨大变革的时代,所以每家大公司都希望自己能够在这场变革中扮演主导的地位,这就导致了不同的公司就会自己开发框架,或者至少不会使用竞争的公司的框架。在如今这个框架百出的时代,并没有哪个框架是最好的,每个框架都有各自的有点,比如tensorflow的工程能力很强,Theano特别适合科研等等,所以我们有必要掌握不同的框架,不要说精通每个框架,至少能够看看这个框架下的代码,因为github上不断地有牛人论文复现,而他们用的框架肯定不会都是一样的,所以你至少要能够阅读别人写的在各个框架下的代码。

说完了为什么要使用不同的框架之后,我们再来介绍一下今天的主角pytorch。之前我们介绍过keras,pytorch不同于keras,keras是一个很高层的结构,它的后端支持theano和tensorflow,它本质上并不是一个框架,只是对框架的操作做了一个封装,你在写keras的时候其实是对其后端进行调用,相当于你还是在tensorflow或者theano上跑程序,只不过你把你的语言交给keras处理了一下变成tensorflow听得懂的语言,然后再交给tensorflow处理,这样的后果当然方便你构建网络,方便定义模型做训练,极快的构建你的想法,工程实现很强,但是这样也有一个后果,那就是细节你没有办法把控,训练过程高度封装,导致你没有办法知道里面的具体细节,以及每个参数的具体细节,使得调试和研究变得很困难。

所以说作为初学者,我们可以用一个模块化的第三方插件帮助我们快速进入深度学习这个领域,但是如果我们真的想要好好去研究里面的问题,好好去做分析,我们还是需要用到我们的底层框架。

这个时候你就会说那我们就用tensorflow就好了啊,这不是最流行的框架吗。tensorflow确实是现在用的人最多的框架,不可否认,但是我们多掌握多了解一些框架也是有必要的,说不定你可以找到你最钟爱的那个框架呢。

相对tensorflow而言,pytorch就优雅多了,通过它的名字你就知道其对python支持特别好,虽然它的底层优化仍然实在c上的,但是它基本所有的框架都是用python写的,这就使得你去看它的源码比较简洁。但是它的缺点也和明显,就是框架刚刚发布没有多久,还没有太多人使用,文档也还在完善当中,但是也绝对够用了。有一个有好处就是你可以去官方论坛上面提问,基本上很快就有人回答了,这也算是新框架的一个好处吧,就是开发者对用户比较在意。

聊完了这么多好与不好,不知道你是不是动心了呢,是不是想学习pytorch了呢。如果你想学习pytorch,很简单,你直接去pytorch的官方教程就可以了,这是教程的链接 http://pytorch.org/tutorials/ 这里是是官方网站的连接 http://pytorch.org/ 最多1个小时,你就能入门了,比tensorflow简单太多了,如果你很牛逼,你还可以在pytorch的github开源项目上贡献你的代码,是不是很酷。这是pytorch的github主页 https://github.com/pytorch/pytorch 最后放上一段pytorch写的Lenet,可以和上一篇keras写的Lenet对比一下,看看有哪些差别。

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.autograd 
import Variablefrom torch 
import optimimport torch.nn as nn
import torch.nn.functional as F 
learning_rate = 1e-3
batch_size = 100
epoches = 50
trans_img = transforms.Compose([        
transforms.ToTensor()    ])
 trainset = MNIST('./data', train=True, transform=trans_img)
 testset = MNIST('./data', train=False, transform=trans_img) 
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)
# build network
class Lenet(nn.Module):    
def __init__(self):        
super(Lenet, self).__init__()        
self.conv = nn.Sequential(            
nn.Conv2d(1, 6, 3, stride=1, padding=1),            
nn.MaxPool2d(2, 2),           
 nn.Conv2d(6, 16, 5, stride=1, padding=0),           
 nn.MaxPool2d(2, 2)        
)  
 
self.fc = nn.Sequential(            
nn.Linear(400, 120),            
nn.Linear(120, 84),            
nn.Linear(84, 10)        
)   
 def forward(self, x):        
out = self.conv(x)        
out = out.view(out.size(0), -1)        
out = self.fc(out)        
return out 
lenet = Lenet() 
lenet.cuda() 
criterian = nn.CrossEntropyLoss(size_average=False)
optimizer = optim.SGD(lenet.parameters(), lr=learning_rate)
# trainfor 
i in range(epoches):    
running_loss = 0.    
running_acc = 0.    
for (img, label) in trainloader:        
img = Variable(img).cuda()       
label = Variable(label).cuda()        
optimizer.zero_grad()        
output = lenet(img)        
loss = criterian(output, label)        
# backward        
loss.backward()        
optimizer.step()        
running_loss += loss.data[0]       
 _, predict = torch.max(output, 1)       
 correct_num = (predict == label).sum()        
running_acc += correct_num.data[0]    
running_loss /= len(trainset)    
running_acc /= len(trainset)    
print("[%d/%d] Loss: %.5f, Acc: %.2f" %(i+1, epoches, running_loss, 100*running_acc))

这上面的代码定义了网络并进行了训练,下面是训练结果

训练结果

# evaluate
lenet.eval() 
testloss = 0.
testacc = 0.
for (img, label) in testloader:   
 img = Variable(img).cuda()    
label = Variable(label).cuda()    
output = lenet(img)    
loss = criterian(output, label)    
testloss += loss.data[0]    
_, predict = torch.max(output, 1)    
num_correct = (predict == label).sum()    
testacc += num_correct.data[0] 
testloss /= len(testset) 
testacc /= len(testset) 
print("Test: Loss: %.5f, Acc: %.2f %%" %(testloss, 100*testacc))

这是测试代码,以及测试结果

测试结果

本文代码已经上传到github上,这是传送门 https://github.com/SherlockLiao/lenet 欢迎访问的我的github主页 https://github.com/SherlockLiao

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2017-10-15

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器人网

七大工业机器人离线编程软件大PK

通常来讲,机器人编程可分为示教在线编程和离线编程。我们今天讲解的重点是离线编程,通过示教在线编程在实际应用中主要存在的问题,来说说机器人离线编程软件的优势和主流...

2896
来自专栏大数据文摘

快问快答 | 助教带你学习数据科学(附答疑视频领取)

1232
来自专栏灯塔大数据

学神之路 —— Python数据科学全攻略

Python菜鸟到Python Kaggler 如果你梦想成为一名数据科学家,或者已然是数据科学家的你想扩展自己的工具库,那么,你找对地方啦。本文旨在为做数据分...

2677
来自专栏生信技能树

一篇文章学会miRNA-seq分析

第一讲:文献选择与解读 前阵子逛BioStar论坛的时候看到了一个关于miRNA分析的问题,提问者从NCBI的SRA中下载文献提供的原始数据,然后处理的时候出现...

5036
来自专栏互联网数据官iCDO

不懂Google Featured Snippets?搜索引擎C位出道的机会别再错过了!

引言: 本文将教您如何针对Google最近的一项更新来进行内容优化,提升搜索排名。

963
来自专栏云市场·精选汇

批改孩子作业一小时?速算小程序一秒搞定!

学校里的课程作为父母插不上手,那辅导孩子写作业就非常关键了。跟让孩子主动写作业比起来,辅导家庭作业以及批阅检查孩子的作业才是让很多家长头疼的事情。一题一题的,虽...

1.1K42
来自专栏腾讯移动品质中心TMQ的专栏

YIYA语义测试方面总结探讨

1 产品介绍 YIYA是一个语音助手,根据用户输入语音内容,进行对应的操作或返回对应的结果,比如询问天气,返回所在地的天气结果。目前使用在微桌面及TOS手表中。...

2009
来自专栏AI科技大本营的专栏

想让视频网站乖乖帮你推内容?看看这位小哥是如何跟YouTube斗法的

编译 | AI科技大本营(rgznai100) 参与 | reason_W 当下视频网站的火热程度大家都是有目共睹的,因此也产生了一些网红视频博主,比如深受营长...

2503
来自专栏玉树芝兰

VOSviewer中文视频教程

因为这一篇文章,是我和几个研究生一起合作的。作者这一栏,最大可以写8个汉字。我让他们几个商议,选贡献度最高的2个人署名。结果他们头脑风暴的结果,就是起了个“V字...

491
来自专栏PPV课数据科学社区

数据可视化实践之美

开篇主要是介绍了一些常用的数据可视化工具和图表,让各位看官对数据可视化有一个较为全面的认识。后续篇章会深入介绍如何运用工具绘制精美图表的技术细节。 随着DT时代...

3296

扫描关注云+社区