Caffe学习笔记(五):使用pycaffe生成solver.prototxt文件并进行训练

Python版本: Python2.7 运行平台: Ubuntu14.04

    上几篇笔记记录了如何将图片数据转换成db(leveldb/lmdb)文件,计算图片数据的均值,train.prototxt和test.prototxt文件的编写。本篇笔记主要记录如何生成sovler文件,solver文件是训练的时候,需要用到的prototxt文件,它指明了train.prototxt和test.prototxt或train_test.prototxt。solver就是用来是loss最小化的优化方法。

一、solver.prototxt参数说明

    依然是以cifar10_quick_solver.prototxt为例,内容如下:

# reduce the learning rate after 8 epochs (4000 iters) by a factor of 10

# The train/test net protocol buffer definition
net: "examples/cifar10/cifar10_quick_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.001
momentum: 0.9
weight_decay: 0.004
# The learning rate policy
lr_policy: "fixed"
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 4000
# snapshot intermediate results
snapshot: 4000
snapshot_format: HDF5
snapshot_prefix: "examples/cifar10/cifar10_quick"
# solver mode: CPU or GPU
solver_mode: GPU

    这些参数,都是有根据进行设置的,从上到下依次进行说明:

  • net:指定配置文件,cifar10_quick_solver.prototx文件中指定的prototxt文件为examples/cifar10/cifar10_quick_train_test.prototxt,可以使用train_net和test_net分别指定。
  • test_iter:测试迭代数。例如:有10000个测试样本,batch_size设为32,那么就需要迭代 10000/32=313次才完整地测试完一次,所以设置test_iter为313。
  • test_interval:每训练迭代test_interval次进行一次测试,例如50000个训练样本,batch_size为64,那么需要50000/64=782次才处理完一次全部训练样本,记作1 epoch。所以test_interval设置为782,即处理完一次所有的训练数据后,才去进行测试。
  • base_lr:基础学习率,学习策略使用的参数。
  • momentum:动量。
  • weight_decay:权重衰减。
  • lr_policy:学习策略。可选参数:fixed、step、exp、inv、multistep。

lr_prolicy参数说明:

  • fixed: 保持base_lr不变;
  • step: step: 如果设置为step,则需要设置一个stepsize,返回base_lr * gamma ^ (floor(iter / stepsize)),其中iter表示当前的迭代次数;
  • exp: 返回base_lr * gamma ^ iter,iter为当前的迭代次数;
  • inv: 如何设置为inv,还需要设置一个power,返回base_lr * (1 + gamma * iter) ^ (- power);
  • multistep: 如果设置为multistep,则还需要设置一个stepvalue,这个参数和step相似,step是均匀等间隔变化,而multistep则是根据stepvalue值变化;
  • stepvalue参数说明: poly: 学习率进行多项式误差,返回base_lr (1 - iter/max_iter) ^ (power); sigmoid: 学习率进行sigmod衰减,返回base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))。
  • display:每迭代display次显示结果。
  • max_iter:最大迭代数,如果想训练100 epoch,则需要设置max_iter为100*test_intervel=78200。
  • snapshot:保存临时模型的迭代数。
  • snapshot_format:临时模型的保存格式。有两种选择:HDF5 和BINARYPROTO ,默认为BINARYPROTO
  • snapshot_prefix:模型前缀,就是训练好生成model的名字。不加前缀为iter_迭代数.caffemodel,加之后为lenet_iter_迭代次数.caffemodel。
  • solver_mode:优化模式。可以使用GPU或者CPU。

二、使用python生成solver.prototxt文件

    以分析的cifar10_quick_solver.prototxt文件为例,使用python程序,生成这个文件。

1.代码如下:

# -*- coding: UTF-8 -*-
import caffe                                                     #导入caffe包

