前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【MMDetection 超全专栏】二,配置类和注册器&数据处理&训练pipline

【MMDetection 超全专栏】二,配置类和注册器&数据处理&训练pipline

作者头像
BBuf
发布2020-06-04 10:26:56
2.2K0
发布2020-06-04 10:26:56
举报
文章被收录于专栏:GiantPandaCVGiantPandaCV

0. 目录

目录,第一节和第二节请看上篇推文

第三节 配置类和注册器

这两个东西可变为自用+练习。

0.3.1 配置类

配置方式支持python/json/yaml,从mmcv的Config解析,其功能同maskrcnn-benchmark的yacs类似,将字典的取值方式属性化.这里贴部分代码,以供学习。

代码语言:javascript
复制
class Config(object):
    ...
    @staticmethod
    def _file2dict(filename):
        filename = osp.abspath(osp.expanduser(filename))
        check_file_exist(filename)
        if filename.endswith('.py'):
            with tempfile.TemporaryDirectory() as temp_config_dir:
                shutil.copyfile(filename,
                                osp.join(temp_config_dir, '_tempconfig.py'))
                sys.path.insert(0, temp_config_dir)
                mod = import_module('_tempconfig')
                sys.path.pop(0)
                cfg_dict = {
                    name: value
                    for name, value in mod.__dict__.items()
                    if not name.startswith('__')
                }
                # delete imported module
                del sys.modules['_tempconfig']
        elif filename.endswith(('.yml', '.yaml', '.json')):
            import mmcv
            cfg_dict = mmcv.load(filename)
        else:
            raise IOError('Only py/yml/yaml/json type are supported now!')

        cfg_text = filename + '\n'
        with open(filename, 'r') as f:
            cfg_text += f.read()
        # 2.0新增的配置文件的组合继承
        if '_base_' in cfg_dict:
            cfg_dir = osp.dirname(filename)
            base_filename = cfg_dict.pop('_base_')
            base_filename = base_filename if isinstance(
                base_filename, list) else [base_filename]

            cfg_dict_list = list()
            cfg_text_list = list()
            for f in base_filename:
                # 递归,可搜索staticmethod and recursion
                # 静态方法调静态方法,类方法调静态方法
                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
                cfg_dict_list.append(_cfg_dict)
                cfg_text_list.append(_cfg_text)

            base_cfg_dict = dict()
            for c in cfg_dict_list:
                if len(base_cfg_dict.keys() & c.keys()) > 0:
                    raise KeyError('Duplicate key is not allowed among bases')
                base_cfg_dict.update(c)
            # 合并
            Config._merge_a_into_b(cfg_dict, base_cfg_dict)
            cfg_dict = base_cfg_dict

            # merge cfg_text
            cfg_text_list.append(cfg_text)
            cfg_text = '\n'.join(cfg_text_list)

        return cfg_dict, cfg_text
    
    ...
    # 获取key值
    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)
    # 序列化
    def __getitem__(self, name):
        return self._cfg_dict.__getitem__(name)
    # 将字典属性化主要用了__setattr__
    def __setattr__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setattr__(name, value)
    # 更新key值
    def __setitem__(self, name, value):
        if isinstance(value, dict):
            value = ConfigDict(value)
        self._cfg_dict.__setitem__(name, value)
    # 迭代器
    def __iter__(self):
        return iter(self._cfg_dict)
        

主要考虑点是自己怎么实现类似的东西,核心点就是python的基本魔法函数的应用,可同时参考yacs。

0.3.2 注册器

把基本对象放到一个继承了字典的对象中,实现了对象的灵活管理。

代码语言:javascript
复制
import inspect
from functools import partial

import mmcv


class Registry(object):
    # 2.0 放到mmcv中

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def get(self, key):
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, force=False):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class   # 类名:类

    def register_module(self, cls=None, force=False):
        # 作为类cls的装饰器
        if cls is None:
            # partial函数(类)固定参数,返回新对象,递归不是很清楚
            return partial(self.register_module, force=force)
        self._register_module(cls, force=force)   # 将cls装进当前Registry对象的中_module_dict
        return cls    # 返回类

def build_from_cfg(cfg, registry, default_args=None):
	assert isinstance(cfg, dict) and 'type' in cfg
	assert isinstance(default_args, dict) or default_args is None
	args = cfg.copy()
	obj_type = args.pop('type')
	if mmcv.is_str(obj_type):
		# 从注册类中拿出obj_type类
		obj_cls = registry.get(obj_type)
		if obj_cls is None:
			raise KeyError('{} is not in the {} registry'.format(
				obj_type, registry.name))
	elif inspect.isclass(obj_type):
		obj_cls = obj_type
	else:
		raise TypeError('type must be a str or valid type, but got {}'.format(
			type(obj_type)))
	if default_args is not None:
		# 增加一些新的参数
		for name, value in default_args.items():
			args.setdefault(name, value)
	return obj_cls(**args)    # **args是将字典解析成位置参数(k=v)。

