前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用xgboost的c接口推理模型

使用xgboost的c接口推理模型

原创
作者头像
plus sign
修改2024-02-27 15:11:01
1390
修改2024-02-27 15:11:01

官方c api tutorial文档,非常恶心的一点是,tutorial和文档问题很多。

也参考了不少开源项目,主要有xgboost-c-cplusplus,xgboostpp.

首先导入头文件#include "xgboost/c_api.h" ,接下来xgboost的绝大多数接口都包含在了这个头文件中。

然后我们需要一个宏,来用它获取xgboost函数使用的情况.在每次调用xgboost函数时都应该调用这个宏。

代码语言:c
复制
#define safe_xgboost(call) {  \
  int err = (call); \
  if (err != 0) { \
    fprintf(stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #call, XGBGetLastError());  \
    exit(1); \
  } \
}

我们使用的模型文件为xgboost_model.bin ,训练数据的输入是 11 个元素。

首先我们声明一个boost模型的句柄BoosterHandle booster; 接着用XGBoosterCreate 函数创建一个模型 。

代码语言:c
复制
BoosterHandle booster;
safe_xgboost(XGBoosterCreate(NULL, 0, &booster));

设置一个字符串作为模型路径const char *model_path = "../xgboost_model.bin";(../是因为编译出来的可执行文件在build目录下) , 通过句柄使用XGBoosterLoadModel函数加载模型。

代码语言:c
复制
const char *model_path = "../xgboost_model.bin";
XGBoosterLoadModel(booster, model_path)

设置一组数据作为推理测试,这里我选的数据标签是1.接着将输入数据转为xgboost的DMatrix格式。

代码语言:c
复制
float a[11]= {14.0,2.0,1.0,12.0,19010.0,120.0,14.0,0.0,0.0,0.0,0.0};
DMatrixHandle h_test;
safe_xgboost(XGDMatrixCreateFromMat(a, 1, 11, -1, &h_test));

下面就可以进行模型推理了,out_len 代表输出的长度(实际上是一个整型变量),f的模型推理的结果。

代码语言:c
复制
bst_ulong out_len;
const float *f;
safe_xgboost(XGBoosterPredict(booster, h_test, 0, 0, 1, &out_len, &f));

我们可以打印输出查看结果

代码语言:c
复制
printf("Value of the variable: %f\n", f[0]);

最后记得释放内存

代码语言:c
复制
XGDMatrixFree(h_test);
XGBoosterFree(booster);

完整的代码

代码语言:c
复制
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "xgboost/c_api.h"

#define safe_xgboost(call) {  \
  int err = (call); \
  if (err != 0) { \
    fprintf(stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #call, XGBGetLastError());  \
    exit(1); \
  } \
}

int main(int argc, char const *argv[]) {
    const char *model_path = "../xgboost_model.bin";

    // create booster handle first
    BoosterHandle booster;
    safe_xgboost(XGBoosterCreate(NULL, 0, &booster));
    // load model
    safe_xgboost(XGBoosterLoadModel(booster, model_path));

    //generate random data of a a[11],every nuber from 0 to 2
    // float a[11]= {1.0,12.0,1.0,1.0,16134.0,20600.0,0.0,1.0,0.0,0.0,0.0}; // label: 0.0
    float a[11]= {14.0,2.0,1.0,12.0,19010.0,120.0,14.0,0.0,0.0,0.0,0.0}; // label: 1.0

    for (int i = 0; i < 11; i++) {
        printf("%f, ", a[i]);
        if (i == 10) {
            printf("\n");
        }
    }
    // convert to DMatrix
    DMatrixHandle h_test;
    safe_xgboost(XGDMatrixCreateFromMat(a, 1, 11, -1, &h_test));
    // predict
    bst_ulong out_len;
    const float *f;
    safe_xgboost(XGBoosterPredict(booster, h_test, 0, 0, 1, &out_len, &f));
    printf("Value of the variable: %f\n", f[0]);

    XGDMatrixFree(h_test);
    XGBoosterFree(booster);
    return 0;
}

使用cmake编译

代码语言:CMakeLists.txt
复制
cmake_minimum_required(VERSION 3.18)
project(project_name LANGUAGES C CXX VERSION 0.1)
set(xgboost_DIR "/usr/include/xgboost")

include_directories(${xgboost_DIR})
link_directories(${xgboost_DIR})

add_executable(project_name test.c)
target_link_libraries(project_name xgboost)
代码语言:bash
复制
mkdir build
cd ./build
cmake ..
make .
./project_name

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档