首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何向lgbm自定义损失函数传递附加参数?

如何向lgbm自定义损失函数传递附加参数?
EN

Stack Overflow用户
提问于 2020-08-14 00:53:01
回答 1查看 265关注 1票数 2

我用以下方式编写了rmsse自定义损失函数

代码语言:javascript
运行
复制
def wrmsse(preds, y_true,store_name):
    '''
    preds - Predictions: pd.DataFrame of size (30490 rows, N day columns)
    y_true - True values: pd.DataFrame of size (30490 rows, N day columns)
    sequence_length - np.array of size (42840,)
    sales_weight - sales weights based on last 28 days: np.array (42840,)
    '''
    preds = preds[-(30490 * 28):]
    y_true = y_true.get_label()[-(30490 * 30490):]
    preds = preds.reshape(28, 30490).T
    y_true = y_true.reshape(28, 30490).T    
    sw = list(SW_store.keys())[key]
    return 'wrmsse', np.sum(np.sqrt(np.mean(np.square(rollup(preds-y_true)),axis=1)) * sw)/12,False #<-used 

我正在像下面这样训练模型

代码语言:javascript
运行
复制
model = 

store_name = 'CA_1    lgbm.train(params,train_set=train_set,num_boost_round=2500,early_stopping_rounds=50,valid_sets=val_set,verbose_eval = 100, feval= wrmsse)

我想将商店名称作为参数传递,我该怎么做呢?

EN

回答 1

Stack Overflow用户

发布于 2021-02-14 02:16:08

您可以通过将您的自定义ndarray附加到dataset来完成此操作。

例如,在声明数据集设置自定义类属性后,

代码语言:javascript
运行
复制
dtrain = lgb.Dataset(X_train, y_train, feature_name =feature_names, categorical_feature=categorical_feature, free_raw_data=False)
dval = lgb.Dataset(X_val, y_val, reference=dtrain, feature_name =feature_names, categorical_feature=categorical_feature, free_raw_data=False)

dtrain.indexes = np.arange(0, X_train.shape[0])
dval.indexes =  np.arange(0, X_val.shape[0])

这里的索引是我想在公制中使用的自定义数组,

然后,在您的指标函数中,将您的自定义数组作为闭包传递,并使用索引访问它们,

代码语言:javascript
运行
复制
def utility_score(weight, resp, date_):   
    def func(preds, train_data):
        score = 0.
        labels = train_data.get_label()
        indexes = train_data.indexes
        y_pred = preds.reshape(-1, 1)

        weight_ = weight[indexes, :]
        resp_ = resp[indexes, :]
        date__ = date_[indexes, :]
       
        # do whatever with ur custom vars and calculate score....
        
        return 'utility', score, True
    return func

像这样使用它,

代码语言:javascript
运行
复制
feval=utility_score(weight, resp, date_)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63399806

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档