前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow的模型持久化

tensorflow的模型持久化

作者头像
狼啸风云
修改2022-09-04 22:17:40
1.8K0
修改2022-09-04 22:17:40
举报

1.持久化代码实现

tensorflow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver类。以下代码给出了保存tensorflow计算图的方法。

代码语言:javascript
复制
import tenosrflow as tf

# 声明两个变量并计算它们的和。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(1.0, shape=[1]), name="v2")
result = v1 + v2


init_op = tf.global_variables_initializer()

# 声明tf.train.Saver类用于保存模型。
saver = tf.train.Saver()


with tf.Session() as sess:
   sess.run(init_op):
   # 将模型保存到/path/to/model/model.ckpt文件。
   saver.save(sess, "/path/to/model/model.ckpt")

以上代码实现了持久化一个简单tensorflow模型的功能。在这段代码中,通过saver.save函数将tensorflow模型保存到了/path/to/model/model.ckpt文件中。tensorflow模型一般会保存在后缀为.ckpt的文件中。虽然以上程序只指定了一个文件路径,但是在这个文件目录下会出现三个文件。这是因为tensorflow会将计算图的结构和图上参数取值分来保存。

上面这段代码会生成的第一个文件为model.ckpt.meta,它保存了tensorflow计算图的结构。第二个文件为model.ckpt,这个文件中保存了tensorflow程序中每一个变量的取值。最后一个文件为checkpoint文件,这个文件中保存了一个目录下所有的模型文件列表。以下代码中给出了加载这个已经保存的tensorflow模型的方法。

代码语言:javascript
复制
import tensorflow as tf


# 使用核保存模型代码中一样的方式来声明变量。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2


saver = tf.train.Saver()



with tf.Session() as sess:
     # 加载已经保存的模型,并通过已经保存的模型中变量的值来计算加法。
     saver.restore(sess, "/path/to/model/model.ckpt")
     print sess.run(result)

这段加载模型的代码基本上和保存模型的代码时一样的。在加载模型的程序中也是先定义了tensorflow计算图上的所有运算,并声明了一个tf.train.Saver类。两段代码唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。以下代码给出了一个样例。

代码语言:javascript
复制
import tensorflow as tf
# 直接加载持久化的图。
saver = tf.train.import_meta_graph("path/to/model/model.ckpt/model.ckpt.meta")
       

with tf.session() as sess:
     saver.restore(sess, "/path/to/model/model.ckpt")
     # 通过张量的名称来获取张量。
     print sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))
     # 输出[ 3.]

在上面的程序中,默认保存和加载了tensorflow计算图上定义的全部变量。但有可能只需要保存或者加载部分变量。比如,可能有一个之前训练好的五层神经网络模型,但现在想尝试一个六层的神经网络,那么可以将前面五层神将网络中的参数直接加载到新的模型,而仅仅将最后一层神将网路重新训练。

为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。如果运行修改后只加载了v1的代码会得到变量未初始化的错误:

代码语言:javascript
复制
tensorflow.python.framework.errors.FailedPreconditionError: 
Attempting to use uninitialized value v2

因为v2没有加载,所以v2在运行初始化之前是没有值的。除了可以选取需要被加载的变量,tf.train.Saver类也支持在保存或者加载时给变量重命名。下面给出了一个简单的样例程序说明变量重命名是如何被使用的。

代码语言:javascript
复制
# 这里声明的变量名称和已经保存的模型中变量的名称不同。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")



# 如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误。下面显示了报错信息:
# tensorflow.python.framework.errors.NotFoundError: Tensor name "other=v2"
# not found in checkpoint files  /path/to/model.ckpt

# 使用一个字典(dictionary)来重命名变量就可以加载原来的模型了。这个字典指定了
# 原来名称为v1的变量现在加载到变量v1中(名称为other-v1),名称为v2的变量
# 加载到变量v2中(名称为other=v2)。
saver = tf.train.Saver({"v1": v1, "v2": v2})

