前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Caffe2 - (二十二) Detectron 之数据集加载与处理函数

Caffe2 - (二十二) Detectron 之数据集加载与处理函数

作者头像
AIHGF
发布2019-02-27 17:53:50
1.2K0
发布2019-02-27 17:53:50
举报
文章被收录于专栏:AIUAIAIUAI

Caffe2 - (二十二) Detectron 之数据集加载与处理函数

Detectron 是基于标准 COCO json 数据集格式进行的.

如果处理新的数据集时,强烈推荐将数据集转化为 COCO json 格式,重用先有数据代码即可.

不推荐重写新数据集格式的代码.

1. 数据集定义 - dataset_catalog.py

代码语言:javascript
复制
"""Collection of available datasets."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os


# Path to data dir
_DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')

# Required dataset entry keys
IM_DIR = 'image_directory'
ANN_FN = 'annotation_file'

# Optional dataset entry keys
IM_PREFIX = 'image_prefix'
DEVKIT_DIR = 'devkit_directory'
RAW_DIR = 'raw_dir'

##
NEW_DATASETS_DIR = '/path/to/new_datasets/Images/'

# 支持的可用数据集
DATASETS = {
    'coco_newdatasets_train': {
        IM_DIR:
            NEW_DATASETS_DIR + '/Img', # 图片路径
        ANN_FN:
            NEW_DATASETS_DIR + '/Anno/coco_newdatasets_train.json', # coco json 格式的标注数据
    },
    'coco_newdatasets_val': {
        IM_DIR:
            NEW_DATASETS_DIR + '/Img',
        ANN_FN:
            NEW_DATASETS_DIR + '/Anno/coco_newdatasets_val.json',
    },
    'cityscapes_fine_instanceonly_seg_train': {
        IM_DIR:
            _DATA_DIR + '/cityscapes/images',
        ANN_FN:
            _DATA_DIR + '/cityscapes/annotations/instancesonly_gtFine_train.json',
        RAW_DIR:
            _DATA_DIR + '/cityscapes/raw'
    },
    'cityscapes_fine_instanceonly_seg_val': {
        IM_DIR:
            _DATA_DIR + '/cityscapes/images',
        # use filtered validation as there is an issue converting contours
        ANN_FN:
            _DATA_DIR + '/cityscapes/annotations/instancesonly_filtered_gtFine_val.json',
        RAW_DIR:
            _DATA_DIR + '/cityscapes/raw'
    },
    'cityscapes_fine_instanceonly_seg_test': {
        IM_DIR:
            _DATA_DIR + '/cityscapes/images',
        ANN_FN:
            _DATA_DIR + '/cityscapes/annotations/instancesonly_gtFine_test.json',
        RAW_DIR:
            _DATA_DIR + '/cityscapes/raw'
    },
    'coco_2014_train': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_train2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/instances_train2014.json'
    },
    'coco_2014_val': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/instances_val2014.json'
    },
    'coco_2014_minival': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/instances_minival2014.json'
    },
    'coco_2014_valminusminival': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/instances_valminusminival2014.json'
    },
    'coco_2015_test': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test2015.json'
    },
    'coco_2015_test-dev': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test-dev2015.json'
    },
    'coco_2017_test': {  # 2017 test uses 2015 test images
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test2017.json',
        IM_PREFIX:
            'COCO_test2015_'
    },
    'coco_2017_test-dev': {  # 2017 test-dev uses 2015 test images
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test-dev2017.json',
        IM_PREFIX:
            'COCO_test2015_'
    },
    'coco_stuff_train': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_train2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/coco_stuff_train.json'
    },
    'coco_stuff_val': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/coco_stuff_val.json'
    },
    'keypoints_coco_2014_train': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_train2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/person_keypoints_train2014.json'
    },
    'keypoints_coco_2014_val': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/person_keypoints_val2014.json'
    },
    'keypoints_coco_2014_minival': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/person_keypoints_minival2014.json'
    },
    'keypoints_coco_2014_valminusminival': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_val2014',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/person_keypoints_valminusminival2014.json'
    },
    'keypoints_coco_2015_test': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test2015.json'
    },
    'keypoints_coco_2015_test-dev': {
        IM_DIR:
            _DATA_DIR + '/coco/coco_test2015',
        ANN_FN:
            _DATA_DIR + '/coco/annotations/image_info_test-dev2015.json'
    },
    'voc_2007_trainval': {
        IM_DIR:
            _DATA_DIR + '/VOC2007/JPEGImages',
        ANN_FN:
            _DATA_DIR + '/VOC2007/annotations/voc_2007_trainval.json',
        DEVKIT_DIR:
            _DATA_DIR + '/VOC2007/VOCdevkit2007'
    },
    'voc_2007_test': {
        IM_DIR:
            _DATA_DIR + '/VOC2007/JPEGImages',
        ANN_FN:
            _DATA_DIR + '/VOC2007/annotations/voc_2007_test.json',
        DEVKIT_DIR:
            _DATA_DIR + '/VOC2007/VOCdevkit2007'
    },
    'voc_2012_trainval': {
        IM_DIR:
            _DATA_DIR + '/VOC2012/JPEGImages',
        ANN_FN:
            _DATA_DIR + '/VOC2012/annotations/voc_2012_trainval.json',
        DEVKIT_DIR:
            _DATA_DIR + '/VOC2012/VOCdevkit2012'
    }
}

2. 数据加载与处理 - json_dataset.py

代码语言:javascript
复制
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import copy
import cPickle as pickle
import logging
import numpy as np
import os
import scipy.sparse

# Must happen before importing COCO API (which imports matplotlib)
import utils.env as envu
#envu.set_up_matplotlib() ###
# COCO API
from pycocotools import mask as COCOmask
from pycocotools.coco import COCO

from core.config import cfg
from datasets.dataset_catalog import ANN_FN
from datasets.dataset_catalog import DATASETS
from datasets.dataset_catalog import IM_DIR
from datasets.dataset_catalog import IM_PREFIX
from utils.timer import Timer
import utils.boxes as box_utils

logger = logging.getLogger(__name__)


class JsonDataset(object):
    """
    A class representing a COCO json dataset.
    """
    def __init__(self, name):
        # 数据初始化
        assert name in DATASETS.keys(), 'Unknown dataset name: {}'.format(name)
        assert os.path.exists(DATASETS[name][IM_DIR]), 'Image directory \'{}\' not found'.format(DATASETS[name][IM_DIR])
        assert os.path.exists(DATASETS[name][ANN_FN]), 'Annotation file \'{}\' not found'.format(DATASETS[name][ANN_FN])
        logger.debug('Creating: {}'.format(name))
        self.name = name
        self.image_directory = DATASETS[name][IM_DIR]
        self.image_prefix = ('' if IM_PREFIX not in DATASETS[name] else DATASETS[name][IM_PREFIX])
        self.COCO = COCO(DATASETS[name][ANN_FN])
        self.debug_timer = Timer()
        # Set up dataset classes
        category_ids = self.COCO.getCatIds()
        categories = [c['name'] for c in self.COCO.loadCats(category_ids)]
        self.category_to_id_map = dict(zip(categories, category_ids))
        self.classes = ['__background__'] + categories
        self.num_classes = len(self.classes)
        self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.COCO.getCatIds()) }
        self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items() }
        self._init_keypoints() # 关键点

    def get_roidb(self, gt=False, proposal_file=None, min_proposal_size=2,
                  proposal_limit=-1, crowd_filter_thresh=0):
        """ 
        返回对应与 json 数据集的 roidb. 包括的处理:
            - 将 ground truth boxes 加入 roidb
            - 添加 proposals 文件中给定的 proposals
            - 基于最小边长度(minimum side length) 过滤 proposals
            - 基于与 crowd 区域交集过滤 proposals
        """
        assert gt is True or crowd_filter_thresh == 0,  'Crowd filter threshold must be 0 if ground-truth annotations are not included.'
        image_ids = self.COCO.getImgIds() # 图片 ids
        image_ids.sort() # 图片ids 排序
        roidb = copy.deepcopy(self.COCO.loadImgs(image_ids)) # 加载 coco json 数据集 
        for entry in roidb:
            self._prep_roidb_entry(entry) # 创建空 roidb
        if gt:
            # 加载 ground-truth object annotations
            self.debug_timer.tic()
            for entry in roidb:
                self._add_gt_annotations(entry)
            logger.debug('_add_gt_annotations took {:.3f}s'.
                         format(self.debug_timer.toc(average=False)))
        if proposal_file is not None:
            # 如果采用 proposal 文件给定 proposals时,从文件加载.
            self.debug_timer.tic()
            self._add_proposals_from_file(roidb, proposal_file, min_proposal_size, proposal_limit, crowd_filter_thresh)
            logger.debug('_add_proposals_from_file took {:.3f}s'.
                         format(self.debug_timer.toc(average=False)) )
        _add_class_assignments(roidb) # 对每个 roidb 元素相关的每个 box 计算 object 类别
        return roidb

    def _prep_roidb_entry(self, entry):
        """
        Adds empty metadata fields to an roidb entry.
        """
        # Reference back to the parent dataset
        entry['dataset'] = self
        # 图片绝对路径
        entry['image'] = os.path.join(self.image_directory, 
                                      self.image_prefix + entry['file_name'])
        entry['flipped'] = False # 原始数据未水平翻转
        entry['has_visible_keypoints'] = False 
        entry['boxes'] = np.empty((0, 4), dtype=np.float32)
        entry['segms'] = []
        entry['gt_classes'] = np.empty((0), dtype=np.int32)
        entry['seg_areas'] = np.empty((0), dtype=np.float32)
        entry['gt_overlaps'] = scipy.sparse.csr_matrix(
            np.empty((0, self.num_classes), dtype=np.float32) )
        entry['is_crowd'] = np.empty((0), dtype=np.bool)
        # 'box_to_gt_ind_map': 大小尺寸为 (#rois). 将每个 roi 映射到 rois 列表中的索引,其满足 np.where(entry['gt_classes'] > 0)
        entry['box_to_gt_ind_map'] = np.empty((0), dtype=np.int32)
        if self.keypoints is not None:
            entry['gt_keypoints'] = np.empty((0, 3, self.num_keypoints), dtype=np.int32 )
        # 移除不相关的标注信息
        for k in ['date_captured', 'url', 'license', 'file_name']:
            if k in entry:
                del entry[k]

    def _add_gt_annotations(self, entry):
        """
        添加 groundtruth 标注数据到一个 roidb entry.
        """
        ann_ids = self.COCO.getAnnIds(imgIds=entry['id'], iscrowd=None)
        objs = self.COCO.loadAnns(ann_ids) # 加载标注数据
        # 净化 bboxes,移除无效的 bboxes
        valid_objs = []
        valid_segms = []
        width = entry['width']
        height = entry['height']
        for obj in objs:
            # crowd regions are RLE encoded and stored as dicts
            if isinstance(obj['segmentation'], list):
                # Valid polygons have >= 3 points, so require >= 6 coordinates
                obj['segmentation'] = [p for p in obj['segmentation'] if len(p) >= 6 ]
            if obj['area'] < cfg.TRAIN.GT_MIN_AREA: # 面积小的 object 丢弃
                continue
            if 'ignore' in obj and obj['ignore'] == 1: 
                continue
            # 将 bbox 标注形式由 (x1, y1, w, h) 转化为 (x1, y1, x2, y2)
            x1, y1, x2, y2 = box_utils.xywh_to_xyxy(obj['bbox'])
            x1, y1, x2, y2 = box_utils.clip_xyxy_to_image(x1, y1, x2, y2, height, width )
            # 确保标注 bboxes 正常,分割 seg 面积大于 0. 
            if obj['area'] > 0 and x2 > x1 and y2 > y1:
                obj['clean_bbox'] = [x1, y1, x2, y2]
                valid_objs.append(obj)
                valid_segms.append(obj['segmentation'])
        num_valid_objs = len(valid_objs) 

        boxes = np.zeros((num_valid_objs, 4), dtype=entry['boxes'].dtype)
        gt_classes = np.zeros((num_valid_objs), dtype=entry['gt_classes'].dtype)
        gt_overlaps = np.zeros((num_valid_objs, self.num_classes),
                               dtype=entry['gt_overlaps'].dtype )
        seg_areas = np.zeros((num_valid_objs), dtype=entry['seg_areas'].dtype)
        is_crowd = np.zeros((num_valid_objs), dtype=entry['is_crowd'].dtype)
        box_to_gt_ind_map = np.zeros((num_valid_objs), dtype=entry['box_to_gt_ind_map'].dtype )
        if self.keypoints is not None:
            gt_keypoints = np.zeros((num_valid_objs, 3, self.num_keypoints),
                                    dtype=entry['gt_keypoints'].dtype )

        im_has_visible_keypoints = False
        for ix, obj in enumerate(valid_objs):
            cls = self.json_category_id_to_contiguous_id[obj['category_id']]
            boxes[ix, :] = obj['clean_bbox']
            gt_classes[ix] = cls
            seg_areas[ix] = obj['area']
            is_crowd[ix] = obj['iscrowd']
            box_to_gt_ind_map[ix] = ix
            if self.keypoints is not None:
                gt_keypoints[ix, :, :] = self._get_gt_keypoints(obj)
                if np.sum(gt_keypoints[ix, 2, :]) > 0:
                    im_has_visible_keypoints = True
            if obj['iscrowd']:
                # Set overlap to -1 for all classes for crowd objects
                # so they will be excluded during training
                gt_overlaps[ix, :] = -1.0
            else:
                gt_overlaps[ix, cls] = 1.0
        entry['boxes'] = np.append(entry['boxes'], boxes, axis=0)
        entry['segms'].extend(valid_segms)
        # To match the original implementation:
        # entry['boxes'] = np.append(
        #     entry['boxes'], boxes.astype(np.int).astype(np.float), axis=0)
        entry['gt_classes'] = np.append(entry['gt_classes'], gt_classes)
        entry['seg_areas'] = np.append(entry['seg_areas'], seg_areas)
        entry['gt_overlaps'] = np.append(entry['gt_overlaps'].toarray(), gt_overlaps, axis=0 )
        entry['gt_overlaps'] = scipy.sparse.csr_matrix(entry['gt_overlaps'])
        entry['is_crowd'] = np.append(entry['is_crowd'], is_crowd)
        entry['box_to_gt_ind_map'] = np.append(entry['box_to_gt_ind_map'], box_to_gt_ind_map )
        if self.keypoints is not None:
            entry['gt_keypoints'] = np.append(entry['gt_keypoints'], gt_keypoints, axis=0 )
            entry['has_visible_keypoints'] = im_has_visible_keypoints

    def _add_proposals_from_file(self, roidb, proposal_file, min_proposal_size, top_k, crowd_thresh):
        """
        从 proposal 文件加载 proposals 到 roidb.
        """
        logger.info('Loading proposals from: {}'.format(proposal_file))
        with open(proposal_file, 'r') as f:
            proposals = pickle.load(f)
        id_field = 'indexes' if 'indexes' in proposals else 'ids'  # compat fix
        _sort_proposals(proposals, id_field) # 根据 id_field 排序 proposals
        box_list = []
        for i, entry in enumerate(roidb):
            if i % 2500 == 0:
                logger.info(' {:d}/{:d}'.format(i + 1, len(roidb)))
            boxes = proposals['boxes'][i]
            # 确保 proposals bboxes 与对应的图片id 相对应.
            assert entry['id'] == proposals[id_field][i]
            # 去除重复 boxes 和非常小的 boxes,并取 top k.
            boxes = box_utils.clip_boxes_to_image(boxes, entry['height'], entry['width'])
            keep = box_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = box_utils.filter_small_boxes(boxes, min_proposal_size)
            boxes = boxes[keep, :]
            if top_k > 0:
                boxes = boxes[:top_k, :]
            box_list.append(boxes)
        _merge_proposal_boxes_into_roidb(roidb, box_list)
        if crowd_thresh > 0:
            _filter_crowd_proposals(roidb, crowd_thresh)

    def _init_keypoints(self):
        """
        初始化 COCO keypoint 标注数据.
        """
        self.keypoints = None
        self.keypoint_flip_map = None
        self.keypoints_to_id_map = None
        self.num_keypoints = 0
        # Thus far only the 'person' category has keypoints
        if 'person' in self.category_to_id_map:
            cat_info = self.COCO.loadCats([self.category_to_id_map['person']])
        else:
            return

        # Check if the annotations contain keypoint data or not
        if 'keypoints' in cat_info[0]:
            keypoints = cat_info[0]['keypoints']
            self.keypoints_to_id_map = dict(
                zip(keypoints, range(len(keypoints))))
            self.keypoints = keypoints
            self.num_keypoints = len(keypoints)
            self.keypoint_flip_map = {
                'left_eye': 'right_eye',
                'left_ear': 'right_ear',
                'left_shoulder': 'right_shoulder',
                'left_elbow': 'right_elbow',
                'left_wrist': 'right_wrist',
                'left_hip': 'right_hip',
                'left_knee': 'right_knee',
                'left_ankle': 'right_ankle'}

    def _get_gt_keypoints(self, obj):
        """
        返回 groudntruth keypoints.
        """
        if 'keypoints' not in obj:
            return None
        kp = np.array(obj['keypoints'])
        x = kp[0::3]  # 0-indexed x coordinates
        y = kp[1::3]  # 0-indexed y coordinates
        # 0: not labeled; 1: labeled, not inside mask;
        # 2: labeled and inside mask
        v = kp[2::3]
        num_keypoints = len(obj['keypoints']) / 3
        assert num_keypoints == self.num_keypoints
        gt_kps = np.ones((3, self.num_keypoints), dtype=np.int32)
        for i in range(self.num_keypoints):
            gt_kps[0, i] = x[i]
            gt_kps[1, i] = y[i]
            gt_kps[2, i] = v[i]
        return gt_kps


def add_proposals(roidb, rois, scales, crowd_thresh):
    """ 
    将只有 groundtruth 标注但没有 proposals 的 proposal boxes(rois) 添加到 roidb.
    如果 proposals 不是原始的图片尺度scale,则指定对应的 scale factor - inv_im_scale.
    """
    box_list = []
    for i in range(len(roidb)):
        inv_im_scale = 1. / scales[i]
        idx = np.where(rois[:, 0] == i)[0]
        box_list.append(rois[idx, 1:] * inv_im_scale)
    _merge_proposal_boxes_into_roidb(roidb, box_list)
    if crowd_thresh > 0:
        _filter_crowd_proposals(roidb, crowd_thresh)
    _add_class_assignments(roidb)


def _merge_proposal_boxes_into_roidb(roidb, box_list):
    """
    将 proposal boxes 添加到每个 roidb entry.
    """
    assert len(box_list) == len(roidb)
    for i, entry in enumerate(roidb):
        boxes = box_list[i]
        num_boxes = boxes.shape[0]
        gt_overlaps = np.zeros((num_boxes, entry['gt_overlaps'].shape[1]),
                               dtype=entry['gt_overlaps'].dtype )
        box_to_gt_ind_map = -np.ones((num_boxes), dtype=entry['box_to_gt_ind_map'].dtype )

        # Note: 这里将所有的 gt rois 都添加到 roidb entry,即使被标注为 crowd 的.
        # 与 crowds 重叠的 boxes 后面采用 _filter_crowd_proposals 进行过滤.
        gt_inds = np.where(entry['gt_classes'] > 0)[0]
        if len(gt_inds) > 0:
            gt_boxes = entry['boxes'][gt_inds, :]
            gt_classes = entry['gt_classes'][gt_inds]
            proposal_to_gt_overlaps = box_utils.bbox_overlaps(
                boxes.astype(dtype=np.float32, copy=False),
                gt_boxes.astype(dtype=np.float32, copy=False) )
            # Gt box that overlaps each input box the most
            # (ties are broken arbitrarily by class order)
            argmaxes = proposal_to_gt_overlaps.argmax(axis=1)
            # Amount of that overlap
            maxes = proposal_to_gt_overlaps.max(axis=1)
            # Those boxes with non-zero overlap with gt boxes
            I = np.where(maxes > 0)[0]
            # Record max overlaps with the class of the appropriate gt box
            gt_overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]
            box_to_gt_ind_map[I] = gt_inds[argmaxes[I]]
        entry['boxes'] = np.append(entry['boxes'], boxes.astype(entry['boxes'].dtype, copy=False), axis=0 )
        entry['gt_classes'] = np.append(entry['gt_classes'], np.zeros((num_boxes), dtype=entry['gt_classes'].dtype) )
        entry['seg_areas'] = np.append(entry['seg_areas'], np.zeros((num_boxes), dtype=entry['seg_areas'].dtype) )
        entry['gt_overlaps'] = np.append(entry['gt_overlaps'].toarray(), gt_overlaps, axis=0 )
        entry['gt_overlaps'] = scipy.sparse.csr_matrix(entry['gt_overlaps'])
        entry['is_crowd'] = np.append(entry['is_crowd'], np.zeros((num_boxes), dtype=entry['is_crowd'].dtype) )
        entry['box_to_gt_ind_map'] = np.append(entry['box_to_gt_ind_map'], box_to_gt_ind_map.astype(entry['box_to_gt_ind_map'].dtype, copy=False ) )


def _filter_crowd_proposals(roidb, crowd_thresh):
    """ 
    寻找在 crowd 区域的 proposals,并标记为 overlap = -1,表示在训练时会被忽略掉.
    """
    for entry in roidb:
        gt_overlaps = entry['gt_overlaps'].toarray()
        crowd_inds = np.where(entry['is_crowd'] == 1)[0]
        non_gt_inds = np.where(entry['gt_classes'] == 0)[0]
        if len(crowd_inds) == 0 or len(non_gt_inds) == 0:
            continue
        crowd_boxes = box_utils.xyxy_to_xywh(entry['boxes'][crowd_inds, :])
        non_gt_boxes = box_utils.xyxy_to_xywh(entry['boxes'][non_gt_inds, :])
        iscrowd_flags = [int(True)] * len(crowd_inds)
        ious = COCOmask.iou(non_gt_boxes, crowd_boxes, iscrowd_flags)
        bad_inds = np.where(ious.max(axis=1) > crowd_thresh)[0]
        gt_overlaps[non_gt_inds[bad_inds], :] = -1
        entry['gt_overlaps'] = scipy.sparse.csr_matrix(gt_overlaps)


def _add_class_assignments(roidb):
    """
    计算与每个 roidb entry 相关的每个 box 的 object 类别.
    """
    for entry in roidb:
        gt_overlaps = entry['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        entry['max_classes'] = max_classes
        entry['max_overlaps'] = max_overlaps
        # 合理性检查
        # 如果 max overlap 是 0,则对应的 class 必须是 background (classid = 0)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # 如果 max overlap > 0, 则对应的 class 必须是某个 fg class (not class 0)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)


def _sort_proposals(proposals, id_field):
    """
    根据指定的 id_field 将proposals 排序.
    """
    order = np.argsort(proposals[id_field])
    fields_to_sort = ['boxes', id_field, 'scores']
    for k in fields_to_sort:
        proposals[k] = [proposals[k][i] for i in order]
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年03月26日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Caffe2 - (二十二) Detectron 之数据集加载与处理函数
    • 1. 数据集定义 - dataset_catalog.py
      • 2. 数据加载与处理 - json_dataset.py
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档