def write_sovler():
    my_project_root = "/home/Jack-Cui/caffe-master/my-caffe-project/"        #my-caffe-project目录
    sovler_string = caffe.proto.caffe_pb2.SolverParameter()                    #sovler存储
    solver_file = my_project_root + 'solver.prototxt'                        #sovler文件保存位置
    sovler_string.train_net = my_project_root + 'train.prototxt'            #train.prototxt位置指定
    sovler_string.test_net.append(my_project_root + 'test.prototxt')         #test.prototxt位置指定
    sovler_string.test_iter.append(100)                                        #测试迭代次数
    sovler_string.test_interval = 500                                        #每训练迭代test_interval次进行一次测试
    sovler_string.base_lr = 0.001                                            #基础学习率   
    sovler_string.momentum = 0.9                                            #动量
    sovler_string.weight_decay = 0.004                                        #权重衰减
    sovler_string.lr_policy = 'fixed'                                        #学习策略           
    sovler_string.display = 100                                                #每迭代display次显示结果
    sovler_string.max_iter = 4000                                            #最大迭代数
    sovler_string.snapshot = 4000                                             #保存临时模型的迭代数
    sovler_string.snapshot_format = 0                                        #临时模型的保存格式,0代表HDF5,1代表BINARYPROTO
    sovler_string.snapshot_prefix = 'examples/cifar10/cifar10_quick'        #模型前缀
    sovler_string.solver_mode = caffe.proto.caffe_pb2.SolverParameter.GPU    #优化模式

    with open(solver_file, 'w') as f:
        f.write(str(sovler_string))   

if __name__ == '__main__':
    write_sovler()

2.运行结果:

三、训练模型

    从第一篇笔记至此,我们已经了解到如何将jpg图片转换成Caffe使用的db(levelbd/lmdb)文件,如何计算数据均值,如何使用python生成solver.prototxt、train.prototxt、test.prototxt文件。接下来,就可以进行训练的最后一步,使用caffe提供的python接口训练生成模型。如果不进行可视化,只想得到一个最终的训练model,可以使用如下代码:

import caffe

my_project_root = "/home/Jack-Cui/caffe-master/my-caffe-project/"        #my-caffe-project目录
solver_file = my_project_root + 'solver.prototxt'                        #sovler文件保存位置
caffe.set_device(0)                                                      #选择GPU-0
caffe.set_mode_gpu()
solver = caffe.SGDSolver(solver_file)
solver.solve()

     现在,如何训练生成模型的简单步骤已经讲完。接下来,以mnist实例,整合所学内容,训练生成model,并使用生成的model进行预测。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏生信宝典

CIRCOS增加热图、点图、线图和区块属性

CIRCOS图在有了染色体信息界定绘图区域后,就可以向里面添加离散数据如标记特定的区域或连续数据如展示修饰的丰度等。 经过前面部分对CIRCOS基本安装,最简单...

2137
来自专栏素质云笔记

图像增强︱window7+opencv3.2+keras/theano简单应用(函数解读)

在服务器上安装opencv遇到跟CUDA8.0不适配的问题,于是不得不看看其他机器是否可以预装并使用。 . 一、python+opencv3.2安装 ope...

31310
来自专栏简书专栏

基于Excel2013的数据转换和清洗

数字可以被设成的格式有12种:常规、数值、货币、会计专用、日期、时间、百分比、分数、科学记数、文本、特殊、自定义

972
来自专栏简书专栏

基于xgboost的风力发电机叶片结冰分类预测

xgboost中文叫做极致梯度提升模型,官方文档链接:https://xgboost.readthedocs.io/en/latest/tutorials/mo...

441
来自专栏机器学习实践二三事

Tensorflow实现word2vec

大名鼎鼎的word2vec,相关原理就不讲了,已经有很多篇优秀的博客分析这个了. 如果要看背后的数学原理的话,可以看看这个: https://wenku.b...

2657
来自专栏Jack-Cui

Python3《机器学习实战》学习笔记(一):k-近邻算法(史诗级干货长文)

运行平台: Windows Python版本: Python3.x IDE: Sublime text3 一 简单k-近邻算法     本文将从k-邻近算法...

8437
来自专栏瓜大三哥

直方图操作(一)

如果要对图像分辨率为640x512位宽的图像进行直方图统计,则有 AWDPRAM≥8 DWDPRAM≥log2(Pixelttotal)=log2(640x51...

1908
来自专栏Petrichor的专栏

tensorflow: 畅玩tensorboard图表(SCALARS)

这篇博客建立在你已经会使用tensorboard的基础上。如果你还不会记录数据并使用tensorboard,请移步我之前的另一篇博客:tensorflow: t...

1003
来自专栏有趣的Python和你

机器学习实战之KNN算法

本系列教程为《机器学习实战》的读书笔记。首先,讲讲写本系列教程的原因:第一,《机器学习实战》的代码由Python2编写,有些代码在Python3上运行已会报错,...

1315
来自专栏mathor

LeetCode258. 各位相加

 ab = (a*10+b)  ab%9 = (a*9+a+b)%9 = (a+b)%9  abc = (a*100+b*10+c)  abc%9 = ...

871

扫码关注云+社区