专栏首页AIUAICaffe Data层 - ImageDataLayer

Caffe Data层 - ImageDataLayer

Caffe Data 层 - ImageDataLayer

Caffe 官方提供的直接从 image 文件读取图像数据及对应label.

1. 数据格式及 prototxt 定义

数据格式为:

# train.txt
001.jpg 1
002.jpg 2
003.jpg 3

网络层定义:

# train_val.prototxt
layer {  
  name: "data"  
  type: "ImageData"  
  top: "data"  
  top: "label"  
  include {  
    phase: TRAIN  
  }  
  transform_param {
    mirror: true
    scale: 0.00390625
    #crop_size: 224  
    #mean_value: 128  
    mean_value: 104
    mean_value: 117
    mean_value: 123
  }  
  image_data_param {  
    source: "/path/to/train.txt"  
    root_folder: "/path/to/images"  
    new_height: 224 
    new_width: 224  
    batch_size: 32  
    shuffle: true  
  }  
} 

2. caffe.proto 定义

message LayerParameter {
    optional ImageDataParameter image_data_param = 115;
}
message ImageDataParameter {
  // 指定图片数据 txt 路径.
  optional string source = 1;
  // batch size.
  optional uint32 batch_size = 4 [default = 1];
  // 随机跳过部分数据样本,以避免所有同步 sgd 客户端开始的样本相同.
  // 其中,跳过的点设置为 rand_skip * rand(0,1).
  // rand_skip 小于等于数据集样本数.
  optional uint32 rand_skip = 7 [default = 0];
  // 每个 epoch 后打乱数据顺序
  optional bool shuffle = 8 [default = false];
  // resize 图片到指定的 new_height 和 new_width 尺寸.
  optional uint32 new_height = 9 [default = 0];
  optional uint32 new_width = 10 [default = 0];
  // 图片是彩色还是灰度图 color or gray
  optional bool is_color = 11 [default = true];
  // DEPRECATED. See TransformationParameter. 
  // 数据预处理时,可以进行简单的缩放(scale) 和减均值处理
  // 减均值是在缩放处理前进行.
  optional float scale = 2 [default = 1];
  optional string mean_file = 3;
  // DEPRECATED. See TransformationParameter. 
  // 从图片随机裁剪.
  optional uint32 crop_size = 5 [default = 0];
  // DEPRECATED. See TransformationParameter. 
  // 随机水平翻转.
  optional bool mirror = 6 [default = false];
  optional string root_folder = 12 [default = ""];
}

3. image_data_layer.hpp

# include \<image_data_layer.hpp>

其继承关系图:

公有成员函数:

ImageDataLayer (const LayerParameter &param) DataLayerSetUp (const vector< Blob< Dtype > > &bottom, const vector< Blob< Dtype > > &top) ExactNumBottomBlobs () const // 返回该层的 bottom blobs 数目,如果没有bottom blob,则返回-1. ExactNumTopBlobs () const // 返回该层的 top blobs 数目, 如果没有 top blob,则返回-1.

protected成员函数:

可以被派生类对象访问,不能被用户代码(类外)访问.

ShuffleImages () load_batch (Batch< Dtype > *batch)

Protected Attributes:

prefetch_rng_ lines_ lines_id_

4. image_data_layer.cpp

#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>

#include <fstream>  // NOLINT(readability/streams)
#include <iostream>  // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>

#include "caffe/data_transformer.hpp"
#include "caffe/layers/base_data_layer.hpp"
#include "caffe/layers/image_data_layer.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

namespace caffe {

template <typename Dtype>
ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
  this->StopInternalThread();
}

template <typename Dtype>
void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  const int new_height = this->layer_param_.image_data_param().new_height();
  const int new_width  = this->layer_param_.image_data_param().new_width();
  const bool is_color  = this->layer_param_.image_data_param().is_color();
  string root_folder = this->layer_param_.image_data_param().root_folder();

  CHECK((new_height == 0 && new_width == 0) ||
      (new_height > 0 && new_width > 0)) << "Current implementation requires "
      "new_height and new_width to be set at the same time.";
  // Read the file with filenames and labels
  const string& source = this->layer_param_.image_data_param().source();
  LOG(INFO) << "Opening file " << source;
  std::ifstream infile(source.c_str());
  string line;
  size_t pos;
  int label;
  while (std::getline(infile, line)) {
    pos = line.find_last_of(' ');
    label = atoi(line.substr(pos + 1).c_str());
    lines_.push_back(std::make_pair(line.substr(0, pos), label));
  }

  CHECK(!lines_.empty()) << "File is empty";

  if (this->layer_param_.image_data_param().shuffle()) {
    // randomly shuffle data
    // 随机打乱数据顺序
    LOG(INFO) << "Shuffling data";
    const unsigned int prefetch_rng_seed = caffe_rng_rand();
    prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
    ShuffleImages();
  } else {
    if (this->phase_ == TRAIN && Caffe::solver_rank() > 0 &&
        this->layer_param_.image_data_param().rand_skip() == 0) {
      LOG(WARNING) << "Shuffling or skipping recommended for multi-GPU";
    }
  }
  LOG(INFO) << "A total of " << lines_.size() << " images.";

  lines_id_ = 0;
  // Check if we would need to randomly skip a few data points
  // 随机跳过部分数据
  if (this->layer_param_.image_data_param().rand_skip()) {
    unsigned int skip = caffe_rng_rand() %
        this->layer_param_.image_data_param().rand_skip();
    LOG(INFO) << "Skipping first " << skip << " data points.";
    CHECK_GT(lines_.size(), skip) << "Not enough points to skip";
    lines_id_ = skip;
  }
  // Read an image, and use it to initialize the top blob.
  // 读取图片,并放入 top blob.
  cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
                                    new_height, new_width, is_color);
  CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
  // Use data_transformer to infer the expected blob shape from a cv_image.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
  this->transformed_data_.Reshape(top_shape);
  // Reshape prefetch_data and top[0] according to the batch_size.
  const int batch_size = this->layer_param_.image_data_param().batch_size();
  CHECK_GT(batch_size, 0) << "Positive batch size required";
  top_shape[0] = batch_size;
  for (int i = 0; i < this->prefetch_.size(); ++i) {
    this->prefetch_[i]->data_.Reshape(top_shape);
  }
  top[0]->Reshape(top_shape);

  LOG(INFO) << "output data size: " << top[0]->num() << ","
      << top[0]->channels() << "," << top[0]->height() << ","
      << top[0]->width();
  // label
  // 数据标签
  vector<int> label_shape(1, batch_size);
  top[1]->Reshape(label_shape);
  for (int i = 0; i < this->prefetch_.size(); ++i) {
    this->prefetch_[i]->label_.Reshape(label_shape);
  }
}