第四节 数据处理

数据处理可能是炼丹师接触最为密集的了,因为通常情况,除了数据的离线处理,写个数据类,就可以炼丹了。但本节主要涉及数据的在线处理,更进一步应该是检测分割数据的pytorch处理方式。虽然mmdet将常用的数据都实现了,而且也实现了中间通用数据格式,但,这和模型,损失函数,性能评估的实现也相关,比如你想把官网的centernet完整的改成mmdet风格,就能看到(看起来没必要)。

0.4.1 检测分割数据

看看配置文件,数据相关的有datadict,里面包含了train,val,test的路径信息,用于数据类初始化,有pipeline,将各个函数及对应参数以字典形式放到列表里,是对pytorch原装的transforms+compose,在检测,分割相关数据上的一次封装,使得形式更加统一。

从builder.py中build_dataset函数能看到,构建数据有三种方式,ConcatDataset,RepeatDataset和从注册器中提取。其中dataset_wrappers.py中ConcatDataset和RepeatDataset意义自明,前者继承自pytorch原始的ConcatDataset,将多个数据集整合到一起,具体为把不同序列(可参考容器的抽象基类https://docs.python.org/zh-cn/3/library/collections.abc.html)的长度相加,__getitem__函数对应index替换一下,后者就是单个数据类(序列)的多次重复。就功能来说,前者提高数据丰富度,后者可解决数据太少使得loading时间长的问题(见代码注释)。而被注册的数据类在datasets下一些熟知的数据名文件中。其中,基类为custom.py中的CustomDataset,coco继承自它,cityscapes继承自coco,xml_style的XMLDataset继承CustomDataset,然后wider_face,voc均继承自XMLDataset。因此这里先分析一下CustomDataset。

CustomDataset 记录数据路径等信息,解析标注文件,将每一张图的所有信息以字典作为数据结构存在results中,然后进入pipeline:数据增强相关操作,代码如下:

代码语言:javascript
复制
self.pipeline = Compose(pipeline)
	# Compose是实现了__call__方法的类,其作用是使实例能够像函数一样被调用,同时不影响实例本身的生命周期
def pre_pipeline(self, results):
	# 扩展字典信息
	results['img_prefix'] = self.img_prefix
	results['seg_prefix'] = self.seg_prefix
	results['proposal_file'] = self.proposal_file
	results['bbox_fields'] = []
	results['mask_fields'] = []
	results['seg_fields'] = []

def prepare_train_img(self, idx):
	img_info = self.img_infos[idx]
	ann_info = self.get_ann_info(idx)
	# 基本信息,初始化字典
	results = dict(img_info=img_info, ann_info=ann_info)
	if self.proposals is not None:
		results['proposals'] = self.proposals[idx]
	self.pre_pipeline(results)
	return self.pipeline(results)    # 数据增强

def __getitem__(self, idx):
	if self.test_mode:
		return self.prepare_test_img(idx)
	while True:
		data = self.prepare_train_img(idx)
		if data is None:
			idx = self._rand_another(idx)
			continue
		return data

这里数据结构的选取需要注意一下,字典结构,在数据增强库albu中也是如此处理,因此可以快速替换为albu中的算法。另外每个数据类增加了各自的evaluate函数。evaluate基础函数在mmdet.core.evaluation中,后做补充。

mmdet的数据处理,字典结构pipelineevaluate是三个关键部分。其他所有类的文件解析部分,数据筛选等,看看即可。因为我们知道,pytorch读取数据,是将序列转化为迭代器后进行io操作,所以在dataset下除了pipelines外还有loader文件夹,里面实现了分组,分布式分组采样方法,以及调用了mmcv中的collate函数(此处为1.x版本,2.0版本将loader移植到了builder.py中),且build_dataloader封装的DataLoader最后在 train_detector中被调用,这部分将在后面补充,这里说说pipelines。

返回maskrcnn的配置文件(1.x,2.0看base config),可以看到训练和测试的不同之处:LoadAnnotations,MultiScaleFlipAug,DefaultFormatBundle和Collect。额外提示,虽然测试没有LoadAnnotations,根据CustomDataset可知,它仍需标注文件,这和inference的pipeline不同,也即这里的test实为evaluate。

代码语言:javascript
复制
# 序列中的dict可以随意删减,增加,属于数据增强调参内容
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

最后这些所有操作被Compose串联起来,代码如下:

代码语言:javascript
复制
@PIPELINES.register_module
class Compose(object):

	def __init__(self, transforms):
		assert isinstance(transforms, collections.abc.Sequence)  # 列表是序列结构
		self.transforms = []
		for transform in transforms:
			if isinstance(transform, dict):
				transform = build_from_cfg(transform, PIPELINES)
				self.transforms.append(transform)
			elif callable(transform):
				self.transforms.append(transform)
			else:
				raise TypeError('transform must be callable or a dict')

	def __call__(self, data):
		for t in self.transforms:
			data = t(data)
			if data is None:
				return None
		return data

上面代码能看到,配置文件中pipeline中的字典传入build_from_cfg函数,逐一实现了各个增强类(方法)。扩展的增强类均需实现__call__方法,这和pytorch原始方法是一致的。

有了以上认识,重新梳理一下pipelines的逻辑,由三部分组成,load,transforms,和format。load相关的LoadImageFromFile,LoadAnnotations都是字典results进去,字典results出来。具体代码看下便知,LoadImageFromFile增加了'filename','img','img_shape','ori_shape','pad_shape','scale_factor','img_norm_cfg'字段。其中img是numpy格式。LoadAnnotations从 results['ann_info']中解析出bboxs,masks,labels等信息。注意coco格式的原始解析来自pycocotools,包括其评估方法,这里关键是字典结构(这个和模型损失函数,评估等相关,统一结构,使得代码统一)。transforms中的类作用于字典的values,也即数据增强。format中的DefaultFormatBundle是将数据转成mmcv扩展的容器类格式DataContainer。另外Collect会根据不同任务的不同配置,从results中选取只含keys的信息生成新的字典,具体看下该类帮助文档。这里看一下从numpy转成tensor的代码:

代码语言:javascript
复制
def to_tensor(data):
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.
    """
    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmcv.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError('type {} cannot be converted to tensor.'.format(
			type(data)))
	以上代码告诉我们,基本数据类型,需掌握。

那么DataContainer是什么呢?它是对tensor的封装,将results中的tensor转成DataContainer格式,实际上只是增加了几个property函数,cpu_only,stack,padding_value,pad_dims,其含义自明,以及size,dim用来获取数据的维度,形状信息。 考虑到序列数据在进入DataLoader时,需要以batch方式进入模型,那么通常的collate_fn会要求tensor数据的形状一致。但是这样不是很方便,于是有了DataContainer。它可以做到载入GPU的数据可以保持统一shape,并被stack,也可以不stack,也可以保持原样,或者在非batch维度上做pad。当然这个也要对default_collate进行改造,mmcv在parallel.collate中实现了这个。

collate_fn是DataLoader中将序列dataset组织成batch大小的函数,这里帖三个普通例子:

代码语言:javascript
复制
def collate_fn_1(batch):
	# 这是默认的,明显batch中包含相同形状的img\_tensor和label
	return tuple(zip(*batch))
	
def coco_collate_2(batch):
	# 传入的batch数据是被albu增强后的(字典结构)
    imgs = [s['image'] for s in batch]    # tensor, h, w, c->c, h, w , handle at transform in __getitem__
    annots = [s['bboxes'] for s in batch]
    labels = [s['category_id'] for s in batch]

	# 以当前batch中图片annot数量的最大值作为标记数据的第二维度值,空出的就补-1。
    max_num_annots = max(len(annot) for annot in annots)
    annot_padded = np.ones((len(annots), max_num_annots, 5))*-1

    if max_num_annots > 0:
        for idx, (annot, lab) in enumerate(zip(annots, labels)):
            if len(annot) > 0:
                annot_padded[idx, :len(annot), :4] = annot
				# 不同模型,损失值计算可能不同,这里ssd结构需要改为xyxy格式并且要做尺度归一化
				# 这一步完全可以放到\_\_getitem\_\_中去,只是albu的格式需求问题。
                annot_padded[idx, :len(annot), 2] += annot_padded[idx, :len(annot), 0]    #  xywh-->x1,y1,x2,y2 for general box,ssd target assigner
                annot_padded[idx, :len(annot), 3] += annot_padded[idx, :len(annot), 1]    # contains padded -1 label
                annot_padded[idx, :len(annot), :] /=  640    # priorbox for ssd primary target assinger
                annot_padded[idx, :len(annot), 4] = lab
	return torch.stack(imgs, 0), torch.FloatTensor(annot_padded)
	
def detection_collate_3(batch):
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, img_anno in enumerate(sample):
            if torch.is_tensor(img_anno):
                imgs.append(img_anno)
            elif isinstance(img_anno, np.ndarray):
                annos = torch.from_numpy(img_anno).float()
                targets.append(annos)
    return torch.stack(imgs, 0), targets    # 做了stack, DataContainer可以不做stack

以上就是数据处理的相关内容。最后再用DataLoader封装拆成迭代器,其相关细节,sampler等暂略。

代码语言:javascript
复制
data_loader = DataLoader(
	dataset,
	batch_size=batch_size,
	sampler=sampler,
	num_workers=num_workers,
	collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
	pin_memory=False,
	worker_init_fn=init_fn,
	**kwargs)

5. 训练pipeline

训练流程的包装过程大致如下:tools/train.py->apis/train.py->mmcv/runner.py->mmcv/hook.py(后面是分散的),其中runner维护了数据信息,优化器,日志系统,训练loop中的各节点信息,模型保存,学习率等.另外补充一点,以上包装过程,在mmdet中无处不在,包括mmcv的代码也是对日常频繁使用的函数进行了统一封装.

0.5.1 训练逻辑

图见Figure2:

Figure 2

注意它的四个层级.代码上,主要查看apis/train.py,mmcv中的runner相关文件.核心围绕Runner,Hook两个类.Runner将模型,批处理函数batch_processor,优化器作为基本属性,训练过程中与训练状态,各节点相关的信息被记录在mode,_hooks,_epoch,_iter,_inner_iter,_max_epochs,_max_iters中,这些信息维护了训练过程中插入不同hook的操作方式.理清训练流程只需看Runner的成员函数run.在run里会根据mode按配置中workflow的epoch循环调用train和val函数,跑完所有的epoch.比如train:

代码语言:javascript
复制
def train(self, data_loader, **kwargs):
	self.model.train()
	self.mode = 'train'    # 改变模式
	self.data_loader = data_loader
	self._max_iters = self._max_epochs * len(data_loader)    # 最大batch循环次数
	self.call_hook('before_train_epoch')    # 根据名字获取hook对象函数
	for i, data_batch in enumerate(data_loader):
		self._inner_iter = i    # 记录训练迭代轮数
		self.call_hook('before_train_iter')    # 一个batch前向开始
		outputs = self.batch_processor(
			self.model, data_batch, train_mode=True, **kwargs)
		self.outputs = outputs
		self.call_hook('after_train_iter')    # 一个batch前向结束
		self._iter += 1    # 方便resume时,知道从哪一轮开始优化

	self.call_hook('after_train_epoch')    # 一个epoch结束
	self._epoch += 1    # 记录训练epoch状态,方便resume

上面需要说明的是自定义hook类,自定义hook类需继承mmcv的Hook类,其默认了6+8+4个成员函数,也即Figure2所示的6个层级节点,外加2*4个区分train和val的节点记录函数,以及4个边界检查函数.从train.py中容易看出,在训练之前,已经将需要的hook函数注册到Runner的self._hook中了,包括从配置文件解析的优化器,学习率调整函数,模型保存,一个batch的时间记录等(注册hook算子在self._hook中按优先级升序排列).这里的call_hook函数定义如下:

代码语言:javascript
复制
def call_hook(self, fn_name):
	for hook in self._hooks:
		getattr(hook, fn_name)(self)

容易看出,在训练的不同节点,将从注册列表中调用实现了该节点函数的类成员函数.比如

代码语言:javascript
复制
class OptimizerHook(Hook):

    def __init__(self, grad_clip=None):
        self.grad_clip = grad_clip

    def clip_grads(self, params):
        clip_grad.clip_grad_norm_(
            filter(lambda p: p.requires_grad, params), **self.grad_clip)

    def after_train_iter(self, runner):
        runner.optimizer.zero_grad()
        runner.outputs['loss'].backward()
        if self.grad_clip is not None:
            self.clip_grads(runner.model.parameters())
        runner.optimizer.step()

将在每个train_iter后实现反向传播和参数更新.学习率优化相对复杂一点,其基类LrUpdaterHook,实现了before_run,before_train_epoch, before_train_iter三个hook函数,意义自明.这里选一个余弦式变化,稍作说明:

代码语言:javascript
复制
class CosineLrUpdaterHook(LrUpdaterHook):

    def __init__(self, target_lr=0, **kwargs):
        self.target_lr = target_lr
        super(CosineLrUpdaterHook, self).__init__(**kwargs)

    def get_lr(self, runner, base_lr):
        if self.by_epoch:
            progress = runner.epoch
            max_progress = runner.max_epochs
        else:
            progress = runner.iter    # runner需要管理各节点信息的原因之一
            max_progress = runner.max_iters
        return self.target_lr + 0.5 * (base_lr - self.target_lr) * \
			(1 + cos(pi * (progress / max_progress)))

从get_lr可以看到,学习率变换周期有两种,epoch->max_epoch,或者更大的iter->max_iter,后者表明一个epoch内不同batch的学习率可以不同,因为没有什么理论,所有这两种方式都行.其中base_lr为初始学习率,target_lr为学习率衰减的上界,而当前学习率即为返回值.

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 0. 目录
  • 第三节 配置类和注册器
    • 0.3.1 配置类
      • 0.3.2 注册器
      • 第四节 数据处理
        • 0.4.1 检测分割数据
        • 5. 训练pipeline
          • 0.5.1 训练逻辑
          相关产品与服务
          批量计算
          批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档