前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >扩展之Tensorflow2.0 | 19 TF2模型的存储与载入

扩展之Tensorflow2.0 | 19 TF2模型的存储与载入

作者头像
机器学习炼丹术
发布2020-10-15 14:24:30
9250
发布2020-10-15 14:24:30
举报

【机器学习炼丹术】的学习笔记分享

参考目录:

  • 1 模型的构建
  • 2 结构参数的存储与载入
  • 3 参数的存储与载入
  • 4 结构的存储与载入

本文主要讲述TF2.0的模型文件的存储和载入的多种方法。主要分成两类型:模型结构和参数一起载入,模型的结构载入。

1 模型的构建

import tensorflow.keras as keras

class CBR(keras.layers.Layer):
    def __init__(self,output_dim):
        super(CBR,self).__init__()
        self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
        self.bn = keras.layers.BatchNormalization(axis=3)
        self.ReLU = keras.layers.ReLU()

    def call(self, inputs):
        inputs = self.conv(inputs)
        inputs = self.ReLU(self.bn(inputs))
        return inputs

class MyNet(keras.Model):
    def __init__ (self):
        super(MyNet,self).__init__()
        self.cbr1 = CBR(16)
        self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
        self.cbr2 = CBR(32)
        self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))

    def call(self, inputs):
        inputs = self.maxpool1(self.cbr1(inputs))
        inputs = self.maxpool2(self.cbr2(inputs))
        return inputs

model = MyNet()

部分朋友可以发现,上面的代码就是上一次课程所构建的一个自定义的网络。

我们现在需要展示这个模型的框架:

model.build((16,224,224,3))
print(model.summary())

运行结果为:

这里需要对网络执行一个构建.build()函数,之后才能生成model.summary()这样的模型的描述。 这是因为模型的参数量是需要知道输入数据的通道数的,假如我们输入的是单通道的图片,那么就是:

model.build((16,224,224,1))
print(model.summary())

输出结果为:

2 结构参数的存储与载入

model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')

这里并不能保存成功,出现这样的错误:

大概的意思就是:因为你的模型不是官方的模型,是自定义的,所以并不能同时保存结构和参数。只有官方的模型可以时候上面的保存的方法,同时保存参数和权重;自定义的模型建议只保存参数

3 参数的存储与载入

model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')

这样子就可以保存自定义的模型了。在对应的目录下会出现这几个文件:

我们来看一下原来的模型和载入的模型对于同一个样本给出的结果是否相同:

# 看一下原来的模型和载入的模型预测相同的样本的输出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]

结果相同,载入的没有问题~

4 结构的存储与载入

结构的存储有两种方法:

  • model.get_config()
  • model.to_json()

需要注意的是,上面的两个方法和save的问题一样,是不能用在自定义的模型中的,如果你在其中使用了自定义的Layer类,那么只能!只能用save_weights的方式进行保存

下面依然给出这两种方法的代码,对于简单的、已经封装好的一些网络层构成的网络,是可以使用这些的。我个人还是常用save_weights啦

# 第一种方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二种方法
json_config = model.to_json()
# 把json写的文件中
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)
# 读取本地json文件
with open('model_config.json') as json_file:
    json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)

今天的内容就是这么多,虽然提供了四种方法,但是对于自定义程度较高的模型,还是要使用save_weights哦~

- END -

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-10-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习炼丹术 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 模型的构建
  • 2 结构参数的存储与载入
  • 3 参数的存储与载入
  • 4 结构的存储与载入
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档