在这个程序中,对变量v1和v2名称进行了修改。如果直接通过tf.train.Saver默认的构造函数来加载保存的模型,那么程序会报变量找不到的错误。因为保存时候变量的名称和加载时变量的名称不一致。为了解决这个问题,tensorflow可以通过字典(dictionary)将模型保存时的变量名的需要加载的变量联系起来。

这样做只要目的之一是方便使用变量的滑动平均值,滑动平均值可以让神经网络更加健壮(robust)。在tensorflow中,每一个变量的滑动均值是通过影子变量维护的,所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型就不需要再调用函数来获取变量的滑动平均值了。这样大大方便了滑动平均模型的使用。以下代码给出了一个保存滑动平均模型的样例。

代码语言:javascript
复制
import tensorflow as tf

v = tf.Variable(0, dtype = tf.float32, name = "v")
for variables in tf.global_variable():
    print variables.name


ema = tf.train.ExponentialMovingAverage(0.99)
maintian_average_op = ema.apply(tf.global_variables())
# 在申明滑动平均模型之后,tensorflow会自动生成一个影子变量
# v/ExponentialMovingAverage。于是以下语句会输出
# "v:0"和"v/ExponentialMovingAverage:0"。
for variables in tf.global_variables():
    print variable.name


saver = tf.train.Saver()
with tf.Session() as sess:
     init_op = tf.global_variables_initializer()
     sess.run(init_op)
    
     sess.run(tf.assign(v, 10))
     sess.run(maintain_average_op)
     # 保存时,tensorflow会将v:0和v/ExponentialMovingAverage:0两个变量都存下来。
     saver.save(sess, "/path/to/model/model.ckpt")
     print sess.run([v, ema.average(v)])   # 输出[10.0, 0.0999999005]

以下代码给出了如何通过变量重命名直接读取变量的滑动平均。从下面的输出可以看出,读取的变量v的值实际上是上面代码中变量v的滑动平均值。通过这个方法,就可以使用完全一样的代码来计算滑动平均模型前向传播的结果。

代码语言:javascript
复制
v = tf.Variable(0, dtype=tf.float32, name="v")
# 通过变量重命名将原来变量v的滑动平均值直接赋给v。
saver = tf.train.Saver(("v/ExponentialMovingAverage": v))
with tf.Session() as sess:
     saver.restore(sess, "/path/to/model/model.ckpt")
     print sess.run(v)  # 输出0.099999905,这个值就是原来模型中变量v的滑动平均值。

为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore函数来生成tf.train_Saver类所需要的变量重命名字典。以下代码给出类variables_to_restore函数的使用样例。

代码语言:javascript
复制
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)


# 通过使用variables_to_restore函数可以直接生成上面代码中提供的字典。
# {"v/ExponentialMovingAverage": v}。
# 以下代码会输出:
# {'v/ExponentialMovingAverage': <tensorflow.Variable 'v:0' shape=()
# dtpye=float32_ref>}
# 其中后面的Variable类就代表了变量v。
print ema.variables_to_restore()


saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:
     saver.restore(sess, "/path/to/model/model.ckpt")
     print sess.run(v)  # 输出0.099999905,即原来模型中变量v的滑动平均值

使用tf.train.Saver会保存运行tensorflow程序所需要的全部信息,然而有时候并不需要某些信息。比如在测试或者离线预测试时,只需要知道如何从神经网络的输入层经过前向传播稀疏得到输出层即可,而不需要类似于变量初始化、模型保存等辅助节点的信息。而且,将变量取值和计算图结构分成不同的文件存储有时候也不方便,于是tensorflow提供了convert_variable_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个tensorflow计算图可以统一存放在一个文件中。以下程序提供了一个样例。

代码语言:javascript
复制
import tensorflow as tf 
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

init_op = tf.global_variable_iniitializer()

