MXNET学习笔记(二):模型的保存与加载

当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值。当序列化 Symbol 的时候,我们序列化的是 Graph。

Symbol序列化

当序列化 Symbol 的时候,通常使用 json 文件作为序列化后的文件,因为可读性好。

import mxnet as mx
a = mx.sym.Variable('a', shape=[2,])
b = mx.sym.Variable('b', shape=[3,])
c = a+b
print(c.tojson()) # 打印出来 json 文件,看看里面是啥
c.save('symbol-c.json') # 保存文件

c2 = mx.sym.loads('symbol-c.json') # 加载 json 文件,此时 c2 就代表一个 symbol
{
  "nodes": [
    {
      "op": "null", 
      "name": "a", 
      "attr": {"__shape__": "[2]"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "b", 
      "attr": {"__shape__": "[3]"}, 
      "inputs": []
    }, 
    {
      "op": "elemwise_add", 
      "name": "_plus0", 
      "inputs": [[0, 0, 0], [1, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 1], 
  "node_row_ptr": [0, 1, 2, 3], 
  "heads": [[2, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 1000]}
}
  • heads : 表示输出
  • [2, 0, 0], [1, 0, 0] 这些应该是表示的 Symbol 的 id。

NDArray 序列化

ndarray 序列化是序列化 ndarray 中的 tensor 值。

序列化 NDArray 有两种方法:

  • 使用 pickle , (python)
    • 序列化:pkl.dumps() pkl.dump()
    • 加载:pkl.load(), pkl.loads()
  • 使用 NDArray 自带的 方法
    • 序列化:mx.nd.save()
    • 加载:mx.nd.load()
import pickle as pkl
a = mx.nd.ones((2, 3))
# pack and then dump into disk
data = pkl.dumps(a)
pkl.dump(data, open('tmp.pickle', 'wb'))
# load from disk and then unpack
data = pkl.load(open('tmp.pickle', 'rb'))
b = pkl.loads(data)
b.asnumpy()

a = mx.nd.ones((2,3))
b = mx.nd.ones((5,6))
mx.nd.save("temp.ndarray", [a,b])
c = mx.nd.load("temp.ndarray")
c

d = {'a':a, 'b':b}
mx.nd.save("temp.ndarray", d)
c = mx.nd.load("temp.ndarray")
c

Module 保存参数与加载参数

保存

使用 checkpoint callback 在每个 epoch 之后保存一次参数。

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

如果不用 fit 的话,如何保存呢?

先看下fit部分的代码

# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)

if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

我们只需要模拟这部分代码,手动调用 callback 就可以了

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)

# ...
mod.bind(...)

# 调用这个函数来 保存参数就可以了
def save_checkpoint(epoch, module, callback):
    arg_params, aux_params = module.get_params()
    module.set_params(arg_params, aux_params)
    callback(epoch, module.symbol, arg_params, aux_params)

加载

加载保存了的 模型参数,使用 load_checkpoint 方法

# 不仅加载了 参数,同时加载了 Symbol
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# 然后创建一个 module
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

参考资料

https://mxnet.incubator.apache.org/tutorials/basic/module.html#save-and-load

https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html#serialize-from-to-distributed-filesystems

https://mxnet.incubator.apache.org/tutorials/basic/symbol.html#load-and-save

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Linyb极客之路

Spring中@Async用法总结

引言: 在Java应用中,绝大多数情况下都是通过同步的方式来实现交互处理的;但是在处理与第三方系统交互的时候,容易造成响应迟缓的情况,之前大部分都是使用多线程来...

1672
来自专栏Coding01

看 Laravel 源代码了解 Container

自从上文《看 Laravel 源代码了解 ServiceProvider 的加载》,我们知道 Application (or Container) 充当 Lar...

2205
来自专栏kl的专栏

skywalking源码分析之javaAgent工具ByteBuddy的应用

关于skywalking请看我上一篇博文,其使用javaAgent技术,使得应用接入监控0耦合。今天在分析skywaking过程中,对javaAgent技术有了...

6178
来自专栏Android群英传

Retrofit源码分析

1204
来自专栏WebHub

NodeJS中的异步编程经验

问题引入:今天在 Gulp 构建任务中出现一个 html 解析错误,但是并没有报错,也没有中断 gulp 构建任务的执行,而是出现 UnhandledPromi...

1462
来自专栏Urahara Blog

Web For Pentester - Directory traversal & File Include Part Tips

2126
来自专栏技术墨客

Hazelcast集群服务(4)——分布式Map

    在第一篇介绍Hazelcast的文章已经提到,Hazelcast为Java中绝大部分数据结构提供了分布式实现。我们常用的Map、List、Queue等数...

1543
来自专栏xingoo, 一个梦想做发明家的程序员

VS报错:DEBUG Assertion Failed!

使用vs2010时,遇到如下错误 ? 然后点击继续后: ? 点击终止: ? 观察变量: ? 根据提示发现,有可能是断点问题,于是猜想可能是指针的错误。 goog...

4509
来自专栏前端杂货铺

node中的Stream-Readable和Writeable解读

在node中,只要涉及到文件IO的场景一般都会涉及到一个类-Stream。Stream是对IO设备的抽象表示,其在JAVA中也有涉及,主要体现在四个类-Inpu...

3669
来自专栏老马寒门IT

Node入门教程(8)第六章:path 模块详解

1524

扫码关注云+社区