MXNet 源码解读系列之一 C++端如何解析NDArray参数文件

本文相关代码: parsingNDArray

要想弄清楚MXNet 是如何解析参数文件,并从中提取预训练好的权值,首先第一步要看

MXNet Python端是如何是调用C接口来完成读取NDArray参数文件的。

这部分代码见源码 python/mxnet/ndarray/utils.py 第149行:

def load(fname):
    """Loads an array from file.

    See more details in ``save``.

    Parameters
    ----------
    fname : str
        The filename.

    Returns
    -------
    list of NDArray, RowSparseNDArray or CSRNDArray, or \
    dict of str to NDArray, RowSparseNDArray or CSRNDArray
        Loaded data.
    """
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(_LIB.MXNDArrayLoad(c_str(fname),                                         
                                  ctypes.byref(out_size),
                                  ctypes.byref(handles),
                                  ctypes.byref(out_name_size),
                                  ctypes.byref(names)))
    if out_name_size.value == 0:
        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
    else:
        assert out_name_size.value == out_size.value
        return dict(
            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
            for i in range(out_size.value))

这个 load 函数接收参数路径作为输入,然后根据参数文件中有没有包含参数的名字选择返回

NDArray参数数组或者字典。然后可以看到是调用了 MXNDArrayLoad 这个C接口函数,这个函数的

代码见 src/c_api/c_api.cc 第308行:

int MXNDArrayLoad(const char* fname,
                  mx_uint *out_size,
                  NDArrayHandle** out_arr,
                  mx_uint *out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();
  std::vector<NDArray> data;
  std::vector<std::string> &names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ret->ret_handles.resize(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    NDArray *ptr = new NDArray();
    *ptr = data[i];
    ret->ret_handles[i] = ptr;
  }
  ret->ret_vec_charp.resize(names.size());
  for (size_t i = 0; i < names.size(); ++i) {
    ret->ret_vec_charp[i] = names[i].c_str();
  }
  *out_size = static_cast<mx_uint>(data.size());
  *out_arr = dmlc::BeginPtr(ret->ret_handles);
  *out_name_size = static_cast<mx_uint>(names.size());
  *out_names = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

然后可以看到最核心的代码就是第319行调用了NDArray类的静态Load函数获得参数的名字和

内容,Load函数具体实现见:src/ndarray/ndarray.cc 第 1812行:

void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

从这里读取内容的过程可以大概看出NDArray参数文件存储的内容的顺序是什么了,首先是会

存两个uint64_t类型的数字,然后就是NDArray数组,接着是每个NDArray对应的名字的数组。

好了接下来就是解读源码中是如何从Stream中解析出内容的,首先我们来看下Stream类的

Read函数,具体见 io.h 第435行:

template<typename T>
inline bool Stream::Read(T *out_data) {
  return serializer::Handler<T>::Read(this, out_data);
}

这里可以看到,Read 函数内部又调用了 Handler这个类的Read静态函数,这个静态函数对应的

代码见 serializer.h 第262行:

inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

这里代码我第一次看的时候有点蒙,后来仔细研究了下也看懂了。首先我们要看 IfThenElse

是什么东西,这里还是看到 io.h 的第 38 到 66行:

//! \cond Doxygen_Suppress
/*!
 * \brief Serializer that redirect calls by condition
 * \tparam cond the condition
 * \tparam Then the serializer used for then condition
 * \tparam Else the serializer used for else condition
 * \tparam Return the type of data the serializer handles
 */
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;

template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Then::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Then::Read(strm, data);
  }
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Else::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Else::Read(strm, data);
  }
};

这里可以看到 IfThenElse 就是一个结构体,有四个模板参数,意思很明显了,如果第一个参数

为true,则会调用Then这个类的Read静态函数,如果第一个参数为false,则会调用Else这个类的

Read静态函数。看完 IfThenElse 的定义之后,我们看回 262 行的Read函数就很清楚了,

inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

意思就是,如果 dmlc::is_pod<T>::value 这个值为 true,那么就会调用 PODHandler 的Read

函数,否则就会走到下一个条件判断,下一个条件判断是当 dmlc::has_saveload<T>::value 这个

值为true的话就调用 SaveLoadClassHandler 的 Read 静态函数,否则就走到

UndefinedSerializerFor。好了,那么现在就是要看具体走了哪个分支,首先我们要知道T在运行时

是什么类型,看回上面的NDArray 的 Load 函数,知道了首先读取得两个数字的类型是 uint64_t,

