caffe随记(三) --- solver 简析

1、概述

solver算是caffe中比较核心的一个概念,在我们训练train我们的网络时,就必须要带上这个参数

如下例是我要对Lenet进行训练的时候要调用的程序,现在不知道什么意思没关系,只需要知道这个solver.prototxt是个必不可少的东西就ok

./build/tools/caffe train--solver=examples/mnist/lenet_solver.prototxt 

Solver通过协调Net的前向推断计算和反向梯度计算对参数进行更新,从而达到减小loss的目的

Solver的主要功能:

○ 设计好需要优化的对象,创建train网络和test网络。(通过调用另外一个配置文件prototxt来进行)

○ 通过forward和backward迭代的进行优化来更新新参数。

○ 定期的评价测试网络。 (可设定多少次训练后,进行一次测试)。

○ 在优化过程中记录模型和solver的状态的快照。

在每一次的迭代过程中,solver做了这几步工作:

1、调用forward算法来计算最终的输出值,以及对应的loss

2、调用backward算法来计算每层的梯度

3、根据选用的slover方法,利用梯度进行参数更新

4、根据学习率、历史数据、求解方法更新solver状态,使得权重从初始化状态逐步更新到最终的状态。

2、caffe.proto关于solver的描述

虽然内容很多,但是基本上是注释占的篇幅多,而且我也基本上都翻译成了中文注释,建议仔细阅读,这是一切solver的模版

