caffe随记(三) --- solver 简析

1、概述

solver算是caffe中比较核心的一个概念,在我们训练train我们的网络时,就必须要带上这个参数

如下例是我要对Lenet进行训练的时候要调用的程序,现在不知道什么意思没关系,只需要知道这个solver.prototxt是个必不可少的东西就ok

./build/tools/caffe train--solver=examples/mnist/lenet_solver.prototxt 

Solver通过协调Net的前向推断计算和反向梯度计算对参数进行更新,从而达到减小loss的目的

Solver的主要功能:

○ 设计好需要优化的对象,创建train网络和test网络。(通过调用另外一个配置文件prototxt来进行)

○ 通过forward和backward迭代的进行优化来更新新参数。

○ 定期的评价测试网络。 (可设定多少次训练后,进行一次测试)。

○ 在优化过程中记录模型和solver的状态的快照。

在每一次的迭代过程中,solver做了这几步工作:

1、调用forward算法来计算最终的输出值,以及对应的loss

2、调用backward算法来计算每层的梯度

3、根据选用的slover方法,利用梯度进行参数更新

4、根据学习率、历史数据、求解方法更新solver状态,使得权重从初始化状态逐步更新到最终的状态。

2、caffe.proto关于solver的描述

虽然内容很多,但是基本上是注释占的篇幅多,而且我也基本上都翻译成了中文注释,建议仔细阅读,这是一切solver的模版

