首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >宽数据集(400列)上的lightgbm内存问题

宽数据集(400列)上的lightgbm内存问题
EN

Stack Overflow用户
提问于 2020-12-12 01:49:45
回答 1查看 165关注 0票数 1

我是第一次接触lightgbm。我有大数据(数十亿行不断更新)。为训练准备的数据集也很宽,大约有400列。

我有两个问题:

首先,即使对于像10000行这样的小子集,我的内核在几千个时期之后仍然会死掉。在训练过程中,内存使用量一直在上升,直到失败。我有126G的内存。

我尝试过使用不同的参数进行训练,评论说也是尝试过的

代码语言:javascript
运行
复制
parameters = {
  'histogram_pool_size': 5000,
  'objective': 'regression',
  'metric': 'l2',
  'boosting': 'dart',#'gbdt
  'num_leaves': 10, #100
  'learning_rate': 0.01,
  'verbose': 0,
  'max_bin': 66,.
  'force_col_wise':True, #default
  'max_bin': 6,  #60 #default
  'max_depth': 10, #default
  'min_data_in_leaf': 30, #default
  'min_child_samples': 20,#default
  'feature_fraction': 0.5,#default
  'bagging_fraction': 0.8,#default
  'bagging_freq': 40,#default
  'bagging_seed': 11,#default
  'lambda_l1': 2 #default
  'lambda_l2': 0.1 #default }

限制列的数量似乎是有帮助的,但我知道一些得分较低且具有全局特征重要性的列在某些局部范围内具有重要意义。

其次,用大数据逐步训练lightgbm并用新数据更新lightgbm模型的正确方法是什么?我之前主要使用神经网络,它是由自然增量训练的,我知道树不会以这种方式工作,尽管在技术上可以更新模型,但它不会与以整体方式训练的模型相同。如何应对?

完整代码:

代码语言:javascript
运行
复制
# X is dataframe

cat_names =  X.select_dtypes(['bool','category',object]).columns.tolist()

for c in cat_names: X[c] = X[c].astype('category')
cat_cols = [c for c, col in enumerate(cat_names)]
X[cat_names] = X[cat_names].apply(lambda x: x.cat.codes)
x = X.values
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.2, random_state=42)

train_ds = lightgbm.Dataset(x_train, label=y_train)
valid_ds = lightgbm.Dataset(x_valid, label=y_valid)

model = lightgbm.train(parameters,
                   train_ds,
                   valid_sets=valid_ds,
                   categorical_feature = cat_cols,
                   num_boost_round=2000,
                   early_stopping_rounds=50)
EN

回答 1

Stack Overflow用户

发布于 2020-12-14 00:08:41

将数据类型更改为较少的冗余修复了内存问题!如果您的数据集是pandas dataframe,请执行以下操作:

代码语言:javascript
运行
复制
ds[ds.select_dtypes('float64').columns] = ds.select_dtypes('float64').astype('float32')
ds[ds.select_dtypes('int64').columns] = ds.select_dtypes('int64').astype('int32')

!!! caution您的数据范围可能超出所选的数据类型范围,在这种情况下,pandas会弄乱您的数据。例如,int8数据类型仅限于-128 to 127中的范围,因此请选择能够处理您的数据的范围。

您可以使用检查选定的dtype范围

代码语言:javascript
运行
复制
import numpy as np
np.iinfo('int32').min, np.iinfo('int32').max
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65256226

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档