Tensorflow教程-线性模型

“ 再小的你,也有自己的个性”

通过阅读本文,您可以学会:

1、了解线性回归

2、两种实现线性模型的Tensorflow方法

源代码:

第一种方式:

https://github.com/PrivateThink/tensorflow_tutorial/blob/master/02.py

第二种方式:

https://github.com/PrivateThink/tensorflow_tutorial/blob/master/03.py

1、了解线性回归

线性模型是机器学习中最简单的模型,通俗的说就是用线性方程去拟合一些数据,这个方程要满足的条件就是这些数据到这个方程的距离要最小

在上面的坐标图上,分布着一些点,怎么才能找出一条线(这条线可以是直线,也可是曲线)使这些点到这条线的距离的和最短,这就是线性模型学习的过程。本文主要是讲述怎么用tensorflow去实现线性模型,因此,更多线性模型的知识留给读者去了解。

2、两种实现线性模型的Tensorflow方法

比较底层的实现

API的实现

机器学习的有个通用的过程,即:准备数据、构建模型、模型训练以及测试,通过本例子,您将了解机器学习的一般过程。

比较底层的实现

首先导入包和设置参数,学习率为0.01,训练次数为1000。

准备数据

train_x代表的是x轴的值集合,train_y代表的是y轴的值集合,它们的shape大小是相同的。

建立模型

在建立模型步骤中,首先创建输入x、y,权重w和偏置b。然后设置线性模型的表达式,x和w相乘再加上偏置b。损失函数使用平方误差,tf.reduce_mean函数代表的是平均操作,tf.square函数是平方操作,然后使用随机梯度方法进行优化损失函数。

训练阶段

在训练阶段,将每一对t_x、t_y进行优化训练,然后每迭代五十次就显示一次结果。当训练完成的时候,就打印损失函数,权重w以及偏置b,权重w以及偏置b就是线性模型所学到的。

最终打印结果如下:

最后用将线性模型用图画出来,代码如下:

从上图可以看出,拟合的效果还是挺不错的,接下来就将API的实现方式,只讲重点。

API实现方式

API实现方式是我给的名字,在理解这种方式之前,首先让我们了解下Tensorflow的Eager模式,这种模式非常的方便,也比较符合我们传统的编程思维。只要开启了这种模式,Tensorflow就会从原来的声明式编程形式转为命令式的编程形式。这个Eager模式有什么优点?

搭建模型简单方便,以前搭建模型的时候,要记住每一步的Tensor和shape的大小,只要一步搞错,就很难调试。现在可以直接将Tensor的形状和大小打印出来。

可以抛弃sess.run了。在Eager模式下,可以直接打印变量的值了。

以将tf开头的函数当成普通的函数使用

好了,现在开始Eager模式下的线性模型。

首先引入Eager相关的包,然后调用enable.eager_execution函数开启eager模式,这个模式一旦开启,中途就不能中断,另外,这个模型的代码和之前的代码是不能兼容的,开启这个模式后,使用tf.placeholder会直接报错。

数据准备阶段是一样的,除了不再用占位符tf.placeholder声明了。具体可以参照公开的源代码。

模型和损失函数如下城市表示,跟上面的方式是一样的,只是封装成了函数了。

封装成函数,方便调用。

这里计算梯度的方式跟第一种方式不太一样。tfe.implicit_gradients的参数是一个函数f,这个函数通常是要求导的,这里要求导的函数就是损失函数(平方误差),作用是返回一个对函数f求导的函数。求的结果grad将会在训练过程调用。

在上述程序中,不同调用了sess.run就可以打印出损失函数、权重和w和偏置的b的值,不用在创建新的变量每次都要初始化sess了。

准备好数据,搭建哈模型,下一步就是训练模型。

在训练阶段,调用grad,传入的参数是mean_square函数的三个参数。然后将得到的结果g传入优化器的apply_gradients函数中进行求导运算优化。

最后还是将拟合效果图显示出来。

第一个polt画的是原始数据的散点图,第二个plot画的是拟合的线性图。

好了,已经讲完了Tensorflow创建线性模型的方式,源代码已经公布到github上了。后续将会持续更新其他的Tensorflow教程,欢迎关注和分享!

源代码:

第一种方式:

https://github.com/PrivateThink/tensorflow_tutorial/blob/master/02.py

第二种方式:

https://github.com/PrivateThink/tensorflow_tutorial/blob/master/03.py

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180623G0LTXA00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券