首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

搭建一个简单的神经网络

前言

我们已经知道,深度学习是在机器学习的基础上发展的,神经网络的层级比机器学习的多而复杂。神经网络结构,正是受到生物学领域中的神经网络的启发,才有了今天机器学习、深度学习中的神经网络的结构。

生物学领域神经网络中的单个神经元 ↑

今天我们就利用谷歌的深度学习框架TensorFlow,来搭建一个自己的简单的神经网络。

神经网络结构图

上面是一个3层的神经网络,其中:每一个⭕️代表一个神经元,也可以叫做神经节点、神经结点,每个神经元有一个偏置bias。每一条线,有一个权重weight。

神经网络学习的目标:就是通过减少损失loss或cost,来确定权重和偏置。理论上,如果有足够多的隐藏层和足够大的训练集,则可以模拟出任何方程。

先来看一下今天所搭建的神经网络,让其自己去学习达到的目标:

搭建一个简单的神经网络

搭建一个简单的神经网络的基本思路:

1.准备数据

2.搭建模型

3.训练模型

4.使用模型

先来搭建一个只有两层(不包含输入层)的神经网络,也即:输入层--隐藏层--输出层。

为了方便层的添加,我们把添加层,抽到一个方法中,参数解释如下:

完整的添加层,代码如下所示:

说明:

1.权重weights的初始化,通常是通过生产随机数作为权重的初始化值。

2.偏置bias的初始化,通常是将偏置初始化为0。

神经元

3.我们已经知道,每一个⭕️代表着一个神经元,也可以叫做神经节点、神经结点,每个神经元有一个偏置bias。每一条线,有一个权重weight。如上图所示。神经网络学习的目标:就是通过减少损失loss或cost,来确定权重和偏置,所以权重和偏置都是用中的来创建的。有一个很重要的特性,是可被训练。

4.整体表示:权重 * 自变量 + 偏置,也就是通常我们见到的线性函数y = a * x + b;权重即这里的a,偏置即这里的b。在中,表示矩阵相乘。

5.通常,一层中的神经元经过加权求和,然后再经过非线性方程得到的结果转化为输出,或者作为下一层的输入。注意:非线性方程,即是通常所说的激励函数。

常见的激励函数:sigmoid函数、tanh函数、ReLu函数、SoftMax函数等等。这里我们设计的函数需要调用者传入一个激励函数,如果不传入激励函数,则按照线性处理。

补充:tensor flow中的随机数

tensor flow中的随机数,常用的有两种:正态分布的随机数和均匀分布随机数。

正态分布随机数:tf.random_normal(shape, mean=平均值,stddev=标准差,seed=随机种子)

均匀分布随机数:tf.random_uniform(shape,minval=下边界,maxval=上边界,dtype=数据类型,seed=随机种子)

每个参数的含义,都在上面写明了。其中,随机种子是产生伪随机数的一种,当设置了相同的随机种子后,每次运行得到的随机数都相同。

准备数据

通过准备x数据,y数据,其中为了数据的略显杂乱,加入了一些噪音数据,也即噪点。

搭建模型

按照预先设计的神经网络:输入层 --- 隐藏层(1层)--- 输出层,所以搭建模型这一步骤中,我们加入了:

一个隐藏层,这一层有20个神经元,且这一层使用的激励函数为。当然,你也可以使用其他激励函数或者,来对比下他们的拟合效果。

一个输出层,这一层只有一个神经元,即:预测值。另外,这一层没有使用激励函数。

神经网络学习的目标:就是通过减少损失loss或cost:使用均方误差来计算的损失函数。

本步骤的完整代码如下:

注意:

1.本示例,采用的是梯度下降优化器,让损失最小化:。传入的学习率,通常是个小数,通常不同的学习率对最终的学习效果是有影响的,且不同的模型影响不同。开发过程中,可以修改学习率的大小,来看看具体的影响。

2.tensorflow中必须对变量进行显示初始化:,否则在会话中运行会报错。

3.由于x、y的值需要根据后面的值来确定,所以用了tensorflow中的占位。

训练模型

训练模型,必须在tensorflow中的Session会话的上下文中进行。因为涉及多次训练,我们设计了一个循环。每学习200次后打印一下损失值,并且更新拟合的图像。

运行结果如下:

学习效果图

说明:

1.是捕捉到的异常。因为第一次删除实线的图像时,我们还没有绘制,所以有异常抛出,被我们捕捉到了。

3.通过学习效果图,可以看到,最初偏离数据很多,随着学习次数增多,拟合的效果也越来越好。

小结

本文通过谷歌的深度学习框架TensorFlow,来搭建一个自己的简单的神经网络。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181111G0ABXL00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券