首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

谈谈我的深度学习代码经验……

先给大家推荐一个网站http://devdocs.io/

在这里可以很方便地查询各种api,本文里提到的所有的函数都可以在这里找到详细使用说明。

本文所有代码都可以在我的github主页找到,有的是前两篇推文里的,有的是我下一篇推文里会详细说明的

训练一个自己的网络大概需要四个部分:data loader、network、model和train。data loader就是数据接口,用来从磁盘读取数据到GPU,data loader的效率是整体训练效率的一个瓶颈,如果这里的IO没有做好,整体速度会很慢,GPU的使用率也很低;network主要实现各种网络的架构,有时候也叫feed forward,就是神经网络里的前传;model实现整个训练架构,包括loss的计算、训练步骤的设计、优化器的选择等;最后train就是训练的main函数,主要实现循环训练、计算validation loss等。

1

network

先从最简单的部分说起。Network里装的就是组成整个model需要的各种零件,比如训练一个GAN网络,需要一个生成网络,还需要一个判别网络,有时候还要一个特征提取网络等,这些网络自成一体,可以分别用函数的形式定义在一个py文件里。tensorflow里面提供了大量的集成函数,各种类型,各种层次都有,所以定义一个网络结构特别简单。比如常用的vgg19,只需要十几行。

这里使用了slim包。这个包十分简洁,集成度高,对于一些常规网络,写起来十分方便。

除了slim,tensorflow里还有很多其他的包,比如layers,contrib等,都是十分方便的。

这里的conv2d,batch_norm也都是集成度很高的函数。

2

Model

有了各种网络部件,就可以用这些组件拼接一个模型了,也是我们常说的“搭积木”

这里需要提一下tensorflow的variable scope,这里不仅仅可以让模型结构清晰明了,还方便之后调用各个部分的参数,在共享变量的时候也十分重要(这里之后可以单独说一说)。我这里是学习一个大佬的用法,用collection的形式把模型组织起来,十分简洁。当然也可以写成一个类,但是个人觉得写一个类比较麻烦,不够简洁。

4

data loader

tf里用的最多也是最常见的读入数据的方法应该就是feed_dict了,用placeholder构建网络,然后在训练的时候用numpy数组feed进去。这样使用是十分方便的,但是效率不是特别高,大概只能剥夺70%左右的GPU使用率。因为每一次调用feed_dict的时候,程序都需要从硬盘里读取数据到CPU,然后再从CPU把数据拷进GPU,这个过程就十分耗时。当然,当数据量比较小的时候,比如所有的训练数据只有2G,可以在训练开始就把所有的数据全装进内存,然后每次feed数据的时候就不用从硬盘读数据。但是大部分情况下,特别是做图像处理,训练数据都是几十G往上走,就不可能一次全读进内存了。

那怎么样才能加快数据读取这一过程呢?计算机从硬盘读数据到CPU再拷贝进GPU这个过程是没法加速的,这个取决于硬件性能。那是不是就不能加速了呢?注意到数据读取和后面的训练是独立进行的,这个独立性就让我们可以用多线程来加速。可以让多个准备数据的线程同时读取数据,然后存到一个队列里,训练用的线程就可以专心做训练计算了,数据有人给他准备好了。

tf还有一个更强大灵活的库——tf.data

看见没,这就是传说中的流水线作业。这里的parse_jpg和_compute_weight都是自定义的函数,数据读进来之后,可以不停地使用map函数来对数据进行处理,十分高效。最后只需要调用iterator.get_next()就可以返回一个batch的数据了。

5

Train

这个部分说难不难,说简单不简单,关键看你想把“自动化”做到哪个程度。因为前面数据接口和模型都做好了,剩下只要不停地sess.run(train_op)就行了。还要考虑训练过程中需要输出的信息,比如loss,acc等。这里推荐一种方法。

个人觉得这样写比较简洁,所有的结果都以字典的形式存储在results里。

除了输出信息,还要记录训练的中间过程,tf里有专门的函数

然后在训练的时候

这样这些变量就能用tensorboard很方便的可视化了。

注意到我这里用了flags,这是tensorflow里专门用来做参数管理的。

然后在定义每个函数的时候,只需要传递flags一个参数就行,要用哪个全局参数就用哪个,十分方便,再也不会被长长的函数定义给迷惑了,也不需要到处找需要传递多少参数了,统统用一个flag搞定。

但是,严格上说,训练的是需要做validation的,就是在训练的时候还要在测试集上测试,有时候训练的loss在下降,但是测试集的loss可能开始上升了,这个时候就说明网络开始过拟合了。Tensorflow在这一块好像没有特别集成的函数,需要自己手动实现。如果是直接用feed_dict,也很方便,但是如果用的是多线程,这个时候数据集是写到图里的,在运行过程中是不能改图的,也就是不能把数据更换到测试数据集上的。这一块我尝试了好久,总算做得自己比较满意了,在我的上一篇推文里的github上有源码,大家可以去围观。

读原文获取github源码

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180706G0E9EU00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券