接着跳转到源码 type_traits.h,看第126和第152行:

/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \
  template<>                                          \
  struct Trait<Type> {                                \
    static const bool value = Value;                  \
  }

DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);

很明显可以看到,dmlc::is_pod<uint64_t>::value 的值为 true,因此会调用 PODHandler 的

Read 函数,代码:

/*! \brief Serializer for POD(plain-old-data) data */
template<typename T>
struct PODHandler {
  inline static void Write(Stream *strm, const T &data) {
    strm->Write(&data, sizeof(T));
  }
  inline static bool Read(Stream *strm, T *dptr) {
    return strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  // NOLINT(*)
  }
};

PODHandler 的Read函数就是调用 Stream 的Read,这里如果读者想再详细了解 Stream 类

Read 函数的工作原理可以自己再去细看,不过对于本文来说,到这里知道了会根据T的字节数读取

内容到dptr里面就够了。

Ok,现在已经读取完两个数字 header, reserved,然后就是读 NDArray Vector 了,然后这

里还是跳转到,调用 Handler<T>::Read 函数,不过这里和读数字不一样的地方在于,这里传入的

模板参数是vector<NDArray>,所以调用的是下面这个Handler定义的Read函数:

//! \cond Doxygen_Suppress
template<typename T>
struct Handler<std::vector<T> > {
  inline static void Write(Stream *strm, const std::vector<T> &data) {
    IfThenElse<dmlc::is_pod<T>::value,
               PODVectorHandler<T>,
               ComposeVectorHandler<T>, std::vector<T> >
    ::Write(strm, data);
  }
  inline static bool Read(Stream *strm, std::vector<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODVectorHandler<T>,
                      ComposeVectorHandler<T>,
                      std::vector<T> >
    ::Read(strm, data);
  }
};

然后这里的判断分支是会调用 ComposeVectorHandler 的 Read 函数:

/*!
 * \brief Serializer handler for std::vector<T> where T can be composed type
 * \tparam T element type
 */
template<typename T>
struct ComposeVectorHandler {
  inline static void Write(Stream *strm, const std::vector<T> &vec) {
    uint64_t sz = static_cast<uint64_t>(vec.size());
    strm->Write(&sz, sizeof(sz));
    for (size_t i = 0; i < vec.size(); ++i) {
      Handler<T>::Write(strm, vec[i]);
    }
  }
  inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
    uint64_t sz;
    if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
    size_t size = static_cast<size_t>(sz);
    out_vec->resize(size);
    for (size_t i = 0; i < size; ++i) {
      if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false;
    }
    return true;
  }
};

首先先读出 vector 数组的大小,然后分别读取每个 NDArray,这里在读每个 NDArray 的时

候又会调用 Handler<T>::Read 函数,这次 IfThenElse 分支判断那里会走

SaveLoadClassHandler这个分支:

// serializer for class that have save/load function
template<typename T>
struct SaveLoadClassHandler {
  inline static void Write(Stream *strm, const T &data) {
    data.Save(strm);
  }
  inline static bool Read(Stream *strm, T *data) {
    return data->Load(strm);
  }
};

最后看到其实就是调用了 NDArray 类本身的 Load 函数,见源码 src/ndarray/ndarray.cc

bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
  const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

  // load storage shape
  TShape sshape;
  if (nad > 0) {
    if (!sshape.Load(strm)) return false;
  }

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;

  // load aux_types and aux_shapes
  std::vector<int32_t> aux_types;
  std::vector<TShape> aux_shapes;
  if (nad > 0) {
    aux_types.resize(nad);
    aux_shapes.resize(nad);
    for (int i = 0; i < nad; ++i) {
      // load aux_type(i)
      if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false;
      // load aux_shapes(i)
      if (!aux_shapes[i].Load(strm)) return false;
    }
  }

  // load data into CPU
  NDArray temp;
  if (0 == nad) {
    temp = NDArray(shape, Context::CPU(), false, type_flag);
  } else {
    temp = NDArray(static_cast<NDArrayStorageType>(stype), shape,
                   Context::CPU(), false, type_flag,
                   aux_types, aux_shapes, sshape);
  }
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;

  // load aux_data
  if (nad > 0) {
    for (int i = 0; i < nad; ++i) {
      load_data = temp.aux_data(i);
      type_size = mshadow::mshadow_sizeof(load_data.type_flag_);
      nread = type_size * load_data.Size();
      if (strm->Read(load_data.dptr_, nread) != nread) return false;
    }
  }

  if (ctx.dev_mask() == cpu::kDevMask) {
    *this = std::move(temp); return true;
  } else {
#if MXNET_USE_CUDA
    *this = temp.Copy(ctx); return true;
#else
    *this = std::move(temp); return true;
#endif
  }
}

