MXNET学习笔记(一):Module类(1)

Module 是 mxnet 提供给用户的一个高级封装的类。有了它,我们可以很容易的来训练模型。

Module 包含以下单元的一个 wraper

  • symbol : 用来表示网络前向过程的 symbol
  • optimizer: 优化器,用来更新网络。
  • exec_group: 用来执行 前向和反向计算。

所以 Module 可以帮助我们做

  • 前向计算,(由 exec_group 提供支持)
  • 反向计算,(由 exec_group 提供支持)
  • 更新网络,(由 optimizer 提供支持)

一个 Demo

下面来看 MXNET 官网上提供的一个 Module 案例

第一部分:准备数据

import logging
logging.getLogger().setLevel(logging.INFO)
import mxnet as mx
import numpy as np

fname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',')[:,1:]
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])

batch_size = 32
ntrain = int(data.shape[0]*0.8)
train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)

第二部分:构建网络

net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net)

第三部分:创建Module

mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

# 通过data_shapes 和 label_shapes 推断其余参数的 shape,然后给它们分配空间
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# 初始化模型的参数
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# 初始化优化器,优化器用来更新模型
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train 5 epochs, i.e. going over the data iter one pass
for epoch in range(5):
    train_iter.reset()
    metric.reset()
    for batch in train_iter:
        mod.forward(batch, is_train=True)       # 前向计算
        mod.update_metric(metric, batch.label)  # accumulate prediction accuracy
        mod.backward()                          # 反向传导
        mod.update()                            # 更新参数
    print('Epoch %d, Training %s' % (epoch, metric.get()))

关于 bind 的参数:

  • data_shapes : list of (str, tuple), str 是 数据 Symbol 的名字,tuple是 mini-batch 的形状,所以一般参数是[('data', (64, 3, 224, 224))]
  • label_shapes: list of (str, tuple),str 是 标签 Symbol 的名字,tuple是 mini-batch 标签的形状,一般 分类任务的 参数为 [('softmax_label'),(64,)]
  • 为什么上面两个参数都是 list 呢? 因为可能某些网络架构,不止一个 数据,不止一种 标签。

关于 forward的参数

  • data_batch : 一个 mx.io.DataBatch-like 对象。只要一个对象,可以 .data返回 mini-batch 训练数据, .label 返回相应的标签,就可以作为 data_batch 的实参 。
  • 关于 DataBatch对象:.data 返回的是 list of NDArray(网络可能有多个输入数据),.label 也一样。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏简书专栏

基于tensorflow+CNN的MNIST数据集手写数字分类预测

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

2493
来自专栏大数据挖掘DT机器学习

除了写烂的手写数据分类,你会不会做自定义图像数据集的识别?!

网上看的很多教程都是几个常见的例子,从内置模块或在线download数据集,要么是iris,要么是MNIST手写识别数字,或是UCI ,数据集不需要自己准备,所...

4584
来自专栏漫漫深度学习路

pytorch学习笔记(十一):fine-tune 预训练的模型

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预...

6759
来自专栏前端儿

鸡兔同笼

已知鸡和兔的总数量为n,总腿数为m。输入n和m,依次输出鸡和兔的数目,如果无解,则输出“No answer”(不要引号)。

1151
来自专栏红色石头的机器学习之路

Coursera吴恩达《优化深度神经网络》课程笔记(3)-- 超参数调试、Batch正则化和编程框架

上节课我们主要介绍了深度神经网络的优化算法。包括对原始数据集进行分割,使用mini-batch gradient descent。然后介绍了指数加权平均(Exp...

4870
来自专栏Petrichor的专栏

TensorFlow大本营

1954
来自专栏天天P图攻城狮

Android终端上视频转GIF的实现及GIF质量讨论

在生成 GIF 的过程中,最关键的步骤就是生成调色板以及像素到调色板的映射关系。

77711
来自专栏CreateAMind

合成动态视频效果及声音合成

图片大小限制,更多可访问 http://www.stat.ucla.edu/~jxie/STGConvNet/STGConvNet.html

1412
来自专栏AI研习社

Github 项目推荐 | Basel Face Model 2017 完全参数化人脸

本软件可以从 Basel Face Model 2017 里生成完全参数化的人脸,论文链接: https://arxiv.org/abs/1712.01619 ...

5357
来自专栏深度学习思考者

DL开源框架Caffe | 目标检测Faster-rcnn问题全解析

一 工程目录 在github上clone下来的代码,可以看到根目录下有以下几个文件夹,其中output为训练完之后才会有的文件夹。 caffe-fast-rcn...

4228

扫码关注云+社区