专栏首页用户5033944的专栏【TensorFlow2.x开发—基础】 模型保存、加载、使用
原创

【TensorFlow2.x开发—基础】 模型保存、加载、使用

前言

本文主要介绍在TensorFlow2 中使用Keras API保存整个模型,以及如果使用保存好的模型。保存整个模型时,有两种格式可以实现,分别是SaveModel和HDF5;在TF2.x中默认使用SavedModel格式。

文章分为简约版、实践版、代码版,首先从简约版认识基本流程、要点、代码等;再到实践版查看每一步的代码调试结果;最后通过代码版,可以自己尝试实践。

简约版

一、HDF5格式

HDF5标准提供了一种基本保存模型格式,也是常见的模型xxx.h5;通过HDF5格式会保存整个模型的权值值、模型的架构、模型的训练配置、优化器及状态等。

使用model.save() 保存,使用tf.keras.models.loda_model加载模型;

首先安装一下相关的依赖库,执行如下命令即可:

pip install pyyaml h5py

1.1)保存模型

# 创建并训练一个新的模型实例
model = create_model()
model.fit(train_images, train_labels, epochs = 5)

# 以HDF5 格式保存模型,保存后是xxx.h5的文件
model.save("my_model.h5")

1.2)加载使用模型

加载模型:

# 重新创建完成相同的模型,包括权值和优化程序等
new_model = tf.keras.models.load_model("my_model.h5")

# 查看模型的结构
new_model.summary()

检查其准确率(accuracy):

loss, acc = new_model.evaluate(test_images, test_labels, verbose = 2)
print("评估保存好的模型 准确率:{:5.2f}%".format(100 * acc))

二、SavedMode格式

SavedModel格式是序列化模型的一种方法,是一个包含Protobuf二进制文件和Tensorflow检查点(checkpoint)的目录;

SavedModel格式也是使用model.save() 保存模型,使用tf.keras.models.loda_model加载模型;这种方式于Tensorflow Serving兼容。

2.1)保存模型

创建并训练一个新的模型实例,然后把训练好模型保存在saved_model 目录下,保存模型的名称为:my_model

# 创建并训练一个新的模型实例。
model = create_model()
model.fit(train_images, train_labels, epochs = 5)

# 以SavedModel格式保存整个模型
model.save("saved_model/my_model")

SavedModel 格式是一个包含 protobuf 二进制文件和 Tensorflow 检查点(checkpoint)的目录。检查保存的模型目录:

# 首先查看 保存模型的目录saved_model下有那些文件
ls saved_model

# 查看我们刚才保存的模型my_model
ls saved_model/my_model

能看到一个assets文件夹,saved_model.pd,和变量文件夹。

2.2)加载使用模型

加载保存好的模型:

new_model = tf.keras.models.load_model("saved_model/my_model")

# 看到模型的结构
new_model.summary()

使用模型:

# 评估模型
loss, acc = new_model.evaluate(test_images, test_labels, verbose = 2)
print("评估保存好的模型 准确率:{:5.2f}%".format(100 * acc))

实践版

一、HDF5格式

HDF5标准提供了一种基本保存模型格式,也是常见的模型xxx.h5;通过HDF5格式会保存整个模型的权值值、模型的架构、模型的训练配置、优化器及状态等。

首先安装一下相关的依赖库,执行如下命令即可:

pip install pyyaml h5py

1.1)保存模型

1.2)加载使用模型

加载模型:

检查其准确率(accuracy):

二、SavedMode格式

SavedModel格式是序列化模型的一种方法,是一个包含Protobuf二进制文件和Tensorflow检查点(checkpoint)的目录;

其使用model.save() 保存,使用tf.keras.models.loda_model加载模型;这种方式于Tensorflow Serving兼容。

2.1)保存模型

创建并训练一个新的模型实例,然后把训练好模型保存在saved_model 目录下,保存模型的名称为:my_model

SavedModel 格式是一个包含 protobuf 二进制文件和 Tensorflow 检查点(checkpoint)的目录。检查保存的模型目录:

能看到一个assets文件夹,saved_model.pd,和变量文件夹。

2.2)加载使用模型

加载保存好的模型:

使用模型:

代码版

HDF5格式:

# 导入Tensorflow和依赖项
import os
import tensorflow as tf
from tensorflow import keras


# 获取示例数据集,使用 MNIST 数据集,主要使用使用前1000个示例
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0


# 定义模型,首先构建一个简单的序列(sequential)模型
# 定义一个简单的序列模型
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

# 创建并训练一个新的模型实例
model = create_model()
model.fit(train_images, train_labels, epochs = 5)

# 以HDF5 格式保存模型,保存后是xxx.h5的文件
model.save("my_model.h5")

# 重新创建完成相同的模型,包括权值和优化程序等
new_model = tf.keras.models.load_model("my_model.h5")

# 查看模型的结构
new_model.summary()

loss, acc = new_model.evaluate(test_images, test_labels, verbose = 2)
print("评估保存好的模型 准确率:{:5.2f}%".format(100 * acc))

SavedMode格式:

# 导入Tensorflow和依赖项
import os
import tensorflow as tf
from tensorflow import keras


# 获取示例数据集,使用 MNIST 数据集,主要使用使用前1000个示例
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0


# 定义模型,首先构建一个简单的序列(sequential)模型
# 定义一个简单的序列模型
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model


# 创建并训练一个新的模型实例。
model = create_model()
model.fit(train_images, train_labels, epochs = 5)

# 以SavedModel格式保存整个模型
model.save("saved_model/my_model")


