首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Python语言中正确加载CatBoost中的预训练模型

如何在Python语言中正确加载CatBoost中的预训练模型
EN

Stack Overflow用户
提问于 2018-08-17 20:45:30
回答 2查看 6.1K关注 0票数 5

我已经训练了CatBoostClassifier来解决我的分类任务。现在我需要保存模型,并在另一个应用程序中使用它进行预测。为了做到这一点,我通过save_model方法保存模型,并通过load_model方法恢复模型。

但是,每次我在恢复的模型中调用predict时,我都会得到一个错误:

代码语言:javascript
运行
复制
CatboostError: There is no trained model to use predict(). Use fit() to train model. Then use predict().

因此,看起来我需要再次训练我的模型,而我需要恢复预先训练的模型,并将其仅用于预测。

我在这里做错了什么?有没有什么特殊的方法可以用来加载模型进行预测?

我的训练过程如下所示:

代码语言:javascript
运行
复制
model = CatBoostClassifier(
    custom_loss=['Accuracy'],
    random_seed=42,
    logging_level='Silent',
    loss_function='MultiClass')

model.fit(
    x_train, 
    y_train,
    cat_features=None,
    eval_set=(x_validation, y_validation),
    plot=True)

...

model.save("model.cbm")

我使用下面的代码恢复模型:

代码语言:javascript
运行
复制
model = CatBoostClassifier(
    custom_loss=['Accuracy'],
    random_seed=42,
    logging_level='Silent',
    loss_function='MultiClass')
model.load_model("model.cbm")

...


predict = self.model.predict(inputs)
EN

回答 2

Stack Overflow用户

发布于 2019-04-20 21:36:51

代码语言:javascript
运行
复制
# After you train the model using fit(), save like this - 
model.save_model('model_name')    # extension not required.

# And then, later load - 
from catboost import CatBoostClassifier
model = CatBoostClassifier()      # parameters not required.
model.load_model('model_name')

# Now, try predict().
票数 10
EN

Stack Overflow用户

发布于 2018-08-18 12:44:42

几个小时后,我意外地找到了解决方案。模型加载在外部python模块中实现,然后将其导入到Jupyter Notebook中。事实证明,我只需要重新启动Jupyter内核。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51895761

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档