with tf.Session() as sess:
     sess.run(init_op)
     # 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算
     # 过程。
     graph_def = tf.get_default_graph().as_graph_def()
     
     # 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉。
     # 如果只关心程序中定义的某些计算时,和这些计算无关的节点就没有必要导出并保存了。在下面一行
     # 代码中,最后一个参数['add']给出了需要保存的节点名称。add节点是上面定义的两个变量相加的
     # 操作。注意这里给出的是计算节点的名称,所以没有后面的:0。
     output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
     # 将导出的模型存入文件。
     with tf.gfile.GFile("/path/to/model/combined_model.pb", "wb") as f:
         f.write(output_graph_def.SerializeToString())

通过以下程序可以直接计算定义的加法运算的结果。当只需要得到计算图中某个节点的取值时,这提供了一个更加方便的用法。

代码语言:javascript
复制
import tensorflow as tf 

from tensorflow.python.platfrom import gfile

with tf.Session() as sess:
    model_filename = "/path/to/model/combined_model.pb"
    # 读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer。
    with gfile.FastGFile(model_filename, 'rb') as f:
         graph_def = tf.GraphDef()
         graph_def.ParseFromString(f.read())
    
    # 将graph_def中保存的图加载到当前的图像中。return_elements = ["add:0"]给出了
    # 返回的张量的名称。在保存的时候给出的是计算节点的名称,所以为"add"。在加载的时候
    # 给出的是张量的名称,所以是add:0。
    result = tf.import_graph_def(graph_def, return_elements = ["add:0"])
    # 输出[3.0]
    print sess.run(result)

2、持久化原理及数据格式

