前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MXNet 源码解读系列之一 C++端如何解析NDArray参数文件

MXNet 源码解读系列之一 C++端如何解析NDArray参数文件

原创
作者头像
Ldpe2G
修改2018-06-21 15:40:23
2.7K0
修改2018-06-21 15:40:23
举报

本文相关代码: parsingNDArray

要想弄清楚MXNet 是如何解析参数文件,并从中提取预训练好的权值,首先第一步要看

MXNet Python端是如何是调用C接口来完成读取NDArray参数文件的。

这部分代码见源码 python/mxnet/ndarray/utils.py 第149行:

代码语言:python
复制
def load(fname):
    """Loads an array from file.

    See more details in ``save``.

    Parameters
    ----------
    fname : str
        The filename.

    Returns
    -------
    list of NDArray, RowSparseNDArray or CSRNDArray, or \
    dict of str to NDArray, RowSparseNDArray or CSRNDArray
        Loaded data.
    """
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(_LIB.MXNDArrayLoad(c_str(fname),                                         
                                  ctypes.byref(out_size),
                                  ctypes.byref(handles),
                                  ctypes.byref(out_name_size),
                                  ctypes.byref(names)))
    if out_name_size.value == 0:
        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
    else:
        assert out_name_size.value == out_size.value
        return dict(
            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
            for i in range(out_size.value))

这个 load 函数接收参数路径作为输入,然后根据参数文件中有没有包含参数的名字选择返回

NDArray参数数组或者字典。然后可以看到是调用了 MXNDArrayLoad 这个C接口函数,这个函数的

代码见 src/c_api/c_api.cc 第308行:

代码语言:c++
复制
int MXNDArrayLoad(const char* fname,
                  mx_uint *out_size,
                  NDArrayHandle** out_arr,
                  mx_uint *out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();
  std::vector<NDArray> data;
  std::vector<std::string> &names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ret->ret_handles.resize(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    NDArray *ptr = new NDArray();
    *ptr = data[i];
    ret->ret_handles[i] = ptr;
  }
  ret->ret_vec_charp.resize(names.size());
  for (size_t i = 0; i < names.size(); ++i) {
    ret->ret_vec_charp[i] = names[i].c_str();
  }
  *out_size = static_cast<mx_uint>(data.size());
  *out_arr = dmlc::BeginPtr(ret->ret_handles);
  *out_name_size = static_cast<mx_uint>(names.size());
  *out_names = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

然后可以看到最核心的代码就是第319行调用了NDArray类的静态Load函数获得参数的名字和

内容,Load函数具体实现见:src/ndarray/ndarray.cc 第 1812行:

代码语言:c++
复制
void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

从这里读取内容的过程可以大概看出NDArray参数文件存储的内容的顺序是什么了,首先是会

存两个uint64_t类型的数字,然后就是NDArray数组,接着是每个NDArray对应的名字的数组。

好了接下来就是解读源码中是如何从Stream中解析出内容的,首先我们来看下Stream类的

Read函数,具体见 io.h 第435行:

代码语言:c++
复制
template<typename T>
inline bool Stream::Read(T *out_data) {
  return serializer::Handler<T>::Read(this, out_data);
}

这里可以看到,Read 函数内部又调用了 Handler这个类的Read静态函数,这个静态函数对应的

代码见 serializer.h 第262行:

代码语言:c++
复制
inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

这里代码我第一次看的时候有点蒙,后来仔细研究了下也看懂了。首先我们要看 IfThenElse

是什么东西,这里还是看到 io.h 的第 38 到 66行:

代码语言:c++
复制
//! \cond Doxygen_Suppress
/*!
 * \brief Serializer that redirect calls by condition
 * \tparam cond the condition
 * \tparam Then the serializer used for then condition
 * \tparam Else the serializer used for else condition
 * \tparam Return the type of data the serializer handles
 */
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;

template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Then::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Then::Read(strm, data);
  }
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Else::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Else::Read(strm, data);
  }
};

这里可以看到 IfThenElse 就是一个结构体,有四个模板参数,意思很明显了,如果第一个参数

为true,则会调用Then这个类的Read静态函数,如果第一个参数为false,则会调用Else这个类的

Read静态函数。看完 IfThenElse 的定义之后,我们看回 262 行的Read函数就很清楚了,

代码语言:c++
复制
inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

意思就是,如果 dmlc::is_pod<T>::value 这个值为 true,那么就会调用 PODHandler 的Read

函数,否则就会走到下一个条件判断,下一个条件判断是当 dmlc::has_saveload<T>::value 这个

值为true的话就调用 SaveLoadClassHandler 的 Read 静态函数,否则就走到

UndefinedSerializerFor。好了,那么现在就是要看具体走了哪个分支,首先我们要知道T在运行时

是什么类型,看回上面的NDArray 的 Load 函数,知道了首先读取得两个数字的类型是 uint64_t,

接着跳转到源码 type_traits.h,看第126和第152行:

代码语言:c++
复制
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \
  template<>                                          \
  struct Trait<Type> {                                \
    static const bool value = Value;                  \
  }

DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);

很明显可以看到,dmlc::is_pod<uint64_t>::value 的值为 true,因此会调用 PODHandler 的

Read 函数,代码:

代码语言:c++
复制
/*! \brief Serializer for POD(plain-old-data) data */
template<typename T>
struct PODHandler {
  inline static void Write(Stream *strm, const T &data) {
    strm->Write(&data, sizeof(T));
  }
  inline static bool Read(Stream *strm, T *dptr) {
    return strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  // NOLINT(*)
  }
};

PODHandler 的Read函数就是调用 Stream 的Read,这里如果读者想再详细了解 Stream 类