// NOTE
// Update the next available ID when you add a new SolverParameter field.
// ## 注意,如果你要增加一个新的sovler参数,需要给它更新ID,就是下面内容中的数字
// SolverParameter next available ID: 41 (last added: type)  ##下一个可用的ID是41,上一次caffe增加的是type参数,就是下文中的40
message SolverParameter {
  //////////////////////////////////////////////////////////////////////////////
  // Specifying the train and test networks
  // ##指定训练和测试网络
  // Exactly one train net must be specified using one of the following fields:
  //     train_net_param, train_net, net_param, net
  // One or more test nets may be specified using any of the following fields:
  //     test_net_param, test_net, net_param, net
  // If more than one test net field is specified (e.g., both net and
  // test_net are specified), they will be evaluated in the field order given
  // above: (1) test_net_param, (2) test_net, (3) net_param/net.
  // A test_iter must be specified for each test_net.
  // A test_level and/or a test_stage may also be specified for each test_net.
  //////////////////////////////////////////////////////////////////////////////

  // Proto filename for the train net, possibly combined with one or more
  // test nets.  ##这个训练网络的Proto文件名,可能结合一个或多个测试网络。
  optional string net = 24;
  // Inline train net param, possibly combined with one or more test nets. ## 对应的训练网络的参数,可能结合一个或多个测试网络
  optional NetParameter net_param = 25;

  optional string train_net = 1; // Proto filename for the train net.     ## train net的proto文件名
  repeated string test_net = 2; // Proto filenames for the test nets.     ## test nets的proto文件名
  optional NetParameter train_net_param = 21; // Inline train net params. ## 与上面train网络一致对应的参数
  repeated NetParameter test_net_param = 22; // Inline test net params.   ## 与上面test网络一致对应的参数


  // The states for the train/test nets. Must be unspecified or
  // specified once per net.
  // ## train/test网络的状态。 必须是未指定或每个网络指定一次
  // By default, all states will have solver = true;            ##默认情况下,所有状态都将有solver = true;
  // train_state will have phase = TRAIN,                       ##train_state会有phase = TRAIN,
  // and all test_state's will have phase = TEST.               ##所有的test_state都将进行phase = TEST
  // Other defaults are set according to the NetState defaults. ##其他默认值是根据NetState默认设置的。
  optional NetState train_state = 26;
  repeated NetState test_state = 27;

  // The number of iterations for each test net.                ## test网络的迭代次数:
  repeated int32 test_iter = 3;

  // The number of iterations between two testing phases.
  // ## 两次test之间(train)的迭代次数 
  //## <训练test_interval个批次,再测试test_iter个批次,为一个回合(epoch), 合理设置应使得每个回合内,遍历覆盖到全部训练样本和测试样本 >
  optional int32 test_interval = 4 [default = 0];
  optional bool test_compute_loss = 19 [default = false]; //   ## 默认不计算测试时损失

  // If true, run an initial test pass before the first iteration,
  // ensuring memory availability and printing the starting value of the loss.
  // ##如设置为真,则在训练前运行一次测试,以确保内存足够,并打印初始损失值
  optional bool test_initialization = 32 [default = true];
  optional float base_lr = 5; // The base learning rate              ##基本学习速率
  // the number of iterations between displaying info. If display = 0, no info
  // will be displayed.       ##打印信息的遍历间隔,遍历多少个批次打印一次信息。设置为0则不打印。
  optional int32 display = 6;
  // Display the loss averaged over the last average_loss iterations ## 打印最后一个迭代批次下的平均损失
  optional int32 average_loss = 33 [default = 1];
  optional int32 max_iter = 7; // the maximum number of iterations   ##train的最大迭代次数  
  // accumulate gradients over `iter_size` x `batch_size` instances 
  // ## 累积梯度误差基于“iter_size×batchSize”个样本实例,< “批次数×批量数”=“遍历的批次数×每批的样本数”个样本实例  >
  optional int32 iter_size = 36 [default = 1];

  // The learning rate decay policy. The currently implemented learning rate
  // policies are as follows:
  //##学习率衰退策略.目前实行的学习率策略如下:
  //    - fixed: always return base_lr.                         ##保持base_lr不变.
  //    - step: return base_lr * gamma ^ (floor(iter / step))   ##返回 base_lr * gamma ^(floor(iter / stepsize)),
  //    - exp: return base_lr * gamma ^ iter                    ##返回base_lr * gamma ^ iter, iter为当前迭代次数
  //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)  ##如果设置为inv,还需设置一个power,返回return 后的内容
  //    - multistep: similar to step but it allows non uniform steps defined by ##这个参数和step很相似,还需要设置一个stepvalue。
  //      stepvalue                                                             ##但step是均匀等间隔变化,而此参数根据stepvalue变化
  //    - poly: the effective learning rate follows a polynomial decay, to be   ##学习率进行多项式衰减,由max_iter变为0
  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) ,  ##返回 base_lr (1- iter/max_iter) ^ (power)
  //    - sigmoid: the effective learning rate follows a sigmod decay           ##学习率进行sigmod衰减,
  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))             ##返回return 后的内容
  //
  // where base_lr, max_iter, gamma, step, stepvalue and power are defined
  // in the solver parameter protocol buffer, and iter is the current iteration.
  // ## 在上述参数中,base_lr, max_iter, gamma, step, stepvalue and power 被定义  
  //    在solver.prototxt文件中,iter是当前迭代次数。
  optional string lr_policy = 8;
  optional float gamma = 9; // The parameter to compute the learning rate.
  optional float power = 10; // The parameter to compute the learning rate.
  optional float momentum = 11; // The momentum value.   ## 动量
  optional float weight_decay = 12; // The weight decay. ##权值衰减系数
  // regularization types supported: L1 and L2
  // controlled by weight_decay 
  // ## 由权值衰减系数所控制的正则化类型:L1或L2范数,默认L2
  optional string regularization_type = 29 [default = "L2"];
  // the stepsize for learning rate policy "step"       ##"step"策略下,学习率的步长值
  optional int32 stepsize = 13;
  // the stepsize for learning rate policy "multistep"  ##  "multistep"策略下的步长值
  repeated int32 stepvalue = 34;

  // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
  // whenever their actual L2 norm is larger.
  optional float clip_gradients = 35 [default = -1];

  optional int32 snapshot = 14 [default = 0]; // The snapshot interval ##快照间隔<遍历多少次对模型和求解器状态保存一次>
  optional string snapshot_prefix = 15; // The prefix for the snapshot.
  // whether to snapshot diff in the results or not. Snapshotting diff will help
  // debugging but the final protocol buffer size will be much larger.
  // ## 是否对diff快照,有助调试,但最终的protocol buffer尺寸会很大
  optional bool snapshot_diff = 16 [default = false];
  // ## 快照数据保存格式{ hdf5,binaryproto(默认) }
  enum SnapshotFormat {
    HDF5 = 0;
    BINARYPROTO = 1;
  }
  optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
 // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.       ##选CPU或GPU模式,默认是GPU
  enum SolverMode {
    CPU = 0;
    GPU = 1;
  }
  optional SolverMode solver_mode = 17 [default = GPU]; 
  // the device_id will that be used in GPU mode. Use device_id = 0 in default. ##如果选了GPU模式,此参数指定哪个GPU,默认是0号GPU
  optional int32 device_id = 18 [default = 0];
  // If non-negative, the seed with which the Solver will initialize the Caffe
  // random number generator -- useful for reproducible results. Otherwise,
  // (and by default) initialize using a seed derived from the system clock.
  optional int64 random_seed = 20 [default = -1];

 
  // type of the solver                   ## 求解器类型=SGD(默认),目前一共有6种
  optional string type = 40 [default = "SGD"];


  // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
  //## RMSProp,AdaGrad和AdaDelta和Adam的数值稳定性
  optional float delta = 31 [default = 1e-8];
  // parameters for the Adam solver       ## Adam类型时的参数
  optional float momentum2 = 39 [default = 0.999];


  // RMSProp decay value                  ##RMSProp的衰减值
  // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
  optional float rms_decay = 38;


  // If true, print information about the state of the net that may help with
  // debugging learning problems.
  //## 此参数默认为false,若为true,则打印网络状态信息,有助于调试问题
  optional bool debug_info = 23 [default = false];


  // If false, don't save a snapshot after training finishes.
  //## 此参数默认为true,若为false,则不会在训练后保存快照
  optional bool snapshot_after_train = 28 [default = true];


  // DEPRECATED: old solver enum types, use string instead ##已经弃用,本来表示6种sovler类型,现在用string type中的string代替
  enum SolverType {
    SGD = 0;
    NESTEROV = 1;
    ADAGRAD = 2;
    RMSPROP = 3;
    ADADELTA = 4;
    ADAM = 5;
  }
  // DEPRECATED: use type instead of solver_type  ##已经弃用,用string type中的type代替
  optional SolverType solver_type = 30 [default = SGD];
}

