前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Caffe 实践 - 基于 ResNet101 的 Multi-label 多标签标注的训练与部署

Caffe 实践 - 基于 ResNet101 的 Multi-label 多标签标注的训练与部署

作者头像
AIHGF
发布2019-02-27 17:37:41
1.1K0
发布2019-02-27 17:37:41
举报
文章被收录于专栏:AIUAIAIUAI

Caffe 实践 - 基于 ResNet101 的 Multi-label 多标签标注的训练与部署

以前曾尝试过修改 Caffe ImageDataLayer 源码的方式来读取多个 labels - ImageMultilabelDataLayer [Caffe实践 - 基于VGG16 多标签分类的训练与部署].

修改源码的方式可能显得稍微有点繁琐, 毕竟需要重新编译.

这里尝试了一种新的方式来进行多标签自动标注.

与 [Caffe实践 - 基于VGG16 多标签分类的训练与部署] 不同的是, 前者是以 Multi-task 的方式进行处理的,每一个 task 分别是一个label的分类问题. 而这里是以多标签标注方式进行的.

1. 数据集

1.1 数据格式转换

数据形式如, images_labels.txt:

代码语言:javascript
复制
img1.jpg 1 0 1 ... 0
img2.jpg 0 1 0 ... 1
img3.jpg 1 1 0 ... 0
......

每一行是一个数据样本, 其对应的 multilabels 为一个 01 向量的形式. 且所有的样本所对应的 multilabels 向量长度是相同的.

可以类似于 [Caffe实践 - 基于VGG16 多标签分类的训练与部署] 来读取 multilabels 数据.

但这里采用了另外的一种方式:

首先将数据转换为两个文件, 其内容格式分别为:

  • imageslist.txt img1.jpg 2 img1.jpg 5 img1.jpg 3 ...... 每一行对应一个数据样本, 图片 + 标签labels数 的形式.
  • labelslist.txt 1 0 1 ... 0 0 1 0 ... 1 1 1 0 ... 0 ...... 每一行对应与一个数据样本的 multilabels. 01 向量.

1.2 生成 lmdb 数据

gen_label_lmdb.py:

代码语言:javascript
复制
#! --*-- coding: utf-8 --*--
import numpy
import argparse

from caffe.proto import caffe_pb2
import lmdb


def parse_args():
    parser = argparse.ArgumentParser(description='End-to-end inference')
    parser.add_argument('--labels', dest='labels',
                        help='label txt file',
                        default=None, type=str )
    parser.add_argument('--images', dest='images',
                        help='image txt file, for keys of datum',
                        default=None, type=str )
    parser.add_argument('--lmdb', dest='lmdb',
                        help='label lmdb file',
                        default=None, type=str )
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    return parser.parse_args()


def array_to_datum(arr, label=0):
    """
    Converts a 3-dimensional array to datum. 
    If the array has dtype uint8,
    the output data will be encoded as a string. 
    Otherwise, the output data will be stored in float format.
    """
    if arr.ndim != 3:
        raise ValueError('Incorrect array shape.')
    datum = caffe_pb2.Datum()
    datum.channels, datum.height, datum.width = arr.shape
    if arr.dtype == numpy.uint8:
        datum.data = arr.tostring()
    else:
        datum.float_data.extend(arr.flat)
    datum.label = label
    return datum


if __name__ == '__main__':
    print 'Starting Generate lmdb'
    args = parse_args()

    labels_array = numpy.loadtxt(args.labels)
    num_samples, num_classes = labels_array.shape
    images_keys_list = open(args.images).readlines()
    assert num_samples == len(images_keys_list)

    labels_db = lmdb.open(args.lmdb, map_size=20 * 1024 * 1024 * 1024)
    with labels_db.begin(write=True) as txn:
        for idx in range(0, num_samples):
            # label data
            labels_cur = labels_array[idx,:]
            num_labels = labels_cur.size
            labels_cur = labels_cur.reshape(num_labels, 1, 1)
            # keys for label
            key_cur = images_keys_list[idx].split(' ')[0] # image name
            key_cur = "{:0>8d}".format(idx) + '_' + key_cur
            num_taggs = int(images_keys_list[idx].split(' ')[1])
            assert num_taggs == labels_cur.sum() 
            # create datum
            labels_datum = array_to_datum(labels_cur, num_taggs)
            # write datum to lmdb
            txn.put(key_cur.encode('ascii'), labels_datum.SerializeToString())
            if (idx+1)%5000 == 0:
                print 'Processed ', idx + 1, 'files in total.'

    num_label_db = 0
    with labels_db.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            n_label_db = num_label_db + 1
    print 'Total # of item in label lmdb:', num_label_db

命令行运行:

代码语言:javascript
复制
python gen_label_lmdb.py --labels ./data/labelslist.txt --images ./data/imageslist.txt --lmdb ./data/output_lmdb/

2. 模型定义与训练

2.1 train_val.prototxt

代码语言:javascript
复制
name: "resnet101"