// NOTE
// Update the next available ID when you add a new SolverParameter field.
// ## 注意,如果你要增加一个新的sovler参数,需要给它更新ID,就是下面内容中的数字
// SolverParameter next available ID: 41 (last added: type)  ##下一个可用的ID是41,上一次caffe增加的是type参数,就是下文中的40
message SolverParameter {
  //////////////////////////////////////////////////////////////////////////////
  // Specifying the train and test networks
  // ##指定训练和测试网络
  // Exactly one train net must be specified using one of the following fields:
  //     train_net_param, train_net, net_param, net
  // One or more test nets may be specified using any of the following fields:
  //     test_net_param, test_net, net_param, net
  // If more than one test net field is specified (e.g., both net and
  // test_net are specified), they will be evaluated in the field order given
  // above: (1) test_net_param, (2) test_net, (3) net_param/net.
  // A test_iter must be specified for each test_net.
  // A test_level and/or a test_stage may also be specified for each test_net.
  //////////////////////////////////////////////////////////////////////////////

  // Proto filename for the train net, possibly combined with one or more
  // test nets.  ##这个训练网络的Proto文件名,可能结合一个或多个测试网络。
  optional string net = 24;
  // Inline train net param, possibly combined with one or more test nets. ## 对应的训练网络的参数,可能结合一个或多个测试网络
  optional NetParameter net_param = 25;

  optional string train_net = 1; // Proto filename for the train net.     ## train net的proto文件名
  repeated string test_net = 2; // Proto filenames for the test nets.     ## test nets的proto文件名
  optional NetParameter train_net_param = 21; // Inline train net params. ## 与上面train网络一致对应的参数
  repeated NetParameter test_net_param = 22; // Inline test net params.   ## 与上面test网络一致对应的参数


  // The states for the train/test nets. Must be unspecified or
  // specified once per net.
  // ## train/test网络的状态。 必须是未指定或每个网络指定一次
  // By default, all states will have solver = true;            ##默认情况下,所有状态都将有solver = true;
  // train_state will have phase = TRAIN,                       ##train_state会有phase = TRAIN,
  // and all test_state's will have phase = TEST.               ##所有的test_state都将进行phase = TEST
  // Other defaults are set according to the NetState defaults. ##其他默认值是根据NetState默认设置的。
  optional NetState train_state = 26;
  repeated NetState test_state = 27;

  // The number of iterations for each test net.                ## test网络的迭代次数:
  repeated int32 test_iter = 3;

  // The number of iterations between two testing phases.
  // ## 两次test之间(train)的迭代次数 
  //## <训练test_interval个批次,再测试test_iter个批次,为一个回合(epoch), 合理设置应使得每个回合内,遍历覆盖到全部训练样本和测试样本 >
  optional int32 test_interval = 4 [default = 0];
  optional bool test_compute_loss = 19 [default = false]; //   ## 默认不计算测试时损失

  // If true, run an initial test pass before the first iteration,
  // ensuring memory availability and printing the starting value of the loss.
  // ##如设置为真,则在训练前运行一次测试,以确保内存足够,并打印初始损失值
  optional bool test_initialization = 32 [default = true];
  optional float base_lr = 5; // The base learning rate              ##基本学习速率
  // the number of iterations between displaying info. If display = 0, no info
  // will be displayed.       ##打印信息的遍历间隔,遍历多少个批次打印一次信息。设置为0则不打印。
  optional int32 display = 6;
  // Display the loss averaged over the last average_loss iterations ## 打印最后一个迭代批次下的平均损失
  optional int32 average_loss = 33 [default = 1];
  optional int32 max_iter = 7; // the maximum number of iterations   ##train的最大迭代次数  
  // accumulate gradients over `iter_size` x `batch_size` instances 
  // ## 累积梯度误差基于“iter_size×batchSize”个样本实例,< “批次数×批量数”=“遍历的批次数×每批的样本数”个样本实例  >
  optional int32 iter_size = 36 [default = 1];

  // The learning rate decay policy. The currently implemented learning rate
  // policies are as follows:
  //##学习率衰退策略.目前实行的学习率策略如下:
  //    - fixed: always return base_lr.                         ##保持base_lr不变.
  //    - step: return base_lr * gamma ^ (floor(iter / step))   ##返回 base_lr * gamma ^(floor(iter / stepsize)),
  //    - exp: return base_lr * gamma ^ iter                    ##返回base_lr * gamma ^ iter, iter为当前迭代次数
  //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)  ##如果设置为inv,还需设置一个power,返回return 后的内容
  //    - multistep: similar to step but it allows non uniform steps defined by ##这个参数和step很相似,还需要设置一个stepvalue。
  //      stepvalue                                                             ##但step是均匀等间隔变化,而此参数根据stepvalue变化
  //    - poly: the effective learning rate follows a polynomial decay, to be   ##学习率进行多项式衰减,由max_iter变为0
  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) ,  ##返回 base_lr (1- iter/max_iter) ^ (power)
  //    - sigmoid: the effective learning rate follows a sigmod decay           ##学习率进行sigmod衰减,
  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))             ##返回return 后的内容
  //
  // where base_lr, max_iter, gamma, step, stepvalue and power are defined
  // in the solver parameter protocol buffer, and iter is the current iteration.
  // ## 在上述参数中,base_lr, max_iter, gamma, step, stepvalue and power 被定义  
  //    在solver.prototxt文件中,iter是当前迭代次数。
  optional string lr_policy = 8;
  optional float gamma = 9; // The parameter to compute the learning rate.
  optional float power = 10; // The parameter to compute the learning rate.
  optional float momentum = 11; // The momentum value.   ## 动量
  optional float weight_decay = 12; // The weight decay. ##权值衰减系数
  // regularization types supported: L1 and L2
  // controlled by weight_decay 
  // ## 由权值衰减系数所控制的正则化类型:L1或L2范数,默认L2
  optional string regularization_type = 29 [default = "L2"];
  // the stepsize for learning rate policy "step"       ##"step"策略下,学习率的步长值
  optional int32 stepsize = 13;
  // the stepsize for learning rate policy "multistep"  ##  "multistep"策略下的步长值
  repeated int32 stepvalue = 34;

  // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
  // whenever their actual L2 norm is larger.
  optional float clip_gradients = 35 [default = -1];

  optional int32 snapshot = 14 [default = 0]; // The snapshot interval ##快照间隔<遍历多少次对模型和求解器状态保存一次>
  optional string snapshot_prefix = 15; // The prefix for the snapshot.
  // whether to snapshot diff in the results or not. Snapshotting diff will help
  // debugging but the final protocol buffer size will be much larger.
  // ## 是否对diff快照,有助调试,但最终的protocol buffer尺寸会很大
  optional bool snapshot_diff = 16 [default = false];
  // ## 快照数据保存格式{ hdf5,binaryproto(默认) }
  enum SnapshotFormat {
    HDF5 = 0;
    BINARYPROTO = 1;
  }
  optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
 // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.       ##选CPU或GPU模式,默认是GPU
  enum SolverMode {
    CPU = 0;
    GPU = 1;
  }
  optional SolverMode solver_mode = 17 [default = GPU]; 
  // the device_id will that be used in GPU mode. Use device_id = 0 in default. ##如果选了GPU模式,此参数指定哪个GPU,默认是0号GPU
  optional int32 device_id = 18 [default = 0];
  // If non-negative, the seed with which the Solver will initialize the Caffe
  // random number generator -- useful for reproducible results. Otherwise,
  // (and by default) initialize using a seed derived from the system clock.
  optional int64 random_seed = 20 [default = -1];

 
  // type of the solver                   ## 求解器类型=SGD(默认),目前一共有6种
  optional string type = 40 [default = "SGD"];


  // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
  //## RMSProp,AdaGrad和AdaDelta和Adam的数值稳定性
  optional float delta = 31 [default = 1e-8];
  // parameters for the Adam solver       ## Adam类型时的参数
  optional float momentum2 = 39 [default = 0.999];


  // RMSProp decay value                  ##RMSProp的衰减值
  // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
  optional float rms_decay = 38;


  // If true, print information about the state of the net that may help with
  // debugging learning problems.
  //## 此参数默认为false,若为true,则打印网络状态信息,有助于调试问题
  optional bool debug_info = 23 [default = false];


  // If false, don't save a snapshot after training finishes.
  //## 此参数默认为true,若为false,则不会在训练后保存快照
  optional bool snapshot_after_train = 28 [default = true];


  // DEPRECATED: old solver enum types, use string instead ##已经弃用,本来表示6种sovler类型,现在用string type中的string代替
  enum SolverType {
    SGD = 0;
    NESTEROV = 1;
    ADAGRAD = 2;
    RMSPROP = 3;
    ADADELTA = 4;
    ADAM = 5;
  }
  // DEPRECATED: use type instead of solver_type  ##已经弃用,用string type中的type代替
  optional SolverType solver_type = 30 [default = SGD];
}