3、举例说明

我仍以caffe/examples/mnist/lenet_solver.prototxt这个文件为例,下图是我的截图

我把上图的内容复制过来看的清楚一些,并把注释翻译了一下:

---------------这一部分可以对照着2中proto中的描述看你会发现其实solver的编写也就是对着模版填参数的一个过程,----------

# 我们需要的Net的模型,这个模型定义在另一个prototxt文件中,这个就是我上一篇博文举的Net的例子

# 显然这里根据需要你可以选择其他的一些Net

net: "examples/mnist/lenet_train_test.prototxt"

#test_iter 设置了test一共迭代多少次,这里是100

# 至于test每一次迭代处理多少张图片,在Net那个prototxt里面batch_size规定了的

test_iter: 100

# 训练每迭代500次,测试一次(这每一次测试要迭代100次).

test_interval:500

#设置学习率。base_lr用于设置基础学习率,在迭代的过程中,可以对基础学习率进行调整。怎么样进行调整,就是调整的策略,由lr_policy来设置。

#momentum称为动量,使得权重更新更为平缓

#weight_decay称为衰减率因子,防止过拟合的一个参数

base_lr: 0.01

momentum: 0.9

# 这里省略了一个内容 type: SGD  ,就是solver方法的选择,因为默认就是SGD,所以这个 solver. prototxt 中省略没写, 如果你想用其他的sovler方法就要指明写出来

weight_decay:0.0005

# 学习率调整的策略,详细见我下面的补充

lr_policy: "inv"

gamma: 0.0001

power: 0.75

# train每迭代100次就显示一次

display: 100

#train最大迭代次数

max_iter: 10000

#快照。将训练出来的model和solver状态进行保存,snapshot用于设置训练多少次后进行保存,默认为0,不保存。snapshot_prefix设置保存路径。

还可以设置snapshot_diff,是否保存梯度值,默认为false,不保存。