tensorflow是一个通过图的形式来表达计算的编程系统,tensflow程序中的所有计算都会被表达为计算图上的节点。tensorflow通过原图(MateGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的原结构。tensorflow中元图是由MetaGraphDef Proticol Buffer定义的。MetaGraphDef中的内容就构成了tensorflow持久化时的第一个文件。以下代码给出了MetaGraphDef类型的定义。

代码语言:javascript
复制
message MetaGrapher {
    MetaInfoDef meta_info_def = 1;
    
   GraphDef graph_def = 2;
   SaverDef saver_def = 3;
   map<string, CollectionDef> collection_def = 4; 
   map<string, SignatureDef>  signature_def = 5;
   repeated AssetFileDef asset_file_def = 6; 
}

从以上代码可以看到,元图主要记录了6类信息。保存MetaGraphDef信息的文件默认以.meta为后缀名,文件model.ckpt.mate中存储的就是元图的数据。tensorflow提供了export_meta_graph函数,这个函数支持以json格式导出MetaGraphDef Protocol Buffer。以下代码展示了如何使用这个函数。

代码语言:javascript
复制
import tensorflow as tf


# 定义变量相加的计算。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2


saver = tf.train.Saver()
# 通过export_meta_graph函数导出tensorflow计算图的元图,并保存为json格式。
saver.export_meta_graph("/path/to/model.ckpt.meda.json", as_text=True)

通过上面给出的代码,可以将计算图以json的格式导出并存储在model.ckpt.meta.json文件中。下文将结合model.ckpt.meta.json文件具体介绍tensorflow元图中存储的信息。

meta_info_def属性

meta_info_def属性是通过MetaInfoDef定义的,它记录了tensorflow计算图中的元数据以及tensorflow程序中所有使用到的运算方法的信息。下面是MetaInfoDef Protocol Buffer的定义:

代码语言:javascript
复制
message MetaInfoDef {
    string meta_graph_version = 1;
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
    string tensorflow_version = 5;
    string tensorflow_git_version = 6;
}

tensorflow计算图的元数据包括了计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。如果没有在saver中特殊指定,那么这些属性都默认为空。在model.ckpt.meta.json文件中,meta_info_def属性里只有stripped_op_list属性是不为空的。stripped_op_list属性记录了tensorflow计算图上使用到的所有运算方法的信息。注意stripped_op_list属性保存的是tensorflow运算方法的信息,所以如果某一个运算在tensorflow计算图中出现了多次,那么在stripped_op_list也只出现一次。比如在model.ckpt.meta.json文件中的stripped_op_list属性中只有一个Variable运算,但这个运算在程序中被使用了两次,stripped_op_list属性的类型是Oplist。Oplist类型是一个OpDef的列表,以下代码给出了OpDef类型的定义:

代码语言:javascript
复制
message OpDef {
    string name = 1;


    repeated ArgDef input_arg = 2;
    repeated ArgDef output_arg = 3;
    repeated AttrDef attr = 4;

    OpDeprecation deprecation = 8;
    string summary = 5;
    string description = 6;
    
    bool is_commutative = 18;
    bool is_aggregate   = 16;
    bool is_statefull   = 17;
    bool allows_uninitialized_input = 19;
};

OpDef类型中前4个属性定义了一个运算最核心的信息。OpDef中的第一个属性name定义了运算的名称,这也是一个运算唯一的标识符。在tensorflow计算图元图的其他属性中,比如下面将要介绍的GraphDef属性,将通过运算名称来引用不同的运算。OpDef的第二和第三个属性为input_arg和output_arg,它们定义了运算的输入和输出。因为输入输出都可以有多个,所以这两个属性都是列表(repeated)。第四个属性attr个西湖了其他的运算参数信息。在model.ckpt,meta,json文件中总共定义了8个运算,下面将给出比较有代表性的一个运算来辅助说明OpDef的数据结构。

代码语言:javascript
复制
op  {
    name: "Add"
    input_arg {
        name: "x"
        type_attr: "T"
    }    
    input_arg {
        name: "y"
        type_attr: "T"
    }
    output_arg{
        name: "x"
        type_attr: "T"
    }
    attr {
       name: "T"
       type: "type"  
       allowed_values  {
          list  {
             type: DT_HALF
             type: DT_FLOAT
             ...
             }
         }
    }
}

上面给出了名称为Add的运算。这个运算有2个输入和1个输出。输入输出属性都指定了属性type_attr,并且这个属性的值为T。在OpDef的attr属性中,必须要出现名称为(name)为T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型(allowed_values)。MetaInfoDef中的tensorflow_version和tesnorflow_git_version属性记录了生成当前计算图的tensorflow版本。

graph_def属性

graph_def属性主要记录了tensorflow计算图上的节点信息。tensorflow计算图的每一个节点对应了tensorflow程序中的一个运算。因为在meta_info_def属性中已经包含了所有运算的具体信息,所以graph_def属性只关注运算的连接结构。graph_def属性是通过GraphDef Protocol Buffer定义的,GraphDef主要包含了一个NodeDef类型的列表。以下代码给出了GraphDef和NodeDef类型中包含的信息:

代码语言:javascript
复制
message GraphDef{
   repeated NodeDef node = 1;
   VersionDef versions = 4;
};


message NodeDef {
   string name = 1;
   string op = 2;
   repeated string input = 3;
   string device = 4;
   map<string, AttrValue> attr = 5;
};

GraphDef中的versions属性比较简单,它主要存储了tensorflow的版本号。GraphDef的主要信息都存在node属性中,它记录了tensorflow计算图上所有的节点信息。和其他属性类似,NodeDef类型中有一个名称属性name,它是一个节点的唯一标识符。在tensorflow程序中可以通过节点的名称来获取想用的节点。NodeDef类型中的op属性给出了该节点使用的tensorflow运算方法的名称,通过这个名称可以在tenosrflow计算图元图的meta_info_def属性中找到该运算的具体信息。

NodeDef类型中input属性是一个字符串列表,它定义了运算的输入。input属性中每个字符串的取值格式为node:src_output,其中node部分给出了一个节点的名称,src_ouput部分表明了这个输入时指定节点的第几个输出。当src_output为0时,可以省略:src_output这个部分。比如node:0表示名称为node的节点的第一个输出,它可以被记为node。

NodeDef类型中的device属性指定了处理这个运算的设备。运行tensorflow运算的设备可以是本地机器的CPU或者GPU,也可以是一台远程的机器CPU或者GPU。当device属性为空时,tensorflow在运行时会自动选取一个最合适的设备来运行这个运算。最后NodeDef类型中的attr属性指定了和当前运算相关的配置信息。下面列举了model.ckpt.meta.json文件中的一些计算节点来更加具体地介绍graph_def属性。

代码语言:javascript
复制
graph_def  {
  node {
    name: "v1"
    op: "VariableV2"
    attr  { 
      key: "_output_shapes" 
      value { 
        list  {  shape  {   dim   {  size:  1    }   }   }
      }
    }
    
    attr  {
       keys: "dtype"
       value  {
          type: DT_FLOAT
       }
    }
    ...
    }
    node  {
       name: "add"
       op: "Add"
       input: "v1/read"
       input: "v2/read"
       ...
    }
    
    node  {
      named: "save/control_dependency"
      op: "Identity"
      ...
    }
    
    version {
      preducer: 24
    }


 }

上面给出了model.ckpt.meta.json文件中graph_def属性里比较有代表性的几个节点。第一个节点给出的是变量定义的运算。在tensorflow中变量定义也是一个运算,这个运算的名称为v1(name:"v1"),运算方法的名称为Variable(op:"VariableV2")。定义变量的运算可以有很多个,于是在NodeDef类型的node属性中可以有多个变量定义的节点。但定义变量的运算方法只用到了一个,于是在MetaInfoDef类型的stripped_op_list属性中只有一个名称为VariableV2的运算方法。除了指定计算图中节点的名称和运算方法,NodeDef类型还定义了运算相关的属性。在节点v1中,attr属性指定了这个变量的维度以及类型。

给出的第二个节点是代表加法的节点。它指定了2个输入,一个为v1/read,另一个为v2/read。其中v1/read代表的节点可以读取变量v1的值。因为v1的值是节点v1/read的第一个输出,所以后面的:0就可以省略了。v2/read也类似的代表了变量v2的取值。以上样例文件中给出的最后一个名称为save/control_dependency,该节点是系统在完成tensorflow模型持久化过程中自动生成的一个运损。在样例文件的最后,属性version给出了生成model.skpt.meta.json文件时使用的tensorflow版本号。

saver_def属性

save_def属性中记录了持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作和加载操作的名称以及保存频率、清理历史记录等。saver_def属性的类型为SaverDef,其定义如下。

代码语言:javascript
复制
message SaverDef {
    string filename_tensor_name = 1;
    string save_tensor_name = 2;
    string restore_op = 3;
    int32 max_to_keep = 4; 
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    
    enum CheckpointFormatVersion {
      LEGACY = 0; 
      V1 = 1;
      v2 = 2;
    }
    CheckpointFormatVersion version = 7;
}

下面给出了model.ckpt.mate.json文件中saver_def属性的内容。

代码语言:javascript
复制
saver_def {
   filename_tensor_name: "save/Const:0"
   save_tensor_name: "save/control_dependency:0"
   restore_op_name: "save/restore_all"
   max_to_keep: 5
   keep_checkpoint_every_n_hours: 10000.0
   version: V2
}

filename_tensor_name属性给出了保存文件名的张量名称,这个张量就是节点save/Const的第一个输出。save_tensor_name属性给出了持久化tensorflow模型的运算所对应的节点名称。从以上文件可以看出,这个节点就是在graph_def属性中给出的save/control_dependency节点。和持久化tensorflow模型运算对应的是加载tensorflow模型的运算,这个运算的名称是由restore_op_name属性指定。max_to_keep属性和keep_checkpoint_every_n_hours属性设定了tf,train.Saver类清理之前保存的模型的策略。比如当max_to_keep为5的时候,在第六次调用saver.save时,第一次保存的模型就会被自动删除。通过设置keep_checkpoint_every_n_hours,每n小时可以在max_to_keep的基础上多保存一个模型。

collection_def属性

在tensorflow的计算图(tf.Graph)中可以维护不同集合,而维护这些集合的底层实现就是通过collection_def这个属性。collection_def属性是一个从集合内容的映射,其中集合名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出了CollectionDef类型的定义。

代码语言:javascript
复制
message CollectionDef{
   message NodeList {
      repeated string value = 1;
   }

   message BytesList {
      repeated bytes value = 1;
   }
   
   message Int64List {
      repeated int64 value = 1 [packed = true];
   
   message FloatList {
      repeated float value = 1 [packed = true];
   }
   
   message AnyList {
      repeated google.protobuf.Any value = 1;
   } 
   
   oneof kind {
     NoneList node_list = 1;
     BytesList bytes_list = 2;
     Int64List int64_list = 3;
     FloatList float_list = 4;
     AnyList any_list = 5;
   }
}

通过以上定义可以看出,tensorflow计算图上的集合主要可以维护4类不同的集合。NodeLIst用于维护计算图上节点的集合。ByteList可以维护字符串或者系列化之后的Procotol Buffer的集合。比如张量是通过Protocol Buffer表示的,而张量的集合是通过BytesList维护的,我们将在model.ckpt.meta.json文件中看到具体样例。Int64List用于维护整数集合,FloatLIst用于维护实数集合。下面给出了model.ckpt.meta.json文件中collection_def属性的内容。

代码语言:javascript
复制
collection_def {
   key: "trainable_variables"
   value {
      bytes_list {
         value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
         value: "\n\004v1:0\022\tv1/Assign\032\tv2/read:0"
      }
   }
}
collection_def {
   key: "variables"
   value {
      bytes_list {
         value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0" 
         value: "\n\004v1:0\022\tv1/Assign\032\tv2/read:0"
     }
   }
}

从以上文件可以看出样例程序中维护了两个集合。一个是所有变量的集合。这个集合的名称为variable。另外一个是可训练变量的集合,名为trainable_variables。在样例程序中,这两个集合中的元素是一样的,都是变量v1和v2,它们都是系统自动维护的。

通过对MetaGraphDef类型中主要属性的讲解,本节已经介绍了tensorflow模型持久化得到的第一个文件的内容。除了持久化tensorflow计算图的结构,持久化tensorflow中变量的取值也是非常重要的一个部分。其中model.ckpt.data文件时通过SSTable格式存储的,可以大致理解为就是一个(key, value)列表。tensorflow提供了tf.train.NewCheckpointReader类来查看保存的变量信息。以下代码展示了如何使用tf.train.NewCheckpointReader类。

代码语言:javascript
复制
import tensorflow as tf

# tf.train.NewCheckpointReader可以读取checkpoint文件中保存的所有变量。
# 注意后面的.data和.index可以省去。
reader = tf.train.NewCheckpointReader('/path/to/model/model.test')


# 获取所有变量列表。这个是一个从变量名到变量维度的字典。
global_variable = reader.get_variable_to_shape_map()
for variable_name in global_variables:
    # variable_name 为变量名称,global_variable[variable_name]为变量的维度。
    print variable_name, global_variables[variable_name]


# 获取名称为v1的变量的取值。
print "Value for variable v1 is ", reader.get_tensor("v1")


'''
这个程序将输出:
v1 [1]                            # 变量v1的维度为[1]。
v2 [2]                            # 变量v2的维度为[1]。
Value for variable v1 is [ 1.]    # 变量v1的取值为1。

最后一个文件的名字是固定的,叫checkpoint。这个文件是tf.train.Saver类自动生成且自动维护的。在checkpoint文件中维护了由一个tf.train.Saver类持久化的所有tensorflow模型文件的文件名。当某个保存的tensorflow模型文件被删除时,这个模型所对应的文件名也从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer,下面给出了CheckpointState类型的定义。

代码语言:javascript
复制
message CheckpointState {
   string model_checkpoint_path = 1;
   repeated string all_model_checkpoint_paths = 2;
}

model_checkpoint_path属性保存了最新的tensorflow模型文件的文件名。all_model_checkpoint_paths属性列出了当前还没有被删除的所有tensorflow模型文件的文件名,下面给出了checkpoint文件。

代码语言:javascript
复制
model_checkpoint_path: "/path/to/model/model.ckpt"
all_model_checkpoint_paths: "/path/to/model/model.ckpt"
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年02月23日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.持久化代码实现
  • 2、持久化原理及数据格式
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档