MXNET学习笔记(二):模型的保存与加载

当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值。当序列化 Symbol 的时候,我们序列化的是 Graph。

Symbol序列化

当序列化 Symbol 的时候,通常使用 json 文件作为序列化后的文件,因为可读性好。

import mxnet as mx
a = mx.sym.Variable('a', shape=[2,])
b = mx.sym.Variable('b', shape=[3,])
c = a+b
print(c.tojson()) # 打印出来 json 文件,看看里面是啥
c.save('symbol-c.json') # 保存文件

c2 = mx.sym.loads('symbol-c.json') # 加载 json 文件,此时 c2 就代表一个 symbol
{
  "nodes": [
    {
      "op": "null", 
      "name": "a", 
      "attr": {"__shape__": "[2]"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "b", 
      "attr": {"__shape__": "[3]"}, 
      "inputs": []
    }, 
    {
      "op": "elemwise_add", 
      "name": "_plus0", 
      "inputs": [[0, 0, 0], [1, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 1], 
  "node_row_ptr": [0, 1, 2, 3], 
  "heads": [[2, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 1000]}
}
  • heads : 表示输出
  • [2, 0, 0], [1, 0, 0] 这些应该是表示的 Symbol 的 id。

NDArray 序列化

ndarray 序列化是序列化 ndarray 中的 tensor 值。

序列化 NDArray 有两种方法:

  • 使用 pickle , (python)
    • 序列化:pkl.dumps() pkl.dump()
    • 加载:pkl.load(), pkl.loads()
  • 使用 NDArray 自带的 方法
    • 序列化:mx.nd.save()
    • 加载:mx.nd.load()
import pickle as pkl
a = mx.nd.ones((2, 3))
# pack and then dump into disk
data = pkl.dumps(a)
pkl.dump(data, open('tmp.pickle', 'wb'))
# load from disk and then unpack
data = pkl.load(open('tmp.pickle', 'rb'))
b = pkl.loads(data)
b.asnumpy()

a = mx.nd.ones((2,3))
b = mx.nd.ones((5,6))
mx.nd.save("temp.ndarray", [a,b])
c = mx.nd.load("temp.ndarray")
c

d = {'a':a, 'b':b}
mx.nd.save("temp.ndarray", d)
c = mx.nd.load("temp.ndarray")
c

Module 保存参数与加载参数

保存

使用 checkpoint callback 在每个 epoch 之后保存一次参数。

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

如果不用 fit 的话,如何保存呢?

先看下fit部分的代码

# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)

if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

我们只需要模拟这部分代码,手动调用 callback 就可以了

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)

# ...
mod.bind(...)

# 调用这个函数来 保存参数就可以了
def save_checkpoint(epoch, module, callback):
    arg_params, aux_params = module.get_params()
    module.set_params(arg_params, aux_params)
    callback(epoch, module.symbol, arg_params, aux_params)

加载

加载保存了的 模型参数,使用 load_checkpoint 方法

# 不仅加载了 参数,同时加载了 Symbol
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# 然后创建一个 module
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

参考资料

https://mxnet.incubator.apache.org/tutorials/basic/module.html#save-and-load

https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html#serialize-from-to-distributed-filesystems

https://mxnet.incubator.apache.org/tutorials/basic/symbol.html#load-and-save

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏牛肉圆粉不加葱

(3) - Scala case class那些你不知道的知识

除了在模式匹配中使用之外,unapply 方法可以让你结构 case class 来提取它的字段,如:

661
来自专栏星回的实验室

js重修课[六]:客户端JavaScript一些琐事

562
来自专栏ImportSource

为什么实现了equals()就必须实现hashCode()?

我们先来看下面这个简单的例子,然后运行她: class Person{ private String name; private int age; ...

3704
来自专栏微信公众号:Java团长

深入理解Java:String

按照官方的说法:Java 虚拟机具有一个堆,堆是运行时数据区域,所有类实例和数组的内存均从此处分配。

771
来自专栏Python、Flask、Django

python中的filter函数

732
来自专栏云霄雨霁

Java--lambda(λ)表达式

2026
来自专栏技术博客

C#基础知识系列五(构造函数)

  2、不带参数的构造函数称为“默认构造函数”。 无论何时,只要使用 new 运算符实例化对象,并且不为 new 提供任何参数,就会调用默认构造函数。除非类是s...

633
来自专栏Python、Flask、Django

TP前台调用后台验证方法(跨模块继承控制器)

1063
来自专栏好好学java的技术栈

Java基础提升篇:equals()与hashCode()方法详解

992
来自专栏小灰灰

Java学习之深拷贝浅拷贝及对象拷贝的两种方式

I. Java之Clone 0. 背景 对象拷贝,是一个非常基础的内容了,为什么会单独的把这个领出来讲解,主要是先前遇到了一个非常有意思的场景 有一个任务,需要...

2359

扫码关注云+社区