前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >caffe详解之工具篇

caffe详解之工具篇

作者头像
AI异构
发布2020-07-29 11:02:24
5630
发布2020-07-29 11:02:24
举报
文章被收录于专栏:AI异构AI异构

从零开始,一步一步学习caffe的使用,期间贯穿深度学习和调参的相关知识!

数据格式LMDB文件制作

convert_imageset是将我们准备的数据集文件转换为caffe接口更快读取的LMDBHDF5数据类型。

文件结构
代码语言:javascript
复制
|____face-lmdb.sh
|____train.txt
|____train
| |____0 // 存放有人脸图片
| |____1 // 存放无人脸图片
|____val.txt
|____val //直接存放待测试图片
train.txt文件格式
代码语言:javascript
复制
1/23039_nonface_0image30595.jpg 1
0/34245_faceimage06120.jpg 0
0/10649_faceimage50035.jpg 0
...
val.txt文件格式
代码语言:javascript
复制
39830_nonface_0image49143.jpg 1
22467_nonface_0image15785.jpg 1
29202_nonface_1image71337.jpg 1
4629_nonface_1image07224.jpg 1
20093_faceimage20295.jpg 0
18655_faceimage09968.jpg 0
...
转换脚本文件
代码语言:javascript
复制
#!/usr/bin/env sh
# Create the face_48 lmdb inputs
# N.B. set the path to the face_48 train + val data dirs

EXAMPLE=/home/xuke/data/face_detect  #项目位置
DATA=/home/xuke/data/face_detect     #数据项目位置
TOOLS=/home/xuke/caffe/build/tools   #caffe tool目录位置

TRAIN_DATA_ROOT=/home/xuke/data/face_detect/train/ #训练数据集根目录
VAL_DATA_ROOT=/home/xuke/data/face_detect/val/ #测试数据集根目录

# Set RESIZE=true to resize the images to 60 x 60. Leave as false if images have
# already been resized using another tool.
RESIZE=true  #是否进行图片尺寸变换
if $RESIZE; then
  RESIZE_HEIGHT=227
  RESIZE_WIDTH=227
else
  RESIZE_HEIGHT=0
  RESIZE_WIDTH=0
fi

if [ ! -d "$TRAIN_DATA_ROOT" ]; then
  echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
  echo "Set the TRAIN_DATA_ROOT variable in create_face_48.sh to the path" \
       "where the face_48 training data is stored."
  exit 1
fi

if [ ! -d "$VAL_DATA_ROOT" ]; then
  echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
  echo "Set the VAL_DATA_ROOT variable in create_face_48.sh to the path" \
       "where the face_48 validation data is stored."
  exit 1
fi

echo "Creating train lmdb..."

GLOG_logtostderr=1 $TOOLS/convert_imageset \
    --resize_height=$RESIZE_HEIGHT \
    --resize_width=$RESIZE_WIDTH \
    --shuffle \
    $TRAIN_DATA_ROOT \
    $DATA/train.txt \
    $EXAMPLE/face_train_lmdb

echo "Creating val lmdb..."

GLOG_logtostderr=1 $TOOLS/convert_imageset \
    --resize_height=$RESIZE_HEIGHT \
    --resize_width=$RESIZE_WIDTH \
    --shuffle \
    $VAL_DATA_ROOT \
    $DATA/val.txt \
    $EXAMPLE/face_val_lmdb

echo "Done."
Status API Training Shop Blog About

生成多标签数据格式HDF5

前面讲到的是LMDB,在Caffe中,如果使用LMDB数据格式的话,默认是只支持“图像+整数单标签”这种形式的数据的如果训练网络需要一些其他形式的数据或标签(如浮点数据,多标签等等),可以将其制作成HDF5格式HDF5数据格式比较灵活,但缺点是占用空间较大。将229多张512x512的图像制作成一个HDF5文件,能达到1.4GB。 因此建议图像文件的话,最好还是用LMDB格式,快速且节省空间。因此建议的方法是将图像存储为LMDB格式,而多标签存储为HDF5格式

图像数据转换为LMDB格式

制作流程与前面思路一致,准备图像文件名列表list.txt,需要注意的是因为我们将多标签的值与图像的存储分开,对于图像我们可以不写对应的Label值。convert_imageset会在运行时把每个Label都赋为0.

多标签数据格式

定义name_label.txt文件,文件的格式如下所示:

代码语言:javascript
复制
8/ballet_106_104.jpg 010011101010011110000010111100001001010100000
8/ballet_106_13.jpg 110000101100001010000010111110001111011100000
8/ballet_106_20.jpg 011100101000001010000010001110000001010100100
8/ballet_106_28.jpg 010110101100001010000000101110001101010000000
8/ballet_106_35.jpg 110110101010001110001000011100001111010000100
8/ballet_106_42.jpg 011110111000001001000100011110001111010100001
8/ballet_106_5.jpg 010101101100101010100101101100011101010000101
8/ballet_106_57.jpg 010000001000001010000010100010001101001100100

注意,list.txt中的文件名一定要和name_label.txt中标签一一对应。这样的话,hdf5_train.h5里面就储存了所有图像对应的标签,每个标签包含多个0或1的值。

转换脚本文件
代码语言:javascript
复制
import h5py
import numpy as np
import os

label_dim = 45

# 存放标签值的文件
list_txt = 'name_label.txt'
# 要生成的HDFS文件名
hdf5_file_name = 'hdf5_train.h5'

