前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >XGBoost:在Python中使用XGBoost

XGBoost:在Python中使用XGBoost

作者头像
全栈程序员站长
发布2022-08-26 13:23:06
1K0
发布2022-08-26 13:23:06
举报
文章被收录于专栏:全栈程序员必看

大家好,又见面了,我是你们的朋友全栈君。

在Python中使用XGBoost

下面将介绍XGBoost的Python模块,内容如下: * 编译及导入Python模块 * 数据接口 * 参数设置 * 训练模型l * 提前终止程序 * 预测

A walk through python example for UCI Mushroom dataset is provided.

安装

首先安装XGBoost的C++版本,然后进入源文件的根目录下的 wrappers文件夹执行如下脚本安装Python模块

代码语言:javascript
复制
python setup.py install

安装完成后按照如下方式导入XGBoost的Python模块

代码语言:javascript
复制
import xgboost as xgb

=

数据接口

XGBoost可以加载libsvm格式的文本数据,加载的数据格式可以为Numpy的二维数组和XGBoost的二进制的缓存文件。加载的数据存储在对象DMatrix中。

  • 加载libsvm格式的数据和二进制的缓存文件时可以使用如下方式
代码语言:javascript
复制
dtrain = xgb.DMatrix('train.svm.txt')
dtest = xgb.DMatrix('test.svm.buffer')
  • 加载numpy的数组到DMatrix对象时,可以用如下方式
代码语言:javascript
复制
data = np.random.rand(5,10) # 5 entities, each contains 10 features
label = np.random.randint(2, size=5) # binary target
dtrain = xgb.DMatrix( data, label=label)
  • scipy.sparse格式的数据转化为 DMatrix格式时,可以使用如下方式
代码语言:javascript
复制
csr = scipy.sparse.csr_matrix( (dat, (row,col)) )
dtrain = xgb.DMatrix( csr )
  • DMatrix 格式的数据保存成XGBoost的二进制格式,在下次加载时可以提高加载速度,使用方式如下
代码语言:javascript
复制
dtrain = xgb.DMatrix('train.svm.txt')
dtrain.save_binary("train.buffer")
  • 可以用如下方式处理 DMatrix中的缺失值:
代码语言:javascript
复制
dtrain = xgb.DMatrix( data, label=label, missing = -999.0)
  • 当需要给样本设置权重时,可以用如下方式
代码语言:javascript
复制
w = np.random.rand(5,1)
dtrain = xgb.DMatrix( data, label=label, missing = -999.0, weight=w)

参数设置

XGBoost使用key-value格式保存参数. Eg * Booster(基本学习器)参数

代码语言:javascript
复制
param = {
  
  'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' }
param['nthread'] = 4
plst = param.items()
plst += [('eval_metric', 'auc')] # Multiple evals can be handled in this way
plst += [('eval_metric', 'ams@0')]
  • 还可以定义验证数据集,验证算法的性能
代码语言:javascript
复制
evallist  = [(dtest,'eval'), (dtrain,'train')]

=

训练模型

有了参数列表和数据就可以训练模型了 * 训练

代码语言:javascript
复制
num_round = 10
bst = xgb.train( plst, dtrain, num_round, evallist )
  • 保存模型 在训练完成之后可以将模型保存下来,也可以查看模型内部的结构
代码语言:javascript
复制
bst.save_model('0001.model')
  • Dump Model and Feature Map You can dump the model to txt and review the meaning of model
代码语言:javascript
复制
# dump model
bst.dump_model('dump.raw.txt')
# dump model with feature map
bst.dump_model('dump.raw.txt','featmap.txt')
  • 加载模型 通过如下方式可以加载模型
代码语言:javascript
复制
bst = xgb.Booster({
  
  'nthread':4}) #init model
bst.load_model("model.bin") # load data

=

提前终止程序

如果有评价数据,可以提前终止程序,这样可以找到最优的迭代次数。如果要提前终止程序必须至少有一个评价数据在参数evals中。 If there’s more than one, it will use the last.

train(..., evals=evals, early_stopping_rounds=10)

The model will train until the validation score stops improving. Validation error needs to decrease at least every early_stopping_rounds to continue training.

If early stopping occurs, the model will have two additional fields: bst.best_score and bst.best_iteration. Note that train() will return a model from the last iteration, not the best one.

This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC).

=

Prediction

After you training/loading a model and preparing the data, you can start to do prediction.

代码语言:javascript
复制
data = np.random.rand(7,10) # 7 entities, each contains 10 features
dtest = xgb.DMatrix( data, missing = -999.0 )
ypred = bst.predict( xgmat )

If early stopping is enabled during training, you can predict with the best iteration.

代码语言:javascript
复制
ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration)

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/143956.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年5月1,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 在Python中使用XGBoost
    • 安装
      • 数据接口
        • 参数设置
          • 训练模型
            • 提前终止程序
            相关产品与服务
            数据保险箱
            数据保险箱(Cloud Data Coffer Service,CDCS)为您提供更高安全系数的企业核心数据存储服务。您可以通过自定义过期天数的方法删除数据,避免误删带来的损害,还可以将数据跨地域存储,防止一些不可抗因素导致的数据丢失。数据保险箱支持通过控制台、API 等多样化方式快速简单接入,实现海量数据的存储管理。您可以使用数据保险箱对文件数据进行上传、下载,最终实现数据的安全存储和提取。
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档