首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow2.0(12):模型保存与序列化

TensorFlow2.0(12):模型保存与序列化

作者头像
统计学家
发布2019-12-30 14:11:11
1.6K0
发布2019-12-30 14:11:11
举报

TensorFlow2.0(1):基本数据结构——张量

TensorFlow2.0(2):数学运算

TensorFlow2.0(3):张量排序、最大最小值

TensorFlow2.0(4):填充与复制

TensorFlow2.0(5):张量限幅

TensorFlow2.0(6):利用data模块进行数据预处理

TensorFlow2.0(7):4种常用的激活函数

TensorFlow2.0(8):误差计算:损失函数总结

TensorFlow2.0(9):神器级可视化工具TensorBoard

TensorFlow2.0(10):加载自定义图片数据集到Dataset

TensorFlow2.0(11):tf.keras建模三部曲

模型训练好之后,我们就要想办法将其持久化保存下来,不然关机或者程序退出后模型就不复存在了。本文介绍两种持久化保存模型的方法:

在介绍这两种方法之前,我们得先创建并训练好一个模型,还是以mnist手写数字识别数据集训练模型为例:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, Sequential
model = Sequential([  # 创建模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,  # 进行简单的1次迭代训练
                    batch_size=64,
                    epochs=1)
Train on 60000 samples
60000/60000 [==============================] - 3s 46us/sample - loss: 2.3700

方法一:model.save()

通过模型自带的save()方法可以将模型保存到一个指定文件中,保存的内容包括:

  • 模型的结构
  • 模型的权重参数
  • 通过compile()方法配置的模型训练参数
  • 优化器及其状态
model.save('mymodels/mnist.h5')

使用save()方法保存后,在mymodels目录下就会有一个mnist.h5文件。需要使用模型时,通过keras.models.load_model()方法从文件中再次加载即可。

new_model = keras.models.load_model('mymodels/mnist.h5')
WARNING:tensorflow:Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.

新加载出来的new_model在结构、功能、参数各方面与model是一样的。

通过save()方法,也可以将模型保存为SavedModel 格式。SavedModel格式是TensorFlow所特有的一种序列化文件格式,其他编程语言实现的TensorFlow中同样支持:

model.save('mymodels/mnist_model', save_format='tf')  # 将模型保存为SaveModel格式
WARNING:tensorflow:From /home/chb/anaconda3/envs/study_python/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: mymodels/mnist_model/assets
new_model = keras.models.load_model('mymodels/mnist_model')  # 加载模型
print(keras.models.__dir__())
['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '_sys', 'Sequential', 'Model', 'clone_model', 'model_from_config', 'model_from_json', 'model_from_yaml', 'load_model', 'save_model']

方法二:model.save_weights()

save()方法会保留模型的所有信息,但有时候,我们仅对部分信息感兴趣,例如仅对模型的权重参数感兴趣,那么就可以通过save_weights()方法进行保存。

model.save_weights('mymodels/mnits_weights')  # 保存模型权重信息
new_model = Sequential([  # 创建新的模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
new_model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
new_model.load_weights('mymodels/mnits_weights')  # 将保存好的权重信息加载的新的模型中
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f49c42b87d0>

注:本系列所有博客将持续更新并发布在github上,您可以通过github下载本系列所有文章笔记文件。

https://github.com/ChenHuabin321/tensorflow2_tutorials

作者博客:

https://www.cnblogs.com/chenhuabin

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习与统计学 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 方法一:model.save()
  • 方法二:model.save_weights()
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档