Read 函数的工作原理可以自己再去细看,不过对于本文来说,到这里知道了会根据T的字节数读取

内容到dptr里面就够了。

Ok,现在已经读取完两个数字 header, reserved,然后就是读 NDArray Vector 了,然后这

里还是跳转到,调用 Handler<T>::Read 函数,不过这里和读数字不一样的地方在于,这里传入的

模板参数是vector<NDArray>,所以调用的是下面这个Handler定义的Read函数:

代码语言:c++
复制
//! \cond Doxygen_Suppress
template<typename T>
struct Handler<std::vector<T> > {
  inline static void Write(Stream *strm, const std::vector<T> &data) {
    IfThenElse<dmlc::is_pod<T>::value,
               PODVectorHandler<T>,
               ComposeVectorHandler<T>, std::vector<T> >
    ::Write(strm, data);
  }
  inline static bool Read(Stream *strm, std::vector<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODVectorHandler<T>,
                      ComposeVectorHandler<T>,
                      std::vector<T> >
    ::Read(strm, data);
  }
};

然后这里的判断分支是会调用 ComposeVectorHandler 的 Read 函数:

代码语言:c++
复制
/*!
 * \brief Serializer handler for std::vector<T> where T can be composed type
 * \tparam T element type
 */
template<typename T>
struct ComposeVectorHandler {
  inline static void Write(Stream *strm, const std::vector<T> &vec) {
    uint64_t sz = static_cast<uint64_t>(vec.size());
    strm->Write(&sz, sizeof(sz));
    for (size_t i = 0; i < vec.size(); ++i) {
      Handler<T>::Write(strm, vec[i]);
    }
  }
  inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
    uint64_t sz;
    if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
    size_t size = static_cast<size_t>(sz);
    out_vec->resize(size);
    for (size_t i = 0; i < size; ++i) {
      if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false;
    }
    return true;
  }
};

首先先读出 vector 数组的大小,然后分别读取每个 NDArray,这里在读每个 NDArray 的时

候又会调用 Handler<T>::Read 函数,这次 IfThenElse 分支判断那里会走

SaveLoadClassHandler这个分支:

代码语言:c++
复制
// serializer for class that have save/load function
template<typename T>
struct SaveLoadClassHandler {
  inline static void Write(Stream *strm, const T &data) {
    data.Save(strm);
  }
  inline static bool Read(Stream *strm, T *data) {
    return data->Load(strm);
  }
};

最后看到其实就是调用了 NDArray 类本身的 Load 函数,见源码 src/ndarray/ndarray.cc

代码语言:c++
复制
bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
  const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

  // load storage shape
  TShape sshape;
  if (nad > 0) {
    if (!sshape.Load(strm)) return false;
  }

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;

  // load aux_types and aux_shapes
  std::vector<int32_t> aux_types;
  std::vector<TShape> aux_shapes;
  if (nad > 0) {
    aux_types.resize(nad);
    aux_shapes.resize(nad);
    for (int i = 0; i < nad; ++i) {
      // load aux_type(i)
      if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false;
      // load aux_shapes(i)
      if (!aux_shapes[i].Load(strm)) return false;
    }
  }

  // load data into CPU
  NDArray temp;
  if (0 == nad) {
    temp = NDArray(shape, Context::CPU(), false, type_flag);
  } else {
    temp = NDArray(static_cast<NDArrayStorageType>(stype), shape,
                   Context::CPU(), false, type_flag,
                   aux_types, aux_shapes, sshape);
  }
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;

  // load aux_data
  if (nad > 0) {
    for (int i = 0; i < nad; ++i) {
      load_data = temp.aux_data(i);
      type_size = mshadow::mshadow_sizeof(load_data.type_flag_);
      nread = type_size * load_data.Size();
      if (strm->Read(load_data.dptr_, nread) != nread) return false;
    }
  }

  if (ctx.dev_mask() == cpu::kDevMask) {
    *this = std::move(temp); return true;
  } else {
#if MXNET_USE_CUDA
    *this = temp.Copy(ctx); return true;
#else
    *this = std::move(temp); return true;
#endif
  }
}

这里首先,读出一个 magic number ,如果用 V1.0 之后的MXNet版本,magic number

都是会等于 NDARRAY_V2_MAGIC,具体定义见下面:

代码语言:c++
复制
/* magic number for ndarray version 1, with int64_t TShape */
static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;

/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

所以不会进入 LegacyLoad 函数,接着就是读 storage type,NDArray的类型,除了常用的

普通类型,现在也已经支持了稀疏类型:

代码语言:c++
复制
enum NDArrayStorageType {
  kUndefinedStorage = -1,  // undefined storage
  kDefaultStorage,         // dense
  kRowSparseStorage,       // row sparse
  kCSRStorage,             // csr
};

一般来说,storage type 都是 kDefaultStorage 类型,我现在写的解析小工具里面也只考虑

了解析普通类型的NDArray,之后再改进吧。然后看到 num_aux_data函数,这个函数如果传入

普通类型则返回0,所以 nad 的值为 0。

代码语言:c++
复制
size_t num_aux_data(NDArrayStorageType stype) {
  size_t num = 0;
  switch (stype) {
    case kDefaultStorage: num = 0; break;
    case kCSRStorage: num = 2; break;
    case kRowSparseStorage: num = 1; break;
     default: LOG(FATAL) << "Unknown storage type" << stype; break;
  }
  return num;
}

nad 值为0 的话整个代码就简洁很多了,简化之后如下:

代码语言:c++
复制
bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;
  // load data into CPU
  NDArray temp;
  temp = NDArray(shape, Context::CPU(), false, type_flag);
  
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;
}

到这里为,大概怎么读取NDArray,相信应该挺清晰的了。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档