MXNET学习笔记(二):模型的保存与加载

当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值。当序列化 Symbol 的时候,我们序列化的是 Graph。

Symbol序列化

当序列化 Symbol 的时候,通常使用 json 文件作为序列化后的文件,因为可读性好。

import mxnet as mx
a = mx.sym.Variable('a', shape=[2,])
b = mx.sym.Variable('b', shape=[3,])
c = a+b
print(c.tojson()) # 打印出来 json 文件,看看里面是啥
c.save('symbol-c.json') # 保存文件

c2 = mx.sym.loads('symbol-c.json') # 加载 json 文件,此时 c2 就代表一个 symbol
{
  "nodes": [
    {
      "op": "null", 
      "name": "a", 
      "attr": {"__shape__": "[2]"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "b", 
      "attr": {"__shape__": "[3]"}, 
      "inputs": []
    }, 
    {
      "op": "elemwise_add", 
      "name": "_plus0", 
      "inputs": [[0, 0, 0], [1, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 1], 
  "node_row_ptr": [0, 1, 2, 3], 
  "heads": [[2, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 1000]}
}
  • heads : 表示输出
  • [2, 0, 0], [1, 0, 0] 这些应该是表示的 Symbol 的 id。

NDArray 序列化

ndarray 序列化是序列化 ndarray 中的 tensor 值。

序列化 NDArray 有两种方法:

  • 使用 pickle , (python)
    • 序列化:pkl.dumps() pkl.dump()
    • 加载:pkl.load(), pkl.loads()
  • 使用 NDArray 自带的 方法
    • 序列化:mx.nd.save()
    • 加载:mx.nd.load()
import pickle as pkl
a = mx.nd.ones((2, 3))
# pack and then dump into disk
data = pkl.dumps(a)
pkl.dump(data, open('tmp.pickle', 'wb'))
# load from disk and then unpack
data = pkl.load(open('tmp.pickle', 'rb'))
b = pkl.loads(data)
b.asnumpy()

a = mx.nd.ones((2,3))
b = mx.nd.ones((5,6))
mx.nd.save("temp.ndarray", [a,b])
c = mx.nd.load("temp.ndarray")
c

d = {'a':a, 'b':b}
mx.nd.save("temp.ndarray", d)
c = mx.nd.load("temp.ndarray")
c

Module 保存参数与加载参数

保存

使用 checkpoint callback 在每个 epoch 之后保存一次参数。

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

如果不用 fit 的话,如何保存呢?

先看下fit部分的代码

# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)

if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

我们只需要模拟这部分代码,手动调用 callback 就可以了

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)

# ...
mod.bind(...)

# 调用这个函数来 保存参数就可以了
def save_checkpoint(epoch, module, callback):
    arg_params, aux_params = module.get_params()
    module.set_params(arg_params, aux_params)
    callback(epoch, module.symbol, arg_params, aux_params)

加载

加载保存了的 模型参数,使用 load_checkpoint 方法

# 不仅加载了 参数,同时加载了 Symbol
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# 然后创建一个 module
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

参考资料

https://mxnet.incubator.apache.org/tutorials/basic/module.html#save-and-load

https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html#serialize-from-to-distributed-filesystems

https://mxnet.incubator.apache.org/tutorials/basic/symbol.html#load-and-save

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏java相关

并发基本概念介绍

12150
来自专栏kl的专栏

skywalking源码分析之javaAgent工具ByteBuddy的应用

关于skywalking请看我上一篇博文,其使用javaAgent技术,使得应用接入监控0耦合。今天在分析skywaking过程中,对javaAgent技术有了...

82280
来自专栏Web行业观察

NodeJS中的异步编程经验

问题引入:今天在 Gulp 构建任务中出现一个 html 解析错误,但是并没有报错,也没有中断 gulp 构建任务的执行,而是出现 UnhandledPromi...

27620
来自专栏Urahara Blog

Web For Pentester - Directory traversal & File Include Part Tips

29960
来自专栏老马寒门IT

Node入门教程(8)第六章:path 模块详解

path 模块详解 path 模块提供了一些工具函数,用于处理文件与目录的路径。由于windows和其他系统之间路径不统一,path模块还专门做了相关处理,屏蔽...

30580
来自专栏Linyb极客之路

Spring中@Async用法总结

引言: 在Java应用中,绝大多数情况下都是通过同步的方式来实现交互处理的;但是在处理与第三方系统交互的时候,容易造成响应迟缓的情况,之前大部分都是使用多线程来...

30930
来自专栏Java技术栈

Spring 获取 request 的几种方法及其线程安全性分析

本文将介绍在Spring MVC开发的Web系统中,获取request对象的几种方法,并讨论其线程安全性。

10640
来自专栏软件开发 -- 分享 互助 成长

很经典的GDB调试命令,包括查看变量,查看内存

在你调试程序时,当程序被停住时,你可以使用print命令(简写命令为p),或是同义命令inspect来查看当前程序的运行数据。print命令的格式是: prin...

1.6K60
来自专栏前端杂货铺

node中的Stream-Readable和Writeable解读

在node中,只要涉及到文件IO的场景一般都会涉及到一个类-Stream。Stream是对IO设备的抽象表示,其在JAVA中也有涉及,主要体现在四个类-Inpu...

38790
来自专栏老马寒门IT

Node入门教程(8)第六章:path 模块详解

18940

扫码关注云+社区

领取腾讯云代金券