也可以设置snapshot_format,保存的类型。有两种选择:HDF5和BINARYPROTO,默认为BINARYPROTO

#这里设置train每迭代5000次就存储一次数据

snapshot: 5000

snapshot_prefix: "examples/mnist/lenet"

#设置运行模式,默认为GPU,如果你没有GPU,则需要改成CPU,否则会出错.

solver_mode: GPU

4、solver方法

Solver方法就是计算最小化损失值(loss)的方法,也就是我上面解析中说的省略掉的一行,其实一共有6种sovler方法:

· Stochastic Gradient Descent (type: "SGD"),

· AdaDelta (type: "AdaDelta"),

· Adaptive Gradient (type: "AdaGrad"),

· Adam (type: "Adam"),

· Nesterov’s Accelerated Gradient (type: "Nesterov") and

· RMSprop (type: "RMSProp")

默认设置的是SGD 随机梯度下降,所以就可以不写,但是如果想用其他的,就必须要写出来,比如type:Adam

这个方法对于我这种小白来说暂时没有研究的必要,而且SGD方法的数学原理至少我是知道的,所以我这里就只把这几种方法列出来了,没有详细解读,如果有兴趣可以参考下面这篇博客:

http://www.cnblogs.com/denny402/p/5074212.html

这篇关于sovler讲解的博文就写完了,下面一篇就准备来将一下caffe中的hello world--使用Lenet来识别mnist手写数据。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏PPV课数据科学社区

TensorFlow 数据集和估算器介绍

TensorFlow 1.3 引入了两个重要功能,您应当尝试一下: 数据集:一种创建输入管道(即,将数据读入您的程序)的全新方式。 估算器:一种创建 Ten...

3069
来自专栏大数据挖掘DT机器学习

支持中文文本数据挖掘的开源项目PyMining

最近一个月,过年的时候天天在家里呆着,年后公司的事情也不断,有一段时间没有更新博客了。PyMining是我最近一段时间构思的一个项目,虽然目前看来比较微型。该项...

3546
来自专栏AI研习社

Github 代码实践:Pytorch 实现的语义分割器

使用Detectron预训练权重输出 *e2e_mask_rcnn-R-101-FPN_2x* 的示例

1002
来自专栏CSDN技术头条

利用GPU和Caffe训练神经网络

本文为利用GPU和Caffe训练神经网络的实战教程,介绍了根据Kaggle的“奥托集团产品分类挑战赛”的数据进行训练一种多层前馈网络模型的方法,如何将模型应用于...

19110
来自专栏漫漫深度学习路

mxnet-Gluon(一):mxnet-Gluon 入门

沐神已经提供了一份官方的文档,为什么要写这么一篇博客: 沐神提供的中文文档质量是非常高的,地址,但是感觉需要看一段时间才能上手 Gluon, 本博客结构模仿 p...

3456
来自专栏Deep learning进阶路

caffe随记(二) --- 数据结构简介

caffe随记(二) --- 数据结构简介 注:这篇文章博文我写的内容有点多,建议看一下左上角的目录,对本文结构有个大致了解。 1、Blob Blob其实...

2210
来自专栏人工智能LeadAI

TensorFlow中的Nan值的陷阱

之前在TensorFlow中实现不同的神经网络,作为新手,发现经常会出现计算的loss中,出现Nan值的情况,总的来说,TensorFlow中出现Nan值的情况...

4405
来自专栏大数据学习笔记

TensorFlow学习笔记:2、TensorFlow超简单入门程序

TensorFlow学习笔记:2、TensorFlow超简单入门程序 2.1 HelloWorld代码说明 import tensorflow as tf ...

2155
来自专栏机器学习实践二三事

Tensorflow实现word2vec

大名鼎鼎的word2vec,相关原理就不讲了,已经有很多篇优秀的博客分析这个了. 如果要看背后的数学原理的话,可以看看这个: https://wenku.b...

2617
来自专栏AI研习社

让系统自动选择空闲的GPU设备!帮你一次解决抢卡争端

项目地址:QuantumLiu / tf_gpu_manager 更新:支持pytorch 使用 git clone https://github.com/...

41111

扫码关注云+社区