new_model = tf.keras.models.load_model("saved_model/my_model")

# 看到模型的结构
new_model.summary()

# 评估模型
loss, acc = new_model.evaluate(test_images, test_labels, verbose = 2)
print("评估保存好的模型 准确率:{:5.2f}%".format(100 * acc))

print(new_model.predict(test_images).shape)

小结

保存整个模型时,有两种方式实现,分别是SaveModel和HDF5;两种都是使用model.save() 保存模块,使用tf.keras.models.loda_model加载模型;

HDF5格式 保存模型后,生成xxx.h5,比较常用。

SavedModel格式 保存模型后,是一个包含Protobuf二进制文件和Tensorflow检查点(checkpoint)的目录;

加油加油~~ 欢迎交流呀

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • [阿里DIN] 模型保存,加载和使用

    Deep Interest Network(DIN)是阿里妈妈精准定向检索及基础算法团队在2017年6月提出的。其针对电子商务领域(e-commerce ind...

    罗西的思考
  • Tensorflow Object Detection API 终于支持tensorflow1.x与tensorflow2.x了

    基于tensorflow框架构建的快速对象检测模型构建、训练、部署框架,是针对计算机视觉领域对象检测任务的深度学习框架。之前tensorflow2.x一直不支持...

    OpenCV学堂
  • TensorFlow2 开发指南 | 01 手写数字识别快速入门

    在上一个专栏【TF2.0深度学习实战——图像分类】中,我分享了各种经典的深度神经网络的搭建和训练过程,比如有:LeNet-5、AlexNet、VGG系列、Goo...

    AI菌
  • 掌握TensorFlow1与TensorFlow2共存的秘密,一篇文章就够了

    TensorFlow是Google推出的深度学习框架,也是使用最广泛的深度学习框架。目前最新的TensorFlow版本是2.1。可能有很多同学想跃跃欲试安装Te...

    蒙娜丽宁
  • 『带你学AI』开发环境配置之Windows10篇:一步步带你在Windows10平台开发深度学习

    1. 章节一:初探AI(《带你学AI与TensorFlow2实战一之深度学习初探》):(已完成)

    小宋是呢
  • 资源|Keras框架速查表

    Keras是强大、易用的深度学习库,基于Theano和Tensorflow提供的高阶神经网络API,用于开发和评估深度学习模型。由于keras创始人加入goog...

    触摸壹缕阳光
  • 如何 30 天吃掉 TensorFlow2.0 ?

    Keras可以看成是一种深度学习框架的高阶接口规范,它帮助用户以更简洁的形式定义和训练深度学习网络。

    double
  • 基于Tensorflow2 Lite在Android手机上实现图像分类

    Tensorflow2之后,训练保存的模型也有所变化,基于Keras接口搭建的网络模型默认保存的模型是h5格式的,而之前的模型格式是pb。Tensorflow2...

    夜雨飘零
  • TensorFlow2.x目标检测API测试代码使用演示

    TensorFlow2.x Object Detection API 的安装与配置可参考前面的两篇文章:

    Color Space
  • TensorFlow2.x目标检测API安装配置步骤详细教程

    TensorFlow Object Detection API支持TensorFlow2.x版本已经有一段时间了,这里对安装配置步骤做详细说明。

    Color Space
  • 【深度学习】Tensorflow2.x入门(一)建立模型的三种模式

    最近做实验比较焦虑,因此准备结合推荐算法梳理下Tensorflow2.x的知识。介绍Tensorflow2.x的文章有很多,但本文(系列)是按照作者构建模型的思...

    黄博的机器学习圈子
  • TensorFlow2.x GPU版安装与CUDA版本选择指南

    目前Python最新release版本为3.9.0,配合TensorFlow2版本使用目前常见的以Python3.6和3.7,大家根据自己的开发平台选择...

    Color Space
  • 使用OpenCV加载TensorFlow2模型

    Suaro希望使用OpenCV来实现模型加载与推演,但是没有成功,因此开了issue寻求我的帮助。

    小白学视觉
  • Pytorh与tensorflow对象检测模型如何部署到CPU端,实现加速推理

    对象检测是计算机视觉最常见的任务之一,应用非常广泛,本文主要给给大家价绍两条快速方便的自定义对象检测模型的训练与部署的技术路径,供大家实际项目中可以参考。

    OpenCV学堂
  • 微信小程序|调用tensorflow自定义模型

    在成功调用官网打包好的tensorflowjs模型后,怎么调用自己的模型呢?又需要做哪些处理呢?

    算法与编程之美
  • Tensorflow + OpenCV4 安全帽检测模型训练与推理

    如何安装tensorflow object detection API框架,看这里:

    OpenCV学堂
  • 带你入门机器学习与TensorFlow2.x

    本文主要介绍人工智能、机器学习和深度学习的区别,以及软硬件环境的搭建,包括Tensorflow1.x和Tensorflow2.x在同一台机器上如何共存。在后续的...

    蒙娜丽宁
  • TensorFlow2 开发指南 | 02 回归问题之汽车燃油效率预测

    这个专栏我将分享我的 TensorFlow2 学习过程,力争打造一个的轻松而高效的TensorFlow2入门学习教程,想学习的小伙伴可以关注我的动态!我们一起学...

    AI菌
  • 【TensorFlow2.x开发—基础】 简介、安装、入门应用案例

    本文介绍最新版本的TensorFlow开发与应用,目前最新版本是TensorFlow2.5.0;首先简单介绍一下TensorFlow,然后安装TensorFlo...

    一颗小树x

扫码关注云+社区

领取腾讯云代金券