首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用对数损失函数的sklearn.linear_model.SGDClassifier运行时间非常长

使用对数损失函数的sklearn.linear_model.SGDClassifier运行时间非常长
EN

Stack Overflow用户
提问于 2019-07-22 00:34:07
回答 1查看 258关注 0票数 1

我正在尝试复制已完成的工作here

这个项目的目标是建立一个逻辑机器学习模型来预测投资失败或失败的概率。

代码的sklearn部分如下:

代码语言:javascript
复制
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Imputer, StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import SGDClassifier

pipeline_sgdlogreg = Pipeline([
    ('imputer', Imputer(copy=False)), # Mean imputation by default
    ('scaler', StandardScaler(copy=False)),
    ('model', SGDClassifier(loss='log', max_iter=1000, tol=1e-3, random_state=1, warm_start=True))
])


param_grid_sgdlogreg = {
    'model__alpha': [10**-5, 10**-2, 10**1],
    'model__penalty': ['l1', 'l2']
}

grid_sgdlogreg = GridSearchCV(estimator=pipeline_sgdlogreg, param_grid=param_grid_sgdlogreg, 
                              scoring='roc_auc', n_jobs=-1, pre_dispatch='2*n_jobs', cv=5, 
                              verbose=1, return_train_score=False)
grid_sgdlogreg.fit(X_train, y_train)

GridSearchCV(cv=5, error_score='raise',
       estimator=Pipeline(memory=None,
       steps=[('imputer', Imputer(axis=0, copy=False, missing_values='NaN', strategy='mean', verbose=0)), 
              ('scaler', StandardScaler(copy=False, with_mean=True, with_std=True)), 
              ('model', SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
       penalty='l2', power_t=0.5, random_state=1, shuffle=True,
       tol=0.001, verbose=0, warm_start=True))]),
       fit_params=None, iid=True, n_jobs=-1,
       param_grid={'model__alpha': [1e-05, 0.01, 10], 'model__penalty': ['l1', 'l2']},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
       scoring='roc_auc', verbose=1)
grid_sgdlogreg.best_score_
grid_sgdlogreg.best_params_

我在训练集中有大约200000行数据。初始运行6小时后未产生任何结果或消息。

将我的训练集减少到100行,模型运行大约9秒。200行似乎无休止地运行。

在我尝试复制的python笔记本示例中,训练集有730000行,运行时间仅为3.3min。

我在一个8核的i7上并行运行,内存为16 gb。

我猜我需要对数据做额外的预处理,但不知道从哪里开始。我应该在我的数据中查找哪些可能导致此问题的内容?

EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57135050

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档