专栏首页AutoML(自动机器学习)Detectron2源码阅读笔记-(三)Dataset
原创

Detectron2源码阅读笔记-(三)Dataset

构建data_loader原理步骤

# engine/default.py
from detectron2.data import (
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader,
)
class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        ...    
    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        """
        return build_detection_train_loader(cfg)

函数调用关系如下图:

结合前面两篇文章的内容可以看到detectron2在构建model,optimizer和data_loader的时候都是在对应的build.py文件里实现的。我们看一下build_detection_train_loader是如何定义的(对应上图中紫色方框内的部分(自下往上的顺序)):

def build_detection_train_loader(cfg, mapper=None):
    """
    A data loader is created by the following steps:

    1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
    2. Start workers to work on the dicts. Each worker will:
      * Map each metadata dict into another format to be consumed by the model.
      * Batch them by simply putting dicts into a list.
    The batched ``list[mapped_dict]`` is what this dataloader will return.

    Args:
        cfg (CfgNode): the config
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            By default it will be `DatasetMapper(cfg, True)`.

    Returns:
        a torch DataLoader object
    """
	# 获得dataset_dicts
    dataset_dicts = get_detection_dataset_dicts(
        cfg.DATASETS.TRAIN,
        filter_empty=True,
        min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
        if cfg.MODEL.KEYPOINT_ON
        else 0,
        proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
    )
	
	# 将dataset_dicts转化成torch.utils.data.Dataset
    dataset = DatasetFromList(dataset_dicts, copy=False)

	# 进一步转化成MapDataset,每次读取数据时都会调用mapper来对dict进行解析
    if mapper is None:
        mapper = DatasetMapper(cfg, True)
    dataset = MapDataset(dataset, mapper)
	
	# 采样器
    sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
    if sampler_name == "TrainingSampler":
        sampler = samplers.TrainingSampler(len(dataset))
		...
    batch_sampler = build_batch_data_sampler(
        sampler, images_per_worker, group_bin_edges, aspect_ratios
    )
	
	# 数据迭代器 data_loader
    data_loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        batch_sampler=batch_sampler,
        collate_fn=trivial_batch_collator,
        worker_init_fn=worker_init_reset_seed,
    )
    return data_loader

由上面的源代码可以看出总共是五个步骤,我们只对前面三个部分进行详细介绍,后面的采样器和data_loader可以参阅一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

获得dataset_dicts

get_detection_dataset_dicts(dataset_names)函数需要传递的一个重要参数是dataset_names,这个参数其实就是一个字符串,用来指定数据集的名称。通过这个字符串,该函数会调用data/catalog.pyDatasetCatalog类来进行解析得到一个包含数据信息的字典。

解析的原理是:DatasetCatalog有一个字典_REGISTERED,默认已经注册好了例如coco,voc这些数据集的信息。如果你想要使用你自己的数据集,那么你需要在最开始前你需要定义你的数据集名字以及定义一个函数(这个函数不需要传参,而且最后会返回一个dict,该dict包含你的数据集信息),举个栗子:

from detectron2.data import DatasetCatalog
my_dataset_name = 'apple'
def get_dicts():
	...
	return dict

DatasetCatalog.register(my_dataset_name, get_dicts)

当然,如果你的数据集已经是COCO的格式了,那么你也可以使用如下方法进行注册:

from detectron2.data.datasets import register_coco_instances
my_dataset_name = 'apple'
register_coco_instances(my_dataset_name, {}, "json_annotation.json", "path/to/image/dir")

最后,get_detection_dataset_dicts会返回一个包含若干个dict的list,之所以是list是因为参数dataset_names也是一个list,这样我们就可以制定多个names来同时对数据进行读取。

解析成DatasetFromList

DatasetFromList(dataset_dict)函数定义在detectron2/data/common.py中,它其实就是一个torch.utils.data.Dataset类,其源码如下

class DatasetFromList(data.Dataset):
    """
    Wrap a list to a torch Dataset. It produces elements of the list as data.
    """

    def __init__(self, lst: list, copy: bool = True):
        """
        Args:
            lst (list): a list which contains elements to produce.
            copy (bool): whether to deepcopy the element when producing it,
                so that the result can be modified in place without affecting the
                source in the list.
        """
        self._lst = lst
        self._copy = copy

    def __len__(self):
        return len(self._lst)

    def __getitem__(self, idx):
        if self._copy:
            return copy.deepcopy(self._lst[idx])
        else:
            return self._lst[idx]

这个很简单就不加赘述了

DatsetFromList转化成MapDataset

其实DatsetFromListMapDataset都是torch.utils.data.Dataset的子类,那他们的区别是什么呢?很简单,区别就是后者使用了mapper

在解释mapper是什么之前我们首先要知道的是,在detectron2中,一张图片对应的是一个dict,那么整个数据集就是listdict。之后我们再看DatsetFromList,它的__getitem__函数非常简单,它只是简单粗暴地就返回了指定idx的元素。显然这样是不行的,因为在把数据扔给模型训练之前我们肯定还要对数据做一定的处理,而这个工作就是由mapper来做的,默认情况下使用的是detectron2/data/dataset_mapper.py中定义的DatasetMapper,如果你需要自定义一个mapper也可以参考这个写。

DatasetMapper(cfg, is_train=True)

我们继续了解一下DatasetMapper的实现原理,首先看一下官方给的定义:

A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model.

简单概括就是这个类是可调用的(callable),所以在下面的源码中可以看到定义了__call__方法。

该类主要做了这三件事:

The callable currently does the following: 1. Read the image from "file_name" 2. Applies cropping/geometric transforms to the image and annotations 3. Prepare data and annotations to Tensor and :class:Instances

其源码如下(有删减):

class DatasetMapper:
    def __init__(self, cfg, is_train=True):
		# 读取cfg的参数
		...

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
		
		# 1. 读取图像数据
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
		
		# 2. 对image和box等做Transformation
        if "annotations" not in dataset_dict:
            image, transforms = T.apply_transform_gens(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
            )
        else:
			...
            image, transforms = T.apply_transform_gens(self.tfm_gens, image)
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w
		
		# 3.将数据转化成tensor格式
        dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
		...

        return dataset_dict

MapDataset

class MapDataset(data.Dataset):
    def __init__(self, dataset, map_func):
        self._dataset = dataset
        self._map_func = PicklableWrapper(map_func)  # wrap so that a lambda will work

        self._rng = random.Random(42)
        self._fallback_candidates = set(range(len(dataset)))

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        retry_count = 0
        cur_idx = int(idx)

        while True:
            data = self._map_func(self._dataset[cur_idx])
            if data is not None:
                self._fallback_candidates.add(cur_idx)
                return data

            # _map_func fails for this idx, use a random new index from the pool
            retry_count += 1
            self._fallback_candidates.discard(cur_idx)
            cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]

            if retry_count >= 3:
                logger = logging.getLogger(__name__)
                logger.warning(
                    "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
                        idx, retry_count
                    )
                )
  • self._fallback_candidates是一个set,它的特点是其中的元素是独一无二的,定义这个的作用是记录可正常读取的数据索引,因为有的数据可能无法正常读取,所以这个时候我们就可以把这个坏数据的索引从_fallback_candidates中剔除,并随机采样一个索引来读取数据。
  • __getitem__中的逻辑就是首先读取指定索引的数据,如果正常读取就把该所索引值加入到_fallback_candidates中去;反之,如果数据无法读取,则将对应索引值删除,并随机采样一个数据,并且尝试3次,若3次后都无法正常读取数据,则报错,但是好像也没有退出程序,而是继续读数据,可能是以为总有能正常读取的数据吧hhh。

<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br>

<h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br>

<b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原创</b>

<b style="color:white;">

2019-10-23 13:37:13

<p></p>

</b><p><b style="color:white;"></b>

</p></h3><br>

</footer>

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Detectron2源码阅读笔记-(三)Dataset pipeline

    结合前面两篇文章的内容可以看到detectron2在构建model,optimizer和data_loader的时候都是在对应的build.py文件里实现的。我...

    marsggbo
  • Pytorch Sampler详解

    其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只...

    marsggbo
  • python多线程学习笔记(超详细)

    python threading 多线程 一. Threading简介 首先看下面的没有用Threading的程序 import threading,time ...

    marsggbo
  • Detectron2源码阅读笔记-(三)Dataset pipeline

    结合前面两篇文章的内容可以看到detectron2在构建model,optimizer和data_loader的时候都是在对应的build.py文件里实现的。我...

    marsggbo
  • 网上课程管理系统---大致框架(伪代码)

          python3中有一个super方法,根据广度优先的继承顺序查找上一个类

    py3study
  • 抽象工厂模式

    当每个抽象产品都有多于一个的具体子类的时候,工厂角色怎么知道实例化哪一个子类呢?比如每个抽象产品角色都有两个具体产品。抽象工厂模式提供两个具体工厂角色,分别对应...

    用户2936342
  • python备份目录脚本

    #!/usr/bin/env python #backup app python script. import os import time import sy...

    py3study
  • 使用 plotly 绘制数据图表

    不少小伙伴在开发过程中都有对模块进行压测的经历,压测结束后大家往往喜欢使用Excel处理压测数据并绘制数据可视化视图,但这样不能很方便的使用web页面进行数据展...

    邵靖
  • 腾讯云-轻量应用服务器SaaS交付Discuz! Q

    本文提供视频讲解,详细见地址:https://www.bilibili.com/video/BV1Hh411Z7gw

    研究僧
  • 什么叫做微内核?与安卓系统有什么区别?

    从事嵌入式开发多年,要讲清楚这个事情真需要一定开发经验特别是关于linux上面的,首先微内核是相对于强内核而言,linux属于典型的强内核架构,从第一版本开始就...

    程序员互动联盟

扫码关注云+社区

领取腾讯云代金券