首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >使用PyO3将列表作为参数传递给Python中的Rust

使用PyO3将列表作为参数传递给Python中的Rust
EN

Stack Overflow用户
提问于 2021-03-18 21:54:36
回答 1查看 745关注 0票数 2

我试图使用Py03将列表从Python传递到Rust。我试图将它传递给的函数具有以下签名:

代码语言:javascript
运行
复制
pub fn k_nearest_neighbours(k: usize, x: &[[f32; 2]], y: &[[f32; 3]]) -> Vec<Option<f32>> 

我正在为预先存在的库编写绑定,因此不能更改原始代码。我目前的做法是:

代码语言:javascript
运行
复制
// This is example code == DOES NOT WORK
#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize, x: Vec<Vec<f32>>, y: Vec<f32>) -> Vec<Option<f32>> {
    // reformat input where necessary
    let x_slice = x.as_slice();
    // return original lib's function return
    classification::k_nearest_neighbours(k, x_slice, y)
}

x.as_slice()函数几乎完成了我需要它做的事情,它给了我一片向量&[Vec<f32>],而不是切片&[[f32; 3]]

我希望能够运行以下Python代码:

代码语言:javascript
运行
复制
from rust_code import k_nearest_neighbours as knn  # this is the Rust function compiled with PyO3

X = [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [0.06, 7.0]]
train = [
        [0.0, 0.0, 0.0],
        [0.5, 0.5, 0.0],
        [3.0, 3.0, 1.0],
        [4.0, 3.0, 1.0],
    ]

k = 2
y_true = [0, 1, 1, 1]
y_test = knn(k, X, train)
assert(y_true == y_test)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-18 22:06:17

查看一下k_nearest_neighbours的签名就会发现,它期望[f32; 2][f32; 3]是数组,而不是切片(例如&[f32] )。

数组在编译时具有静态已知的大小,而片是动态大小的。向量也是如此,在你的例子中,你无法控制内部向量的长度。因此,从输入向量到预期数组的转换是错误的。

可以使用TryFrom将片转换为数组,即:

代码语言:javascript
运行
复制
use std::convert::TryFrom;
fn main() {
    let x = vec![vec![3.5, 3.4, 3.6]];
    let x: Result<Vec<[f32; 3]>, _> = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
}

综合起来,您的函数将需要在不正确的输入上返回一个错误,并且您需要创建一个新的向量,数组可以传递给您的函数:

代码语言:javascript
运行
复制
#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize, x: Vec<Vec<f32>>, y: Vec<f32>) -> PyResult<Vec<Option<f32>>> {
    let x = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
    let y = y.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>, _>>();
    // Error handling is missing here, you'll need to look into PyO3's documentation for that
    ...
    // return original lib's function return
    Ok(classification::k_nearest_neighbours(k, &x, &y))
}
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66699656

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档