这里首先,读出一个 magic number ,如果用 V1.0 之后的MXNet版本,magic number

都是会等于 NDARRAY_V2_MAGIC,具体定义见下面:

/* magic number for ndarray version 1, with int64_t TShape */
static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;

/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

所以不会进入 LegacyLoad 函数,接着就是读 storage type,NDArray的类型,除了常用的

普通类型,现在也已经支持了稀疏类型:

enum NDArrayStorageType {
  kUndefinedStorage = -1,  // undefined storage
  kDefaultStorage,         // dense
  kRowSparseStorage,       // row sparse
  kCSRStorage,             // csr
};

一般来说,storage type 都是 kDefaultStorage 类型,我现在写的解析小工具里面也只考虑

了解析普通类型的NDArray,之后再改进吧。然后看到 num_aux_data函数,这个函数如果传入

普通类型则返回0,所以 nad 的值为 0。

size_t num_aux_data(NDArrayStorageType stype) {
  size_t num = 0;
  switch (stype) {
    case kDefaultStorage: num = 0; break;
    case kCSRStorage: num = 2; break;
    case kRowSparseStorage: num = 1; break;
     default: LOG(FATAL) << "Unknown storage type" << stype; break;
  }
  return num;
}

nad 值为0 的话整个代码就简洁很多了,简化之后如下:

bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;
  // load data into CPU
  NDArray temp;
  temp = NDArray(shape, Context::CPU(), false, type_flag);
  
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;
}

到这里为,大概怎么读取NDArray,相信应该挺清晰的了。

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Java帮帮-微信公众号-技术文章全总结

Java设计模式-模板方式模式

模板方法模式: 定义一个操作中的算法的骨架, 而将一些步骤延迟到子类中. 模板方法使得子类可以在不改变一个算法的结构的前提下重定义该算法的某些特定步骤. ? ...

4578
来自专栏james大数据架构

MVC前台Post/Get异步获得数据时参数的取值问题

Post方法,返回text,后台获得Data View         $.ajax({ type: "POST", ...

2445
来自专栏二进制文集

JDK源码分析 Integer

对于JDK源码分析的文章,仅仅记录我认为重要的地方。源码的细节实在太多,不可能面面俱到地写清每个逻辑。所以我的JDK源码分析,着重在JDK的体系架构层面,具体源...

883
来自专栏小灰灰

Java 动手写爬虫: 二、 深度爬取

第二篇 前面实现了一个最基础的爬取单网页的爬虫,这一篇则着手解决深度爬取的问题 简单来讲,就是爬了一个网页之后,继续爬这个网页中的链接 1. 需求背景 背景...

45410
来自专栏xiaoxi666的专栏

状态机编程思想(2):删除代码注释(目前支持C/C++和Java)

之前考虑过正则表达式,但是感觉实现起来相当麻烦。而状态机可以把多种情况归为一类状态再行分解,大大简化问题。本文就是基于状态机实现的。

632
来自专栏cmazxiaoma的架构师之路

SpringBoot之路(二)之Web进阶

1854
来自专栏chenssy

【死磕 Spring】----- IOC 之 Spring 统一资源加载策略

在学 Java SE 的时候我们学习了一个标准类 java.net.URL,该类在 Java SE 中的定位为统一资源定位器(Uniform Resource ...

1213
来自专栏tkokof 的技术,小趣及杂念

Sweet Snippet 系列之 有序列表

很朴素的一种想法,为了维持 List 有序,我们可以在 Add 操作之后进行 Sort 操作(Remove 操作后不需要重新 Sort):

311
来自专栏Android机器圈

Java设计模式总汇二(小白也要飞)

PS:上一篇我介绍了适配器设计模式、单例设计模式、静态代理设计模式、简单工厂设计模式,如果没有看过第一篇的小火鸡可以点这个看看http://www.cnblog...

3349
来自专栏函数式编程语言及工具

FunDA(13)- 示范:用户自定义操作函数 - user defined tasks

   FunDA是一种函数式的编程工具,它所产生的程序是由许多功能单一的细小函数组合而成,这些函数就是用户自定义操作函数了。我们在前面曾经提过FunDA的运作原...

1748

扫码关注云+社区