首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >我可以使用pybind11将numpy数组传递给接受Eigen::Tensor的函数吗?

我可以使用pybind11将numpy数组传递给接受Eigen::Tensor的函数吗?
EN

Stack Overflow用户
提问于 2019-10-16 20:00:19
回答 2查看 1.7K关注 0票数 5

是否可以使用pybind1将三维数值数组传递给接受Eigen::Tensor作为参数的c++函数。例如,考虑以下c++函数:

代码语言:javascript
运行
复制
Eigen::Tensor<double, 3> addition_tensor(Eigen::Tensor<double, 3> a,
                                         Eigen::Tensor<double, 3> b) {
    return a + b;
}

在编译函数、将其导入python并向其传递numpy数组np.ones((1, 2, 2))之后,我收到以下错误消息:

代码语言:javascript
运行
复制
TypeError: addition_tensor(): incompatible function arguments. The following argument types are supported:
    1. (arg0: Eigen::Tensor<double, 3, 0, long>, arg1: Eigen::Tensor<double, 3, 0, long>) -> Eigen::Tensor<double, 3, 0, long>

我特别惊讶于不能传递三维numpy array数组,因为我可以向接受Eigen::MatrixXd的函数传递二维numpy,如下所示:

代码语言:javascript
运行
复制
Eigen::MatrixXd addition(Eigen::MatrixXd a, Eigen::MatrixXd b) { return a + b; }

我在这个例子中使用的整个代码是:

代码语言:javascript
运行
复制
#include <eigen-git-mirror/Eigen/Dense>
#include <eigen-git-mirror/unsupported/Eigen/CXX11/Tensor>
#include "pybind11/include/pybind11/eigen.h"
#include "pybind11/include/pybind11/pybind11.h"

Eigen::MatrixXd addition(Eigen::MatrixXd a, Eigen::MatrixXd b) { return a + b; }

Eigen::Tensor<double, 3> addition_tensor(Eigen::Tensor<double, 3> a,
                                         Eigen::Tensor<double, 3> b) {
    return a + b;
}

PYBIND11_MODULE(example, m) {
    m.def("addition", &addition, "A function which adds two numbers");
    m.def("addition_tensor", &addition_tensor,
          "A function which adds two numbers");
}

我用g++ -shared -fPIC `python3 -m pybind11 --includes` example.cpp -o example`python3-config --extension-suffix`编译了上面的代码。有人知道如何将三维numpy数组转换为接受三维Eigen::Tensor的函数吗

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-10-16 20:18:00

它不是直接支持的,这里有一些讨论(如果您想要将其添加到您的项目中,包括一些进行映射的代码):https://github.com/pybind/pybind11/issues/1377

票数 4
EN

Stack Overflow用户

发布于 2019-10-16 21:29:46

感谢@John Zwinck的回答,我可以实现我正在寻找的东西。如果有人感兴趣,这里是复制:

代码语言:javascript
运行
复制
#include <eigen-git-mirror/Eigen/Dense>
#include <eigen-git-mirror/unsupported/Eigen/CXX11/Tensor>
#include "pybind11/include/pybind11/eigen.h"
#include "pybind11/include/pybind11/numpy.h"
#include "pybind11/include/pybind11/pybind11.h"

Eigen::Tensor<double, 3, Eigen::RowMajor> getTensor(
    pybind11::array_t<double> inArray) {
    // request a buffer descriptor from Python
    pybind11::buffer_info buffer_info = inArray.request();

    // extract data an shape of input array
    double *data = static_cast<double *>(buffer_info.ptr);
    std::vector<ssize_t> shape = buffer_info.shape;

    // wrap ndarray in Eigen::Map:
    // the second template argument is the rank of the tensor and has to be
    // known at compile time
    Eigen::TensorMap<Eigen::Tensor<double, 3, Eigen::RowMajor>> in_tensor(
        data, shape[0], shape[1], shape[2]);
    return in_tensor;
}

pybind11::array_t<double> return_array(
    Eigen::Tensor<double, 3, Eigen::RowMajor> inp) {
    std::vector<ssize_t> shape(3);
    shape[0] = inp.dimension(0);
    shape[1] = inp.dimension(1);
    shape[2] = inp.dimension(2);
    return pybind11::array_t<double>(
        shape,  // shape
        {shape[1] * shape[2] * sizeof(double), shape[2] * sizeof(double),
         sizeof(double)},  // strides
        inp.data());       // data pointer
}

pybind11::array_t<double> addition(pybind11::array_t<double> a,
                                   pybind11::array_t<double> b) {
    Eigen::Tensor<double, 3, Eigen::RowMajor> a_t = getTensor(a);
    Eigen::Tensor<double, 3, Eigen::RowMajor> b_t = getTensor(b);
    Eigen::Tensor<double, 3, Eigen::RowMajor> res = a_t + b_t;
    return return_array(res);
}

PYBIND11_MODULE(example, m) {
    m.def("addition", &addition, "A function which adds two numbers");
}

与约翰提到的链接中的建议不同,我不介意对Eigen::Tensor使用RowMajor存储顺序。我也看到这个存储顺序在tensorflow代码中被多次使用。我不知道上面的代码是否不必要地复制了数据。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58412795

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档