首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

PyTorch快速使用介绍–分类的实现

微信公众号:文艺数学君

关注可了解更多的教程及。问题或建议,欢迎公众号留言;

如果你觉得文艺数学君对你有帮助,欢迎转载评论

前言

特别说明

阅读正文之前一定要仔细阅读下边这些建议:

一些其他话

我发现好像发这种技术类的文章就没什么人看,发一些感想看的人就很多。但是,我还是要发。之后的话两者都会有更新,其实技术类的还是比较喜欢发在网站上,如果不想发公众号的话也会在公众号里提及一下的。

如果从文艺君的域名注册时间开始算起,8月8日就是文艺君一周年的生日了(公众号的话会更早一些,之前是在csdn上更新一些文章),之后希望可以越做越好,也是希望可以慢慢把一个系列做完,每周都可以定期更新一下。

下面就开始正式讲关于PyTorch的内容,我们这次来讲一下分类,使用PyTorch来搭建。

这一次介绍的四个模型分别是会比较两种激活函数,ReLU和Sigmoid,模型二和模型三;同时会比较不同层数的区别,模型一与模型二;同时会看一种数据增维的方法,模型四。

视频

每篇文章一个视频,这一篇也放一个之前在校园里拍的视频。(吐槽一下这个视频上传大小只能是20M,很是不方便,每次上传要压缩一遍)

实现分类

数据准备

我们首先使用sklearn中的datasets生成我们需要的数据;

我们可以画图看一下数据的样子:

从上面的图像可以看到,我们的数据呈现月牙的形状,这种形状的数据直接使用k-means聚类是不能区分的,下面我们来搭建神经网络来实现以下数据的分类;

我们首先把要使用到的库进行导入:

网络一(使用ReLU作为激活函数、三层)

下面开始介绍第一个模型,由于后面每个模型的代码内容基本差不多,即定义优化器,损失函数,训练绘图那里,就网络定义的地方有点区别,我们就着重先讲网络的结构,后面就贴一下全部的代码:

上面是我们这里使用的网络的结构,输入是2是因为数据有x和y,最后输出也是2是因为分为两类,整个网络一共有三层,下面贴一下完整的代码,注释已经写得很详细了,可以复制下来仔细看一下;

我们来看一下上面代码的输出,左边是训练集上的结果,右边是测试集上的结果:

可以看到最终可以全部分开,我们再看一下loss的情况:

网络二(使用ReLU作为激活函数、五层)

这个网络和上面的区别在于网络层数变深了,我们看一下这时候的情况,还是先看一下网络的定义:

和上面的差不多,就是变深了,看一下完整的代码:

还是看一下训练过程中结果的变化,可以看到收敛还是很快的:

接下来可以看一下loss的变化情况,和上面的模型比起来,loss下降更加快了。

网络三(使用Sigmoid作为激活函数、五层)

这个网络我们使用Sigmoid作为激活函数来尝试一下,其他的不变。我们还是先把网络结构打印出来:

接下来看一下完整的代码:

看一下这个训练的动图,可以看到epoch是一直在变化的,前面模型收敛的比较慢:

下面我们看一下loss的图,来看一下上面的想法是否正确:

可以看到使用Sigmoid作为激活函数确实会收敛的慢一些,但是也不是所有的都适合用ReLU,有很多情况我也不确定,具体用的时候可以看一下别人写的论文之类的。

网络四(使用Sigmoid作为激活函数、五层、扩充数据集)

这个模型我们会重点介绍一下数据集扩展的方式,我们之前使用的数据集只有两个维度,可以认为是x1和x2,但是有的时候只有这两个数据是不够的,我们可以通过扩充数据集的方式使得模型更快的收敛:

可以看到上面的代码,我们增加了五个特征,x1^2, x2^2, x1*x2, sin(x1)和sin(x1) ,其余的我们不变,还是使用上面的网络结构进行训练,只不过输入要改成7。

我们来看一下上面数据合并的方法,使用一个简单的例子进行查看:

接下来我们来看一下代码:

我们就直接看一下loss变化趋势的图像,可以看到这次的收敛会比上面做数据扩展之前快一些。

结语

关于上面四个模型的比较,我们可以看到当模型层数较深,且使用ReLU作为激活函数时效果会比较好,同时我们还学习了一种数据增维的方式。

这次的关于PyTorch的介绍到这里为止,其实内容也不是很多,很多代码片段都是重复的,我就只是多写了几次,方便之后直接复制来进行运行。

关于这些的学习自己还是一定要上手练习,代码不能只用看的。

我们下一篇文章再见,如果觉得好的话欢迎点赞转载!

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180726G0MHJT00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券