with open(list_txt, 'r') as f:
    lines = f.readlines()
    samples_num = len(lines)

    # 此处可以指定数据类型,如 dtype=float
    labels = np.zeros((len(lines), label_dim))

    for index in range(samples_num):
        img_name, label = lines[index].strip().split()
        label_int = [int(l) for l in label]
        labels[index, :] = label_int

    # 将标签数据写入hdf5文件
    h5_file = h5py.File(hdf5_file_name, 'w')
    # 此处'multi_label'和网络定义文件中HDF5Data层的top名字是一致的
    h5_file.create_dataset('multi_label', data=labels)
    h5_file.close()

print 'Complete.'
网络定义

采用两个数据层,一个Data层用于读取图像数据,一个HDF5Data层,用于读取图像对应的多Lable数据。具体Net定义如下:

代码语言:javascript
复制
name: "TRAIN_NET"

layer {
  name: "data"
  type: "Data"
  top: "data"
  top: "label"
  data_param {
    source: "lmdb/train_lmdb"
    backend: LMDB
    batch_size: 1
  }
  transform_param {
    mirror: true
    mean_value: 104.0
    mean_value: 117.0
    mean_value: 123.0
  }
}

layer {
  name: "multi_label"
  type: "HDF5Data"
  top: "multi_label"
  hdf5_data_param {
    source: "hdf5_train_list.txt"
    batch_size: 1
  }
}

计算数据集均值

通过caffetool文件夹中自带的工具compute_image_mean可以生成我们需要的均值文件。

代码语言:javascript
复制
sudo /home/xuke/caffe/build/tools/compute_image_mean \
/home/xuke/caffe/examples/mnist/mnist_train_lmdb \
/home/xuke/caffe_case/mean.binaryproto
  • 寻找compute_image_mean工具
  • 锁定训练数据集目标mnist_train_lmdb
  • 生成均值文件的路径及名称

绘制网络结构

使用caffe Python接口中的draw_net.py工具,将设计出的网络模型结构进行可视化。另外推荐一下网页版在线的caffe结构可视化工具——netscope(https://ethereon.github.io/netscope/#/editor)

安装graphViz
代码语言:javascript
复制
sudo apt-get install graphviz
安装pydot
代码语言:javascript
复制
sudo pip install pydot
绘制网络指令
代码语言:javascript
复制
sudo python /home/xuke/caffe/python/draw_net.py \
/home/xuke/caffe/examples/mnist/lenet_train_test.prototxt \
/home/xuke/caffe_case/lenet.png --rankdir=BT
  • 第一个参数:网络模型的prototxt文件
  • 第二个参数:保存的图片路径及名字
  • 第三个参数:--rankdir=x , x有四种选项,分别是LR, RL, TB, BT 。用来表示网络的方向,分别是从左到右,从右到左,从上到小,从下到上。默认为LR
结果

绘制Loss曲线图

通过matplotlibcaffe训练过程中的loss值与Accuracy值进行图形绘制,便于查看模型训练结果。

加载库文件设置路径
代码语言:javascript
复制
#加载必要的库
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import sys,os,caffe
#设置当前目录
caffe_root = '/home/xuke/caffe/'
sys.path.insert(0, caffe_root + 'python')
os.chdir(caffe_root)
设置求解器
代码语言:javascript
复制
caffe.set_mode_cpu()
solver = caffe.SGDSolver('/home/xuke/caffe/examples/mnist/lenet_solver.prototxt')
训练并保存Loss与Accuracy值
代码语言:javascript
复制
niter =1000
test_interval = 200
train_loss = np.zeros(niter)
test_acc = np.zeros(int(np.ceil(niter / test_interval)))
# the main solver loop
for it in range(niter):
    solver.step(1)  # SGD by Caffe

    # store the train loss
    train_loss[it] = solver.net.blobs['loss'].data
    solver.test_nets[0].forward(start='conv1')

    if it % test_interval == 0:
        acc=solver.test_nets[0].blobs['accuracy'].data
        print 'Iteration', it, 'testing...','accuracy:',acc
        test_acc[it // test_interval] = acc
代码语言:javascript
复制
Iteration 0 testing... accuracy: 0.109999999404
Iteration 200 testing... accuracy: 0.899999976158
Iteration 400 testing... accuracy: 0.949999988079
Iteration 600 testing... accuracy: 0.959999978542
Iteration 800 testing... accuracy: 0.949999988079
绘制Loss及Accuracy曲线
代码语言:javascript
复制
print(test_acc)
_, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(np.arange(niter), train_loss)
ax2.plot(test_interval * np.arange(len(test_acc)), test_acc, 'r')
ax1.set_xlabel('iteration')
ax1.set_ylabel('train loss')
ax2.set_ylabel('test accuracy')
plt.show()

代码语言:javascript
复制
[ 0.11        0.89999998  0.94999999  0.95999998  0.94999999]

参考

Caffe中使用HDF5制作多标签数据 http://blog.csdn.net/u011321962/article/details/77868348

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-02-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI异构 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 数据格式LMDB文件制作
    • 文件结构
      • train.txt文件格式
        • val.txt文件格式
          • 转换脚本文件
          • 生成多标签数据格式HDF5
            • 图像数据转换为LMDB格式
              • 多标签数据格式
                • 转换脚本文件
                  • 网络定义
                  • 计算数据集均值
                  • 绘制网络结构
                    • 安装graphViz
                      • 安装pydot
                        • 绘制网络指令
                          • 结果
                          • 绘制Loss曲线图
                            • 加载库文件设置路径
                              • 设置求解器
                                • 训练并保存Loss与Accuracy值
                                  • 绘制Loss及Accuracy曲线
                                  • 参考
                                  领券
                                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档