前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【技术分享】从Tensorflow源码中学习设计模式

【技术分享】从Tensorflow源码中学习设计模式

原创
作者头像
腾讯云TI平台
修改2019-08-02 10:53:07
1.6K0
修改2019-08-02 10:53:07
举报

本文原作者:于洋,经授权后发布。

1. 开篇


通常,我们在使用Tensorflow低级API编程时(非Eager模式), 一般有下面三个步骤:

  • 使用tensorflow python侧的API构建图。图通常包括了两部分:正向计算图和反向计算图; 构建的关键字是:新建的 tf.Operation(节点)和 tf.Tensor(边)对象并将它们添加到 tf.Graph 实例中。例如,典型添加op操作就是tf.matmul
  • 创建tf.Session会话; 此步骤的关键字是:创建默认本地会话with tf.Session() as sess:,创建分布式会话with tf.Session("grpc://example.org:2222"):
  • 在tf.Session会话中,初始化全局变量,并批量运行图。 此步骤的关键语句是:sess.run(init_op), sess.run(train_op)

参考链接:图和会话线性回归例子

众所周知,tensorflow使用支持多种前端语言(python,js,swift,go等),执行引擎为C/C++后端实现。 那么,在上述三个步骤中,当用户python构建图,以及运行的图的时候。C/C++后端有在执行哪些工作呢? 按照对应的三个步骤,我们做如下拆解:

  • python在构建图的过程中,也是C/C++构造图的过程。 即python在新增的tf.Operation(节点)和 tf.Tensor(边)的同时,C/C++的后端也生成对应的节点和边,从而构造后端的图。
  • 图创建好后,python调用tf.Session语句,C/C++端会根据参数创建对应本地Session运行图,或者分布式Session运行图。
  • 通过sess.run触发一次图的正向计算,以及反向计算。

本次分享的设计模式,就是在上述第二阶段时:创建本地session和分布式session时,tensorflow是怎样利用抽象工厂设计模式的?

2. 抽象工厂设计模式(Abstract Factory


在《设计模式》中描述的23设计模式,分为三类:创建型、结构型、行为型。其中,抽象工厂设计模式属于创建性设计模式。即是解决对象的创建需求。关于抽象工厂模式我的理解是这样的:

调用者有创建不同对象的需求(对象有一定相似性,例如轿车、卡车),调用者无需关注具体的实现类,而是通过抽象类定义的接口,就能创造不同对象。

当然,个人抽象理解和描述还是很难理解的。我们根据GOF书中,抽象工厂的模式结构图(图需要从右上角看起)在来理解一下:

  • 调用者(Client)有创建对象ProductA1或ProductA2的需求,
  • 但是Client类没有直接调用实现类CreateProductA1、CreateProductA2。
  • 而是通过抽象工厂AbstractFactory的接口创建了不同的对象(即:创建对象ProductA1或ProductA2)。

[ 抽象工厂的模式结构图 - 《设计模式》58页 ]

有了上面粗浅的理解后,我们看一下tensorflow是如何使用抽象工厂模式,创建本地session和分布式session?

首先,我们看一下python创建Session调用栈:

NewSession的代码如下:

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << s;
    return s;
  }
  s = factory->NewSession(options, out_session);
  if (!s.ok()) {
    *out_session = nullptr;
  }
  return s;
}

代码很枯燥,我们看一下上述代码的时序图(以创建DirectSessione为例)。 上述代码对应着时序图的阶段2和阶段3。其中:

  • 阶段2对应代码SessionFactory::GetFactory(options, &factory);
  • 阶段3对应代码factory->NewSession(options, out_session);

[ NewSession的时序图 ]

看到这里,我们温习一下抽象工厂的理解:

  • Client(NewSession)有创建GrpcSession或者DirectSession的需求;
  • 但是,Client没有直接调用new DirectSession或者new GrpcSession创建;
  • 而是,通过调用抽象工厂(SessionFactory)接口GetFactory找到DirectSessionFactory。最终通过DirectSessionFactory->NewSession创建;
  • 最终返回实例为Session型(多态可以到GrpcSesion或者DirectSession对象)。

