本文原作者:于洋,经授权后发布。
通常,我们在使用Tensorflow低级API编程时(非Eager模式), 一般有下面三个步骤:
tf.Operation
(节点)和 tf.Tensor
(边)对象并将它们添加到 tf.Graph
实例中。例如,典型添加op操作就是tf.matmul
。with tf.Session() as sess:
,创建分布式会话with tf.Session("grpc://example.org:2222"):
sess.run(init_op)
, sess.run(train_op)
众所周知,tensorflow使用支持多种前端语言(python,js,swift,go等),执行引擎为C/C++后端实现。 那么,在上述三个步骤中,当用户python构建图,以及运行的图的时候。C/C++后端有在执行哪些工作呢? 按照对应的三个步骤,我们做如下拆解:
tf.Operation
(节点)和 tf.Tensor
(边)的同时,C/C++的后端也生成对应的节点和边,从而构造后端的图。本次分享的设计模式,就是在上述第二阶段时:创建本地session和分布式session时,tensorflow是怎样利用抽象工厂设计模式的?
在《设计模式》中描述的23设计模式,分为三类:创建型、结构型、行为型。其中,抽象工厂设计模式属于创建性设计模式。即是解决对象的创建需求。关于抽象工厂模式我的理解是这样的:
调用者有创建不同对象的需求(对象有一定相似性,例如轿车、卡车),调用者无需关注具体的实现类,而是通过抽象类定义的接口,就能创造不同对象。
当然,个人抽象理解和描述还是很难理解的。我们根据GOF书中,抽象工厂的模式结构图(图需要从右上角看起)在来理解一下:
[ 抽象工厂的模式结构图 - 《设计模式》58页 ]
有了上面粗浅的理解后,我们看一下tensorflow是如何使用抽象工厂模式,创建本地session和分布式session?
首先,我们看一下python创建Session调用栈:
NewSession
的代码如下:
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << s;
return s;
}
s = factory->NewSession(options, out_session);
if (!s.ok()) {
*out_session = nullptr;
}
return s;
}
代码很枯燥,我们看一下上述代码的时序图(以创建DirectSessione为例)。 上述代码对应着时序图的阶段2和阶段3。其中:
SessionFactory::GetFactory(options, &factory);
,factory->NewSession(options, out_session);
[ NewSession的时序图 ]
看到这里,我们温习一下抽象工厂的理解:
NewSession
)有创建GrpcSession
或者DirectSession
的需求;new DirectSession
或者new GrpcSession
创建;SessionFactory
)接口GetFactory
找到DirectSessionFactory
。最终通过DirectSessionFactory->NewSession
创建;Session
型(多态可以到GrpcSesion
或者DirectSession
对象)。值得说明的是:Client在整个过程中,并不清楚里面不同的Factory(GrpcSessionFactory
和DirectSessionFactory
),也不清楚不同的Session类型(GrpcSession
和DirectSession
)。
最后,参考抽象工厂结构图,大致画了如下Session的创建环节,大家可以在回味一下该设计模式(图也是从右上角看起):
[ 抽象工厂模式创建Session ]
至此,创建Session的主题框架已经大致梳理出来了。但是,上面的时序图中的阶段1一直还没有说明吧? 好,这部分涉及了单件设计模式。
后记:按照下面的定义,上述创建Session的模式(因为只创建了一种Session产品)是不是叫“工厂方法”会好一点?
说一下个人理解,tensorflow在设计这段代码的时候,做了很高程度的抽象,具备完成多个产品抽象的能力。我这里姑且认为应用的是抽象工厂模式。 大家也可以按照“工厂方法”模式理解上述代码,宗旨是:希望大家在学习tensorflow代码的过程,能了解里面蕴含的设计模式。
NewSession
中有这样的代码,不知道大家是否有注意到SessionFactory::GetFactory(options, &factory);
?这段代码的含义也就是根据传递的options
信息,选择是DirectSessionFactory
还是分布式GrpcSessionFactory
。
但是,大家在看时候,有没有这样的疑问:不同的SessionFactory
的是什么时候写入到SessionFactory map
中的?何况tensorflow这种没有main函数的程序?这个问题曾经一直很困扰我,在gdb debug后,我发现了下面的小trick。
诀窍在这行代码中static DirectSessionRegistrar registrar;
。
SessionFactory map
初始化的能量蕴含在这个static
变量的构造函数。下面的流程图揭示所有的秘密。结合代码,从图的左下角看起(下面的代码对应上面NewSession的时序图)。
和全局变量一样,static变量一直存储在程序的静态存储区。当程序初始化static变量时,通过
DirectSessionRegistrar
和GrpcSessionFactory
的构造函数完成初始化,将不同的SessionFactory
(工厂对象)写入到SessionFactory map
中。
[ SessionFactory map的初始化过程 ]
囧~~~,扯了半天的代码和流程,貌似一点都没有提及单件设计模式。其实,单件设计模式在还是比较简单的。GOF中定义如下:
保证一个类仅有一个实例,并提供一个访问它的全局访问点。
tensorflow这里使用了单例中一种更灵活的模式:单件注册表,也就是使用的一个Singleton类的集合(从上图看到存储结构是std::unordered_map
),Singleton类通过一个注册接口将自己的单件实例注册到集合中。而这里的tensorflow是通过DirectSessionRegistrar
和GrpcSessionFactory
构造函数中的SessionFactory::Register
接口完成注册。
其实,在tensorflow中,上述模式还有很多资源管理的场景中使用。如下给出代码指引,感兴趣的同学可自行学习:
更多优质内容请关注官方微信公众号
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。