layer {
  name: "data"
  type: "ImageData"
  top: "data"
  top: "dummylabel"
  image_data_param {
    source: "/path/to/train_imagelist.txt"
    root_folder: "/path/to/images/"
    batch_size: 16
    new_height: 256
    new_width: 256
  }
  transform_param {
    mirror: true
    crop_size: 224
    mean_value: 104
    mean_value: 117
    mean_value: 123
  }
  include {
    phase: TRAIN
  }
}
layer {
  name: "label"
  type: "Data"
  top: "label"
  data_param {
    source: "/path/to/train_output_lmdb"
    batch_size: 16
    backend: LMDB
  }
  include {
    phase: TRAIN
  }
}

##
layer {
  name: "data"
  type: "ImageData"
  top: "data"
  top: "dummylabel"
  image_data_param {
    source: "/path/to/test_imagelist.txt"
    root_folder: "/path/to/images/"
    batch_size: 4
    new_height: 256
    new_width: 256
  }
  transform_param {
    crop_size: 224
    mean_value: 104
    mean_value: 117
    mean_value: 123
  }
  include {
    phase: TEST
  }
}
layer {
  name: "label"
  type: "Data"
  top: "label"
  data_param {
    source: "/path/to/test_output_lmdb"
    batch_size: 4
    backend: LMDB
  }
  include {
    phase: TEST
  }
}
####### dummylabel #######
# 这里 dummpylabe 是未用到的
layer {
  name: "label_silence"
  type: "Silence"
  bottom: "dummylabel"
}

#### netword define ####

......

####
layer {
  name: "fc_labels"
  type: "InnerProduct"
  bottom: "Pooling_"
  top: "fc_labels"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  inner_product_param {
    num_output: 30 # labels 数
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

layer {
  name: "loss"
  type: "SigmoidCrossEntropyLoss" # 
  bottom: "fc_labels"
  bottom: "label"
  top: "loss"
}

2.2 deploy.prototxt

代码语言:javascript
复制
name: "resnet101"
input: "data"
input_shape {
  dim: 1
  dim: 3
  dim: 224
  dim: 224
}

#### netword define ####

......

####
layer {
  name: "fc_labels"
  type: "InnerProduct"
  bottom: "Pooling_"
  top: "fc_labels"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  inner_product_param {
    num_output: 30 # labels 数
    weight_filler {
      type: "gaussian"
      std: 0.005
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}

3. 模型部署

代码语言:javascript
复制
#!/usr/bin/env python
# --*-- coding: utf-8 --*--
import numpy as np
from PIL import Image
import scipy.misc
import matplotlib.pyplot as plt

import sys
caffe_root = '/path/to/caffe/'
sys.path.insert(0, caffe_root + 'python')
import caffe

caffe.set_mode_gpu()
caffe.set_device(0)
# caffe.set_mode_cpu()

class SimpleTransformer(object):

    """
    SimpleTransformer is a simple class for preprocessing and deprocessing
    images for caffe.
    """

    def __init__(self, mean=[128, 128, 128]):
        self.mean = np.array(mean, dtype=np.float32)
        self.scale = 1.0

    def set_mean(self, mean):
        """
        Set the mean to subtract for centering the data.
        """
        self.mean = mean

    def set_scale(self, scale):
        """
        Set the data scaling.
        """
        self.scale = scale

    def preprocess(self, im):
        """
        preprocess() emulate the pre-processing occuring in the vgg16 caffe
        prototxt.
        """

        im = np.float32(im)
        im = im[:, :, ::-1]  # change to BGR
        im -= self.mean
        im *= self.scale
        im = im.transpose((2, 0, 1))

        return im


if __name__ == '__main__':
    print 'Start...'
    multilabels = ['a', 'b', 'c', ...]

    test_image = '/home/sh/Pictures/upper/10.jpg'
    im = np.asarray(Image.open(test_image))
    im = scipy.misc.imresize(im, [224, 224])

    model_def = '/path/to/deploy.prototxt'
    weight_def = '/path/to/multilabel_vgg16_iter_100000.caffemodel'
    net = caffe.Net(model_def, weight_def, caffe.TEST)

    transformer = ImageTransformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2, 0, 1))    # h*w*c -> c*h*w
    transformer.set_mean('data', np.array((104, 117, 123)))  # mean pixel
    transformer.set_raw_scale('data', 255)  # the net operates on images in [0,255]
    transformer.set_channel_swap('data', (2, 1, 0))  # RGB -> BGR

    transformed_image = transformer.preprocess(im)
    net.blobs['data'].data[...] = transformed_image
    outputs = net.forward()

    img_scores = outputs[net.outputs[0]]
    img_scores = np.array([1./(1+np.exp(-tmp)) for tmp in img_scores[0]])

    pred_labels_index = np.where(img_scores > 0.5)[0].tolist()
    print 'pred_labels:', pred_labels_index, '----', len(pred_labels_index)
    for tmp in pred_labels_index:
        print 'pred_labels:', multilabels[tmp], '----', len(pred_labels_index)
    print '--------------------------------------------'

    plt.imshow(im)
    plt.axis('off')
    plt.show()
print 'Done.'

Related

[1] - Caffe实践 - 基于VGG16 多标签分类的训练与部署

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年05月08日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Caffe 实践 - 基于 ResNet101 的 Multi-label 多标签标注的训练与部署
    • 1. 数据集
      • 1.1 数据格式转换
        • 1.2 生成 lmdb 数据
      • 2. 模型定义与训练
        • 2.1 train_val.prototxt
        • 2.2 deploy.prototxt
      • 3. 模型部署
        • Related
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档