前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习入门笔记(2)-线性回归 Linear Regression with autograd

深度学习入门笔记(2)-线性回归 Linear Regression with autograd

作者头像
企鹅号小编
发布2018-02-26 10:12:00
9930
发布2018-02-26 10:12:00
举报
文章被收录于专栏:人工智能人工智能

一同前行!

假设我们有一个曲线(或者平面)y=wx+b

我们给定它一个特定的w,和b

w = [2,51]

b = 21.2

即y=2x1+51x2+21.2

目标是通过数据训练使得w和b靠近w =[2,51],b = 21.2,换句话说就是通过训练得到一个平面能够跟实际的平面(y=2x1+51x2+21.2)一致。

-代码实现-

回顾深度学习的套路

准备数据集dataset

构建网络(激活函数activation function)

初始化

训练(epochs,更新权重)

预测

所用的深度学习框架为Mxnet

需要用到的库为 mxnet 中的nd,gluon,gutograd

还有用于图形化显示的matplotlib

from mxnet import nd,gluon,autograd

import matplotlib.pyplot as plt

1.准备数据集

正太分布随机得到300个样本数

每个样本由X1,X2及结果y组成。

形成300个数据集,给Y加了一点噪声(0.01倍的正太分布数据)来模拟真实数据。

2.构建网络(激活函数activation function)

激活函数为参数的输入到输出的关系。

这里是y=wx+b的关系

def net(x): #激活函数

return nd.dot(x,w)+b

3.初始化

初始化真实的参数为

true_w=nd.array([2,51])

true_b=nd.array([21.2])

变化的参数,需要迭代所求的参数初始化如下:

w=nd.random_normal(shape=(2))

b=nd.zeros((1))

params=[w,b]

w初值为随机的数[w1,w2]

b初值为[0].

初值本可以任意由自己定义,对结果都是没有影响的,可能会影响迭代收敛的速度。

4.训练(epochs,更新权重)

训练时,还是利用的梯度下降法

def SGD(params,eta): #梯度下降法

for param in params:

param[:]=param-eta*param.grad

param 为训练中的【w,b】

eta 为训练步长

损失函数定义为均方误差:

训练:

迭代epochs=10次

步长为eta=0.01

其中用到了autograd自动求导函数。

attach_grad()给参数附上梯度,

向系统申请空间

with autograd.record():

记录需要求导的函数

backward()

回传求导

举例如下图z=y*x,y=2*x:

继续

epochs=10

eta=0.01

for param in params:

param.attach_grad() #要求系统申请对应的空间

for e in range(epochs):

for x,y in data_iter:

with autograd.record():

yhat=net(x)

loss=square_loss(yhat,y)

loss.backward()

SGD(params,eta)

#break

plot(xs)

5.预测:

选取部分数据(50个点),以x2为横坐标,Y为纵坐标。

由预估曲线和真实曲线进行可视化对比。

def plot(xs,sample_size=50):

_,fig=plt.subplots()

plotxs=xs[:sample_size,:]

plotxn=xs[:sample_size,1].asnumpy()

#以x2为横坐标

yhatn=net(plotxs).asnumpy()

fig.plot(plotxn,yhatn,'or') #估计曲线

ys=nd.dot(plotxs,true_w)+true_b

fig.plot(plotxn,ys.asnumpy(),'*g') #实际数据曲线

plt.show()

为了更好的说明参数迭代接近目标,分步截图了如下过程。

结果

迭代后的参数如下:

还记得我们设定的真实参数吗,

true_w=nd.array([2,51])

true_b=nd.array([21.2])

迭代后的参数已经趋近了!

至此,一个线性回归算法就算是完成了!

本文来自企鹅号 - 机器学习算法与Python精研媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文来自企鹅号 - 机器学习算法与Python精研媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档