首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Caffe工厂模式解析

Caffe工厂模式解析

作者头像
chaibubble
发布2019-09-06 09:24:10
7760
发布2019-09-06 09:24:10
举报

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/chaipp0607/article/details/100188814

Caffe有五个基本组件,分别是Blob,Solver,Net,Layer和Proto,其中Solver和Layer使用了工厂模式,下面以Slover为例说明下。 Solver的工厂模式在注册和调用的过程中体现,所以在说明工厂模式之前,我们首先要弄明白Solver在Caffe内部是如何被使用的。

Solver注册机制

什么是Solver注册

我们都知道Layer和Slover是需要被注册的,而所谓的注册就是把这个类型的Slover(比如SDGSlover)找个地方记录下来,好告诉后面的过程,有这个Slover了,需要的话可以来这里调用。 这就和在CSDN注册会员一样,我们成功注册为会员,“用户名”和“密码”就被记录下来了,然后可以进一步的完善信息,写博客等等,这些都是我们这个账户里面的内容了。下一次登录的时候,我们需要使用“用户名”来匹配,登录我们的账户,而密码只是一个安全措施。 Caffe中Slover有SGDSlover,AdaGradSolver,AdaDeltaSolver,AdamSolver,NesterovSolver,RMSPropSolver这六种,注册的代码在它们各自的源文件中,比如SGDSlover的注册在sgd_solver.cpp的最下面:

REGISTER_SOLVER_CLASS(SGD);

SGD的就是solver.proto中type对应的字符串。 下面我们就从这行代码开始,往前追踪SGDSlover的注册。

Solver如何被注册

在这里插入图片描述
在这里插入图片描述

solver_factory.hpp中可以找到REGISTER_SOLVER_CLASS的定义,它是一个宏

    #define REGISTER_SOLVER_CLASS(type)                                            \
      template <typename Dtype>                                                    \
      Solver<Dtype>* Creator_##type##Solver(                                       \
          const SolverParameter& param)                                            \
      {                                                                            \
        return new type##Solver<Dtype>(param);                                     \
      }                                                                            \
      REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

define 里的 ##是一个连接符号,用于把参数连在一起 。而type其实就是SGD,编译的时候这个宏会被替换,并将type换成SGD ,所以实际上这个宏就是完成了。

  template <typename Dtype>                                                    
  Solver<Dtype>* Creator_SGDSolver(const SolverParameter& param){                                                                            
    return new SGDSolver<Dtype>(param);                                     
  }                                                                            
  REGISTER_SOLVER_CREATOR(SGD, Creator_SGDSolver)

它定义了一个函数Creator_SGDSolver(),参数为SolverParameter&类型的引用,返回值为SGDSolver<Dtype>(param)

最后又调用了另一个宏REGISTER_SOLVER_CREATOR

#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

还是想上面那样替换它:

  static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    
  static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>);  

最后的目的就是要实例化SolverRegisterer类的两个对象。SolverRegisterer是一个模板类,所以在实例化时候有SolverRegisterer<float>SolverRegisterer<double>,以支持两种Slove的数据类型,分别对应float和double。 实例化时会调用SolverRegisterer类的构造函数,通过SolverRegisterer类定义,发现构造函数里面调用了AddCreator()方法。

template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
      Solver<Dtype>* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};

AddCreator()方法是另一个类SolverRegistry的成员,我们暂时只看SolverRegistry类下面这些成员就够了,细节的地方做了注释。

// LayerRegistry:注册类,主要实现两个方法,AddCreator()和CreateSolver(),下面代码只有AddCreator()
template <typename Dtype>
class SolverRegistry {
 public:
  //定义名为Creator的函数指针类型,参数为SolverParameter&类型的引用,返回值为一个Solver类型的指针
  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
  //将一个map类型定义一个别名,叫做CreatorRegistry
  //map将“字符串-函数指针”行成映射
  typedef std::map<string, Creator> CreatorRegistry;

 // Registry()静态函数,只创建一个map实例,仅第一次调用时会new,其它直接return
 //创建的map其实就是solver的内部注册表
  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }

  // Adds a creator.
  // AddCreator函数用来向Registry列表中添加一组<type, creator>
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    // 向map中加入一个映射
    registry[type] = creator;
  }
};