3、举例说明

我仍以caffe/examples/mnist/lenet_solver.prototxt这个文件为例,下图是我的截图

我把上图的内容复制过来看的清楚一些,并把注释翻译了一下:

---------------这一部分可以对照着2中proto中的描述看你会发现其实solver的编写也就是对着模版填参数的一个过程,----------

# 我们需要的Net的模型,这个模型定义在另一个prototxt文件中,这个就是我上一篇博文举的Net的例子

# 显然这里根据需要你可以选择其他的一些Net

net: "examples/mnist/lenet_train_test.prototxt"

#test_iter 设置了test一共迭代多少次,这里是100

# 至于test每一次迭代处理多少张图片,在Net那个prototxt里面batch_size规定了的

test_iter: 100

# 训练每迭代500次,测试一次(这每一次测试要迭代100次).

test_interval:500

#设置学习率。base_lr用于设置基础学习率,在迭代的过程中,可以对基础学习率进行调整。怎么样进行调整,就是调整的策略,由lr_policy来设置。

#momentum称为动量,使得权重更新更为平缓

#weight_decay称为衰减率因子,防止过拟合的一个参数

base_lr: 0.01

momentum: 0.9

# 这里省略了一个内容 type: SGD  ,就是solver方法的选择,因为默认就是SGD,所以这个 solver. prototxt 中省略没写, 如果你想用其他的sovler方法就要指明写出来

weight_decay:0.0005

# 学习率调整的策略,详细见我下面的补充

