专栏首页marsggboDetectron2代码阅读笔记-(二)
原创

Detectron2代码阅读笔记-(二)

Trainer解析

我们继续Detectron2代码阅读笔记-(一)中的内容。

上图画出了detectron2文件夹中的三个子文件夹(tools,config,engine)之间的关系。那么剩下的文件夹又是如何起作用的呢?

def main(args):
    cfg = setup(args)

    if args.eval_only:
		...
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()

build_*方法

我们从trainer = Trainer(cfg)开始进一步了解。

Detectron2代码阅读笔记-(一)中已经提到过一连串的Trainer的继承关系如下:

tools.train_net.Trainer->detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase,而detectron2.engine.default.DefaultTrainer在其__init__(self, cfg)函数中定义了解析cfg。如下面代码所示,cfg会作为参数倍若干个build_*方法解析,得到解析后的model,optimizer,data_loader等。

from detectron2.modeling import build_model
class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)
		
		... 
		
        self.register_hooks(self.build_hooks())
		
	@classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

下面我们以DefaultTrainer.build_model为例来介绍注册机制,该方法调用了detectron2/modeling/meta_arch/build_model.pybuild_model函数,其源代码如下:

from detectron2.utils.registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
def build_model(cfg):
    """
    Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`.
    """
    meta_arch = cfg.MODEL.META_ARCHITECTURE
    return META_ARCH_REGISTRY.get(meta_arch)(cfg)
  • meta_arch = cfg.MODEL.META_ARCHITECTURE: 根据超参数获得网络结构的名字model = META_ARCH_REGISTRY.get(meta_arch) return model(cfg)
  • return META_ARCH_REGISTRY.get(meta_arch)(cfg):META_ARCH_REGISTRY是一个Registry类(这个在后面会详细介绍),可以将这一行代码拆成如下几个步骤:

注册机制Registry

那么Registry到底是什么呢?在分析源代码之前我们先了解一下如何使用它,假如你想自己实现一个新的backbone网络,那么你可以这样做:

首先在detectron2中定义好如下(实际上已经定义了):

# detectron2/modeling/backbone/build.py
BACKBONE_REGISTRY = Registry('BACKBONE')

之后在你创建的新的文件下按如下方式创建你的backbone

# detectron2/modeling/backbone/your_backbone.py
from .build import BACKBONE_REGISTRY

# 方式1
@BACKBONE_REGISTRY.register()
class MyBackbone():
	...
		
# 方式2
class MyBackbone():
	...
BACKBONE_REGISTRY.register(MyBackbone)

Registry源代码如下(有删减):

class Registry(object):
    def __init__(self, name):
        self._name = name
        self._obj_map = {}

    def _do_register(self, name, obj):
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
        self._obj_map[name] = obj

    def register(self, obj=None):
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name):
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))
        return ret
  • 首先是__init__部分: - self._name则是你要注册的名字,例如对于完整的模型而言,name一般取META_ARCH。当然如果你需要自定义backbone网络,你也可以定义一个Registry('BACKBONE') - self._obj_map:其实就是一个字典。以模型为例,key就是你的模型名字,而value就是对应的模型类。这样你在传参时只需要修改一下模型名字就能使用不同的模型了。具体实现方法就是后面这几个函数。
  • register: 可以看到该方法定义了注册的两种方式,一种是当obj==None的时候,使用装饰器的方式注册,另外一种就是直接将obj作为参数调用_do_register进行注册。
  • _do_register:真正注册的函数,可以看到它首先会判断name是否已经存在于self._obj_map了。什么意思呢?还是以backbone为例,我们定义了一个BACKBONE_REGISTRY = Registry('BACKBONE'),然后又定义了很多种backbone,而这些backbone都使用@BACKBONE_REGISTRY.register()的方式注册到了BACKBONE_REGISTRY._obj_map中了,所以才取名为Registry,还是蛮形象的吼。
  • get: 这个其实就是根据key值对字典进行取值。

Detectron2 整体代码架构

虽然Detectron2还有很多部分没有介绍到,但是源代码分析到这应该对整体架构有了一定的理解了,具体的一些细节会在后续的文章中进行分析。现对Detectron2 整体代码架构总结一下:

<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br>

<h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br>

<b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原创</b>

<b style="color:white;">

2019-10-15 13:16:32

<p></p>

</b><p><b style="color:white;"></b>

</p></h3><br>

</footer>

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Python types.MethodType动态更改类方法

    动态编程语言是高级程序设计语言的一个类别,在计算机科学领域已被广泛应用。它是一类在运行时可以改变其结构的语言:例如新的函数、对象、甚至代码可以被引进,已有的函数...

    marsggbo
  • 论文笔记系列-AutoFPN

    之前的AutoML都是应用在图像分类或者语言模型上,AutoFPN成功地将这技术应用到了目标检测任务上。

    marsggbo
  • Pytorch Sampler详解

    其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只...

    marsggbo
  • python之反射

    python中的反射功能是由以下四个内置函数提供:hasattr、getattr、setattr、delattr,改四个函数分别用于对对象内部执行:检查是否含有...

    菲宇
  • python3 类、对象的基础概念

    py3study
  • 开源数据库中间件-MyCa初探与分片实践

    rpm -ivh MySQL-server-5.5.49-1.linux2.6.i386.rpm

    蒋老湿
  • 基于飞桨复现CVPR 2020 GhostNet的全程解析

    论文作者发现在传统的深度学习网络中存在着大量冗余,但是对模型的精度至关重要的特征图。这些特征图是由卷积变化得到,又输入到下一个卷积层进行运算,这个过程包含大量的...

    用户1386409
  • mybatis的接口实现操作数据库

            测试项目结构如图所示:其中UserDao.java为一个接口,以后的userMapper.xml的配置就是围绕这个接口类展开的:       ...

    yawn
  • 平面四节点等参元(Q4)有限元程序算例

    如图所示悬臂梁,假定为平面应力条件。材料弹性模量E=1e6MPa,泊松比v=0.3,板厚度t=10mm,长度l=400mm,高h=100mm。划分8个单元,分别...

    fem178
  • JVM的垃圾收集器策略

    程序可以通过判断引用队列中是否已经加入了虚引用,来了解被引用的对象是否将要被垃圾回收。如果程序发现某个虚引用已经被加入到引用队列,那么就可以在所引用的对象的内存...

    大大大大大先生

扫码关注云+社区

领取腾讯云代金券