【学术】从零开始,教你如何编写一个神经网络分类器

高水平的深度学习库,如TensorFlow,Keras和Pytorch,通过隐藏神经网络的许多乏味的内部工作细节,使深度学习从业者的生活变得更容易。尽管这是深度学习的好方法,但它仍然有一个小缺点:让许多基础理解较差的新来者在其他地方学习。我们的目标是提供从头开始编写的一个隐藏层全连接神经网络分类器(没有深度学习库),以帮助消除神经网络中的黑箱。

项目地址:https://github.com/ankonzoid/NN-from-scratch

所提供的神经网络对描述属于小麦的三类内核的几何属性的数据集进行分类(你可以轻松地将其替换为自己的自定义数据集)。假设有一个L2损失函数,并且在隐藏和输出层中的每个节点上使用sigmoid传递函数。权值更新方式使用具有L2范数的梯度下降的差量规则。

本文的其余部分,概述了我们的代码为构建和训练神经网络进行类预测所采取的一般步骤。关于深度学习和强化学习的博客,教程和项目,请查看Medium和Github。

Medium地址:https://medium.com/@ankonzoid Github地址:https://github.com/ankonzoid

我们逐步建立单层神经网络分类器

1.设置n次交叉验证

对于N次交叉验证,我们随机地排列N个样本指标,然后取连续大小为~ N/ n的块作为折叠。每个折叠作为一个交叉验证实验的测试集,补码(complement )指标作为训练集。

2.创建和训练神经网络模型

我们有2个完全连通的权值层:一个连接输入层节点与隐藏层节点,另一个连接隐藏层节点与输出层节点。如果没有任何偏项,这应该是神经网络中权值数量的总和(n_input *n_hidden + n_hidden* n_output)。我们通过对正态分布进行采样来初始化每个权值。

每个节点(神经元)具有存储到存储器中的3个属性:连接到其输入节点的权重列表,由正向传递的一些输入计算得到的输出值,以及表示其输出的反向传递分类不匹配的增量值层。这3个属性是相互交织的,并通过以下三个过程循环进行更新:

(A)正向传递一个训练示例,以更新当前给定节点权值的节点输出。每个节点输出被计算为其上一层输入(无偏项)的加权和,然后是sigmoid传递函数。

(B)反向传递分类错误,以更新当前给出节点权值的节点增量。因为我们使用从L2损失函数应用梯度下降导出的相同的增量规则方程。

(C)我们通过更新节点输出和增量来执行正向传递以更新当前的权值。

训练周期过程为(A)→(B)→(C),对每个训练样本执行该过程。

3.进行类预测

在训练之后,我们可以简单地使用这个模型来对我们的测试样本进行类预测,方法是将文本示例传递给经过训练的神经网络,获取输出的argmax函数。准确性分数是示例(在训练和测试集的n倍交叉验证中)数量的直观分数,在该示例中神经网络分类正确地除以了样本总数。

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-10-26

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

Github 项目推荐 | 100+ Chinese Word Vectors 上百种预训练中文词向量

该项目提供了不同表征(密集和稀疏)上下文特征(单词,ngram,字符等)和语料库训练的中文单词向量。开发者可以轻松获得具有不同属性的预先训练的向量,并将它们用于...

992
来自专栏深度学习自然语言处理

基于汉语短文本对话的立场检测系统理论与实践

汉语短文本对话立场检测的主要任务就是通过以对话的一个人的立场为主要立场,而判断另一个人针对该人的回话的立场。立场包括支持,反对,中立三种立场。基于对话的立场检测...

791
来自专栏AI研习社

机器学习可以生成任何线条图片的 ASCII 码绘画

回顾 1960 年代,贝尔实验室的天才们想出了用计算机语言来绘画的方法。这种绘画形式叫做 ASCII 绘画,尽管这种绘画需要使用计算机,但很难让计算机自动生成图...

832
来自专栏小鹏的专栏

01 TensorFlow入门(2)

Working with Matrices:         了解TensorFlow如何使用矩阵对于通过计算图理解数据流非常重要。 Getting read...

2306
来自专栏智能算法

多目标模板匹配

一. 模板匹配 模板匹配是数字图像处理的重要组成部分之一。把不同传感器或同一传感器在不同时间、不同成像条件下对同一景物获取的两幅或多幅图像在空间上对准,或根据已...

2935
来自专栏算法channel

机器学习|K-Means算法

01 — K-Means算法 在数据挖掘中,K-Means算法是一种 cluster analysis 的算法,主要通过不断地取离种子点最近均值的算法。 如下...

2686
来自专栏小詹同学

深度学习入门笔记系列 ( 五 )

本系列将分为 8 篇 。本次为第 5 篇 ,结合上一篇的应用实例 ,将前边学到一些基础知识用到手写数字的识别分类上 。

562
来自专栏人工智能头条

数据科学与机器学习管道中预处理的重要性(一):中心化、缩放和K近邻

1633
来自专栏利炳根的专栏

学习笔记CB012: LSTM 简单实现、完整实现、torch、小说训练word2vec lstm机器人

LSTM(Long Short Tem Memory)特殊递归神经网络,神经元保存历史记忆,解决自然语言处理统计方法只能考虑最近n个词语而忽略更久前词语的问题。...

4146
来自专栏进击的程序猿

经典检索算法:BM25原理

bm25 是一种用来评价搜索词和文档之间相关性的算法,它是一种基于概率检索模型提出的算法,再用简单的话来描述下bm25算法:我们有一个query和一批文档Ds,...

771

扫描关注云+社区