MXNET学习笔记(二):模型的保存与加载

当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值。当序列化 Symbol 的时候,我们序列化的是 Graph。

Symbol序列化

当序列化 Symbol 的时候,通常使用 json 文件作为序列化后的文件,因为可读性好。

import mxnet as mx
a = mx.sym.Variable('a', shape=[2,])
b = mx.sym.Variable('b', shape=[3,])
c = a+b
print(c.tojson()) # 打印出来 json 文件,看看里面是啥
c.save('symbol-c.json') # 保存文件

c2 = mx.sym.loads('symbol-c.json') # 加载 json 文件,此时 c2 就代表一个 symbol
{
  "nodes": [
    {
      "op": "null", 
      "name": "a", 
      "attr": {"__shape__": "[2]"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "b", 
      "attr": {"__shape__": "[3]"}, 
      "inputs": []
    }, 
    {
      "op": "elemwise_add", 
      "name": "_plus0", 
      "inputs": [[0, 0, 0], [1, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 1], 
  "node_row_ptr": [0, 1, 2, 3], 
  "heads": [[2, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 1000]}
}
  • heads : 表示输出
  • [2, 0, 0], [1, 0, 0] 这些应该是表示的 Symbol 的 id。

NDArray 序列化

ndarray 序列化是序列化 ndarray 中的 tensor 值。

序列化 NDArray 有两种方法:

  • 使用 pickle , (python)
    • 序列化:pkl.dumps() pkl.dump()
    • 加载:pkl.load(), pkl.loads()
  • 使用 NDArray 自带的 方法
    • 序列化:mx.nd.save()
    • 加载:mx.nd.load()
import pickle as pkl
a = mx.nd.ones((2, 3))
# pack and then dump into disk
data = pkl.dumps(a)
pkl.dump(data, open('tmp.pickle', 'wb'))
# load from disk and then unpack
data = pkl.load(open('tmp.pickle', 'rb'))
b = pkl.loads(data)
b.asnumpy()

a = mx.nd.ones((2,3))
b = mx.nd.ones((5,6))
mx.nd.save("temp.ndarray", [a,b])
c = mx.nd.load("temp.ndarray")
c

d = {'a':a, 'b':b}
mx.nd.save("temp.ndarray", d)
c = mx.nd.load("temp.ndarray")
c

Module 保存参数与加载参数

保存

使用 checkpoint callback 在每个 epoch 之后保存一次参数。

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

如果不用 fit 的话,如何保存呢?

先看下fit部分的代码

# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)

if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

我们只需要模拟这部分代码,手动调用 callback 就可以了

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)

# ...
mod.bind(...)

# 调用这个函数来 保存参数就可以了
def save_checkpoint(epoch, module, callback):
    arg_params, aux_params = module.get_params()
    module.set_params(arg_params, aux_params)
    callback(epoch, module.symbol, arg_params, aux_params)

加载

加载保存了的 模型参数,使用 load_checkpoint 方法

# 不仅加载了 参数,同时加载了 Symbol
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# 然后创建一个 module
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

参考资料

https://mxnet.incubator.apache.org/tutorials/basic/module.html#save-and-load

https://mxnet.incubator.apache.org/tutorials/basic/ndarray.html#serialize-from-to-distributed-filesystems

https://mxnet.incubator.apache.org/tutorials/basic/symbol.html#load-and-save

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏柠檬先生

Sass 基础(五)

@if   @if 指令是一个SassScript,它可以根据条件处理样式块,如果条件为true返回一个样式块,反之   false 返回另一个样式块,在S...

2168
来自专栏互联网大杂烩

二分查找

在对线性表的操作中,经常需要查找某一个元素在线性表中的位置。此问题的输入是待查元素x和线性表L,输出为x在L中的位置或者x不在L中的信息。

643
来自专栏Java帮帮-微信公众号-技术文章全总结

Java基础-03(01).总结运算符、键盘录入、if语句

? 1:运算符(掌握) (1)算术运算符 A:+,-,*,/,%,++,-- B:+的用法 a:加法 b:正号 c:字符串连接符 C:/和%的区...

2624
来自专栏Java帮帮-微信公众号-技术文章全总结

14(01)正则表达式,Pattern,Mactcher,Math,BigInteger,BigDeximal,System等

学正则表达式之前qq号问题: package cn.itcast_01; import java.util.Scanner; /* * 校验qq号码. * ...

2695
来自专栏PHP在线

php数组操作(回顾)

1. 合并数组 array_merge()函数将数组合并到一起,返回一个联合的数组。所得到的数组以第一个输入数组参数开始,按后面数组参数出现的顺序依次迫加。其形...

3357
来自专栏C/C++基础

C++ struct与union

编码运行环境:VS2012+Win32+Debug Win32既表示运行平台是Windows 32bits操作系统,又表示生成32bits的应用程序。

521
来自专栏偏前端工程师的驿站

(cljs/run-at (JSVM. :browser) "简单类型可不简单啊~")

前言  每逢学习一个新的语言时总要先了解这门语言支持的数据类型,因为数据类型决定这门语言所针对的问题域,像Bash那样内置只支持字符串的脚步明显就是用于文本处理...

1847
来自专栏黄Java的地盘

should.js源码分析与学习

为了研究与学习某些测试框架的工作原理,同时也为了完成培训中实现一个简单的测试框架的原因,我对should.js的代码进行了学习与分析,现在与大家来进行交流下。

791
来自专栏开发之途

Android IPC机制(1)-序列化机制

1255
来自专栏Java帮帮-微信公众号-技术文章全总结

13(02)总结StringBuffer,StringBuilder,数组高级,Arrays,Integer,Character

(3)Arrays工具类 A:是针对数组进行操作的工具类。包括排序和查找等功能。 B:要掌握的方法(自己补齐方法) 把数组转成字符串:public sta...

2215

扫码关注云+社区