lr_policy: "inv"

gamma: 0.0001

power: 0.75

# train每迭代100次就显示一次

display: 100

#train最大迭代次数

max_iter: 10000

#快照。将训练出来的model和solver状态进行保存,snapshot用于设置训练多少次后进行保存,默认为0,不保存。snapshot_prefix设置保存路径。

还可以设置snapshot_diff,是否保存梯度值,默认为false,不保存。

也可以设置snapshot_format,保存的类型。有两种选择:HDF5和BINARYPROTO,默认为BINARYPROTO

#这里设置train每迭代5000次就存储一次数据

snapshot: 5000

snapshot_prefix: "examples/mnist/lenet"

#设置运行模式,默认为GPU,如果你没有GPU,则需要改成CPU,否则会出错.

solver_mode: GPU

4、solver方法

Solver方法就是计算最小化损失值(loss)的方法,也就是我上面解析中说的省略掉的一行,其实一共有6种sovler方法:

· Stochastic Gradient Descent (type: "SGD"),

· AdaDelta (type: "AdaDelta"),

· Adaptive Gradient (type: "AdaGrad"),

· Adam (type: "Adam"),

· Nesterov’s Accelerated Gradient (type: "Nesterov") and

· RMSprop (type: "RMSProp")

默认设置的是SGD 随机梯度下降,所以就可以不写,但是如果想用其他的,就必须要写出来,比如type:Adam

这个方法对于我这种小白来说暂时没有研究的必要,而且SGD方法的数学原理至少我是知道的,所以我这里就只把这几种方法列出来了,没有详细解读,如果有兴趣可以参考下面这篇博客:

http://www.cnblogs.com/denny402/p/5074212.html

这篇关于sovler讲解的博文就写完了,下面一篇就准备来将一下caffe中的hello world--使用Lenet来识别mnist手写数据。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏专知

推荐一些有助于理解TensorFlow机制的资料(二)

1503
来自专栏漫漫深度学习路

pytorch 学习笔记(一)

pytorch是一个动态的建图的工具。不像Tensorflow那样,先建图,然后通过feed和run重复执行建好的图。相对来说,pytorch具有更好的灵活性。...

3516
来自专栏PPV课数据科学社区

Python学习手册:NumPy快速参考表

如果你想用Python做数据分析,那么NumPy是你必须掌握的其中一个基础计算包。它可以很好的替代Python列表,因为NumPy数组更紧凑,允许快速读写访问,...

2797
来自专栏大数据智能实战

tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试

为了更好地掌握GAN的例子,从网上找了段代码进行跑了下,测试了效果。具体过程如下: 代码文件如下: import tensorflow as tf from ...

31510
来自专栏机器之心

从框架优缺点说起,这是一份TensorFlow入门极简教程

2048
来自专栏漫漫深度学习路

tensorflow学习笔记(三十六):learning rate decay

learning rate decay 在训练神经网络的时候,通常在训练刚开始的时候使用较大的learning rate, 随着训练的进行,我们会慢慢的减小le...

3796
来自专栏文武兼修ing——机器学习与IC设计

有基础(Pytorch/TensorFlow基础)mxnet+gluon快速入门mxnet基本数据结构mxnet的数据载入网络搭建模型训练准确率计算模型保存与载入

import numpy as np import mxnet as mx import logging logging.getLogger().setLeve...

7068
来自专栏小鹏的专栏

01 TensorFlow入门(1)

tensorflow_cookbook--第1章 TensorFlow入门         Google的TensorFlow引擎具有独特的解决问题的方法。 ...

23210
来自专栏Coding迪斯尼

详解神经网络算法所需最基础数据结构Tensor及其相关操作

644
来自专栏CSDN技术头条

实战Google深度学习框架:TensorFlow计算加速

要将深度学习应用到实际问题中,一个非常大的问题在于训练深度学习模型需要的计算量太大。比如Inception-v3模型在单机上训练到78%的正确率需要将近半年的时...

3058

扫码关注云+社区