所以,当我们看到了 registry[type] = creator;这一行代码时,也就找到了slover的注册到底在做什么,他其实就是在往registry变量里添加一组映射,registry是静态的,它只有一个,就是slover的注册表;一组映射是CreatorRegistry,它实际是一个map,建立映射的两个值分别stringCreator,string不用说,他就是像“SGD”,“Adam”,“AdaDelta”这样的一个字符串,关键是和它建立映射的东西:CreatorCreator是一个函数指针,这个指针可以指向的函数要以SolverParameter&类型的引用作为参数,并且返回值为一个Solver类型的指针,Caffe里面那个函数是这个样子呢?就是在宏里定义的那个函数:Creator_SGDSolver()。 最终,SGDSlover的注册是将字符串"SGD"和指向函数Creator_SGDSolver()的指针成对存储到registry变量里面。

Solver的调用

在这里插入图片描述
在这里插入图片描述

说完了注册的部分,下面说明下调用,也就是程序的运行过程。 caffe的程序入库在caffe.cpp的main()函数中,比如执行train的时候,调用了SolverRegistry类的CreateSolver()函数:

  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

此时的Dtype已经指定为了float类型,solver_param是从slover.proto里面解析出来的。 CreateSolver()也在SolverRegistry类中定义:

template <typename Dtype>
class SolverRegistry {
 public:
  // Get a solver using a SolverParameter.
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }
}

它实现了registry[type](param)的操作,实际上就是AddCreator()反过来的过程,一个是取,一个是存。同样在"SGD"的时候,取出来的就应该是上面提到的Creator_SGDSolver(),而Creator_SGDSolver()的返回值是SGDSolver<Dtype>(param)。 这个SGDSolver<Dtype>(param)就在sgd_solvers.hpp中定义,就是SGDSolver的构造函数:

/**
 * @brief Optimizes the parameters of a Net using
 *        stochastic gradient descent (SGD) with momentum.
 */
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
 public:
  explicit SGDSolver(const SolverParameter& param)
      : Solver<Dtype>(param) { PreSolve(); }
  explicit SGDSolver(const string& param_file)
      : Solver<Dtype>(param_file) { PreSolve(); }
  virtual inline const char* type() const { return "SGD"; }

  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
}

通过main()中的调用,Dtype指定为了float。

Solver注册发生在什么时候

通过上面的分析,我们知道了所谓的注册就是往map里面存入,调用就是从map取出来,那就会有一个问题,注册是在什么时候发生的? 因为registry就是个静态变量,它的生命周期的开始一定在程序运行起来之后,但是程序运行起来就要从入口执行train了,这就要求在这之前registry里就要完成注册了,我们加个断点调试一下。

在这里插入图片描述
在这里插入图片描述

一个断点打在程序的入口处:

在这里插入图片描述
在这里插入图片描述

一个断点打在注册的地方:

在这里插入图片描述
在这里插入图片描述

启动调试之后,先断到了注册的地方:

在这里插入图片描述
在这里插入图片描述

此时的type是"AdaDelta",因为还没有存入,所以registy的size=0,再走一步的话:

在这里插入图片描述
在这里插入图片描述

type变成了"AdaGrad",因为已经存入了"AdaDelta",所以registy的size=1。 于是可以得到一个结论是,注册的过程是在进入main函数之前完成。

此外,还可以用代码图的当时看下,首先改一下断点的位置到:

在这里插入图片描述
在这里插入图片描述

开始执行调试,直到代码执行到main中,生成代码图,就像下面这样:

在这里插入图片描述
在这里插入图片描述

Solver的工厂模式

最后就是Solver的工厂模式了,上面的说明包含了工厂模式思想,下面我们工厂模式的角度再说明下。 Caffe中Slover的工厂模式是一种简单工厂模式,只有一个工厂,负责生产多种产品。在solver_factory.hppSolverRegistry类定义了一个工厂,前面提到的注册,是在完善工厂中选择的逻辑,在很多简单工厂的例子中,这个逻辑可以靠switch,case来实现,只是在caffe中它变成了一个“字符串”-“函数指针”的映射。 上面提到的调用的过程,就是工厂生产产品的过程,还拿SDG的例子:

  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

尽管solver_param参数的不同,但是都调用工厂中的方法CreateSolver(),最终将生产的过程交给了产品的子类去实现,产品的子类实现就在各个优化器对应的源码中。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年09月03日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Solver注册机制
    • 什么是Solver注册
      • Solver如何被注册
      • Solver的调用
      • Solver注册发生在什么时候
      • Solver的工厂模式
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档