template <typename Dtype>
void ImageDataLayer<Dtype>::ShuffleImages() {
  caffe::rng_t* prefetch_rng =
      static_cast<caffe::rng_t*>(prefetch_rng_->generator());
  shuffle(lines_.begin(), lines_.end(), prefetch_rng);
}

// This function is called on prefetch thread
  // 预读取数据线程
template <typename Dtype>
void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
  CPUTimer batch_timer;
  batch_timer.Start();
  double read_time = 0;
  double trans_time = 0;
  CPUTimer timer;
  CHECK(batch->data_.count());
  CHECK(this->transformed_data_.count());
  ImageDataParameter image_data_param = this->layer_param_.image_data_param();
  const int batch_size = image_data_param.batch_size();
  const int new_height = image_data_param.new_height();
  const int new_width = image_data_param.new_width();
  const bool is_color = image_data_param.is_color();
  string root_folder = image_data_param.root_folder();

  // Reshape according to the first image of each batch
  // on single input batches allows for inputs of varying dimension.
  // 读取图像数据
  // 数据维度调整转换
  cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
      new_height, new_width, is_color);
  CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
  // Use data_transformer to infer the expected blob shape from a cv_img.
  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
  this->transformed_data_.Reshape(top_shape);
  // Reshape batch according to the batch_size.
  top_shape[0] = batch_size;
  batch->data_.Reshape(top_shape);

  Dtype* prefetch_data = batch->data_.mutable_cpu_data();
  Dtype* prefetch_label = batch->label_.mutable_cpu_data();

  // datum scales
  const int lines_size = lines_.size();
  for (int item_id = 0; item_id < batch_size; ++item_id) {
    // get a blob
    timer.Start();
    CHECK_GT(lines_size, lines_id_);
    cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,
        new_height, new_width, is_color);
    CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;
    read_time += timer.MicroSeconds();
    timer.Start();
    // Apply transformations (mirror, crop...) to the image
    // 图像处理,如 mirror,crop 等
    int offset = batch->data_.offset(item_id);
    this->transformed_data_.set_cpu_data(prefetch_data + offset);
    this->data_transformer_->Transform(cv_img, &(this->transformed_data_));
    trans_time += timer.MicroSeconds();

    prefetch_label[item_id] = lines_[lines_id_].second;
    // go to the next iter
    lines_id_++;
    if (lines_id_ >= lines_size) {
      // We have reached the end. Restart from the first.
      DLOG(INFO) << "Restarting data prefetching from start.";
      lines_id_ = 0;
      if (this->layer_param_.image_data_param().shuffle()) {
        ShuffleImages();
      }
    }
  }
  batch_timer.Stop();
  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}

INSTANTIATE_CLASS(ImageDataLayer);
REGISTER_LAYER_CLASS(ImageData);

}  // namespace caffe
#endif  // USE_OPENCV

Reference

[1] - caffe::ImageDataLayer

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Tensorflow - tfrecords 文件的创建

    这里主要提供了 Tensorflow 创建 tfrecords 文件的辅助函数,以用于图像分类、检测和关键点定位.

    AIHGF
  • 拉普拉斯矩阵及谱聚类

    拉普拉斯矩阵及谱聚类(Laplacian Matrix and Spectral Clustering)

    AIHGF
  • Caffe2 - (十六) 创建 LMDB 数据库

    AIHGF
  • 【IoT应用创新大赛】LoRaWAN在工业互联网中的应用

    工业互联网,是智能制造发展的基础,可以提供共性的基础设施和能力;我国已经将工业互联网作为重要基础设施,为工业智能化提供支撑。

    Zach_iot
  • 微信小程序 wx.request 的封装

    自学转行到前端也已近两年,也算是简书和掘金的忠实粉丝,但是以前一直惜字如金(实在是胆子小,水平又低),现在我决定视金钱如粪土(就只是脸皮厚了,水平就那样),好了...

    极乐君
  • bootstrap 弹出框 提示框

    <div class="container" style="padding: 100px 50px 10px;" > <button type="but...

    用户5760343
  • what is conversion exit defined in ABAP domain

    我们之前用了这个data element。 UI framework的getter 会自动检测data type的domain上是否维护conversion e...

    Jerry Wang
  • R语言词频统计与词云显示

    治电小白菜
  • 新手学Linux(二)----使用 Vagrant 打造跨平台开发环境(一)

    做Web开发少不了要在本地搭建好开发环境,虽然说目前各种脚本都有对应的Windows版,甚至是一键安装包,但很多时候和Windows环境的相性并不是那么好,各...

    令仔很忙
  • 吴军北京来信:人工智能应该变成通识教育,区块链不是炒概念

    大数据文摘

扫码关注云+社区

领取腾讯云代金券