首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >使用递归模板动态分配的多维数组

使用递归模板动态分配的多维数组
EN

Stack Overflow用户
提问于 2018-10-26 03:10:18
回答 1查看 62关注 0票数 1

为了读取和存储MATLAB程序的一些结果,我需要使用多达6维的矩阵。而不是像这样做:

代码语言:javascript
运行
复制
typedef std::vector<double>  Row;
typedef std::vector<Row>     Matrix2;
typedef std::vector<Matrix2> Matrix3;
typedef std::vector<Matrix3> Matrix4;
typedef std::vector<Matrix4> Matrix5;
typedef std::vector<Matrix5> Matrix6;

我决定使用模板,这是我到目前为止所做的:

代码语言:javascript
运行
复制
template <class T, int N>
class Matrix {
public:
    typedef typename Matrix<T, N - 1>::type MatrixOneDimLower;
    typedef std::vector<MatrixOneDimLower> type;

    type _data;

    template <unsigned int dn, typename ...NT>
    Matrix(unsigned int dn, NT ...drest) : _data(dn, MatrixOneDimLower(drest)) {}

    MatrixOneDimLower& operator[](unsigned int index)
    {
        return _data[index];
    }
};

template <class T>
class Matrix<T, 1> {
public:
    typedef std::vector<T> type;

    type _data;

    Matrix(unsigned int d0) : _data(d0, T(0.0)) {}

    T& operator[](unsigned int index)
    {
        return _data[index];
    }
};

不幸的是,我不太擅长各种模板和递归模板,这是不能工作的。例如,如果我尝试将其用作:

代码语言:javascript
运行
复制
Matrix<double, 4> temp(n,  dim[2], dim[1], dim[0]);

我得到这个编译时错误(Visual Studio 2017):

代码语言:javascript
运行
复制
error C2661: 'Matrix<double,4>::Matrix': no overloaded function takes 4 arguments

如果你能让我知道我做错了什么,我将不胜感激。

EN

Stack Overflow用户

回答已采纳

发布于 2018-10-26 03:41:02

代码语言:javascript
运行
复制
template<class T, std::size_t I>
struct MatrixView {
  MatrixView<T, I-1> operator[](std::size_t i) {
    return {ptr + i* *strides, strides+1};
  }
  MatrixView( T* p, std::size_t const* stride ):ptr(p), strides(stride) {}
private:
  T* ptr = 0;
  std::size_t const* strides = 0;
};
template<class T>
struct MatrixView<T, 1> {
  T& operator[](std::size_t i) {
    return ptr[i];
  }
  MatrixView( T* p, std::size_t const* stride ):ptr(p) {}
private:
  T* ptr = 0;
};
template<class T, std::size_t N>
struct Matrix {
  Matrix( std::array<std::size_t, N> sizes ) {
    std::size_t accumulated = 1;
    for (std::size_t i = 1; i < sizes.size(); ++i) {
      accumulated *= sizes[N-i];
      strides[N-i] = accumulated;
    }
    storage.resize( strides[0] * sizes[0] );
  }
  MatrixView<T, N> get() { return {storage.data(), strides.data()}; }
  MatrixView<T const, N> get() const { return {storage.data(), strides.data()}; }
private:
  std::vector<T> storage;
  std::array<std::size_t, N-1> strides;
};

这需要执行Matrix<int, 6> m{ {5,4,2,1,3,5} };来创建一个6维的矩阵。

要访问它,您需要执行m.get()[3][0][0][0][0][0] = 4

你可以去掉那个.get(),但是只要你想支持一阶张量,它就有点烦人。

数据是连续存储的。

票数 2
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52996486

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档