值得说明的是:Client在整个过程中,并不清楚里面不同的Factory(GrpcSessionFactoryDirectSessionFactory),也不清楚不同的Session类型(GrpcSessionDirectSession)。

最后,参考抽象工厂结构图,大致画了如下Session的创建环节,大家可以在回味一下该设计模式(图也是从右上角看起):

[ 抽象工厂模式创建Session ]

至此,创建Session的主题框架已经大致梳理出来了。但是,上面的时序图中的阶段1一直还没有说明吧? 好,这部分涉及了单件设计模式。

后记:按照下面的定义,上述创建Session的模式(因为只创建了一种Session产品)是不是叫“工厂方法”会好一点?

  • 简单工厂:一个工厂类,一个产品抽象类。
  • 工厂方法:多个工厂类,一个产品抽象类。
  • 抽象工厂:多个工厂类,多个产品抽象类。

说一下个人理解,tensorflow在设计这段代码的时候,做了很高程度的抽象,具备完成多个产品抽象的能力。我这里姑且认为应用的是抽象工厂模式。 大家也可以按照“工厂方法”模式理解上述代码,宗旨是:希望大家在学习tensorflow代码的过程,能了解里面蕴含的设计模式。

3. 单件设计模式(Singleton


NewSession中有这样的代码,不知道大家是否有注意到SessionFactory::GetFactory(options, &factory);?这段代码的含义也就是根据传递的options信息,选择是DirectSessionFactory还是分布式GrpcSessionFactory

但是,大家在看时候,有没有这样的疑问:不同的SessionFactory的是什么时候写入到SessionFactory map中的?何况tensorflow这种没有main函数的程序?这个问题曾经一直很困扰我,在gdb debug后,我发现了下面的小trick。

诀窍在这行代码中static DirectSessionRegistrar registrar;

SessionFactory map初始化的能量蕴含在这个static变量的构造函数。下面的流程图揭示所有的秘密。结合代码,从图的左下角看起(下面的代码对应上面NewSession的时序图)。

和全局变量一样,static变量一直存储在程序的静态存储区。当程序初始化static变量时,通过DirectSessionRegistrarGrpcSessionFactory的构造函数完成初始化,将不同的SessionFactory(工厂对象)写入到SessionFactory map中。

[ SessionFactory map的初始化过程 ]

囧~~~,扯了半天的代码和流程,貌似一点都没有提及单件设计模式。其实,单件设计模式在还是比较简单的。GOF中定义如下:

保证一个类仅有一个实例,并提供一个访问它的全局访问点。

tensorflow这里使用了单例中一种更灵活的模式:单件注册表,也就是使用的一个Singleton类的集合(从上图看到存储结构是std::unordered_map),Singleton类通过一个注册接口将自己的单件实例注册到集合中。而这里的tensorflow是通过DirectSessionRegistrarGrpcSessionFactory构造函数中的SessionFactory::Register接口完成注册。

4. 进阶


其实,在tensorflow中,上述模式还有很多资源管理的场景中使用。如下给出代码指引,感兴趣的同学可自行学习:

  • DeviceFactory //Tensorflow设备管理的代码
  • ExecutorFactory //Tensorflow图执行单元的代码

5. 参考


  1. 代码参考:tensorflow v1.12.0
  2. 画图:draw.io

更多优质内容请关注官方微信公众号

长按/识别关注我们
长按/识别关注我们

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 开篇
  • 2. 抽象工厂设计模式(Abstract Factory)
  • 3. 单件设计模式(Singleton)
  • 4. 进阶
  • 5. 参考
相关产品与服务
腾讯云 TI 平台
腾讯云 TI 平台(TencentCloud TI Platform)是基于腾讯先进 AI 能力和多年技术经验,面向开发者、政企提供的全栈式人工智能开发服务平台,致力于打通包含从数据获取、数据处理、算法构建、模型训练、模型评估、模型部署、到 AI 应用开发的产业 + AI 落地全流程链路,帮助用户快速创建和部署 AI 应用,管理全周期 AI 解决方案,从而助力政企单位加速数字化转型并促进 AI 行业生态共建。腾讯云 TI 平台系列产品支持公有云访问、私有化部署以及专属云部署。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档