前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何保存机器学习模型

如何保存机器学习模型

作者头像
MeteoAI
发布2019-08-21 17:29:10
2.5K0
发布2019-08-21 17:29:10
举报
文章被收录于专栏:MeteoAIMeteoAI

很多场合下我们都需要将训练完的模型存下以便于以后复用。

这篇文章主要介绍持久化存储机器学习模型的两种方式:picklejoblib,以及如何DIY自己的模型存储模块。

Before

对于下面这个例子,我们用逻辑回归算法训练了模型,那么如何在以后的场景中,重复应用这个训练完的模型呢?

代码语言:javascript
复制
from sklearn.linear_model import LogisticRegression  
from sklearn.datasets import load_iris  
from sklearn.model_selection import train_test_split

# Load and split data
data = load_iris()  
Xtrain, Xtest, Ytrain, Ytest = train_test_split(data.data, data.target, test_size=0.3, random_state=4)

# Create a model
model = LogisticRegression(C=0.1,  
                           max_iter=20, 
                           fit_intercept=True, 
                           n_jobs=3, 
                           solver='liblinear')
model.fit(Xtrain, Ytrain)

"""
our resulting model:
LogisticRegression(C=0.1, class_weight=None, dual=False, fit_intercept=True,  
    intercept_scaling=1, max_iter=20, multi_class='ovr', n_jobs=3,
    penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
    verbose=0, warm_start=False)
"""

下面我们介绍python中三种常用的方法,来做模型的持久化存储,以便于后续的上线部署。

Pickle Module (also: cPickle)

pickle可以序列化对象并保存到磁盘中,并在需要的时候读取出来,任何对象都可以执行序列化操作。同样我们也可以将训练好的模型对象序列化并存储到本地。具体的操作过程如下:

This saving procedure is also known as object serialization - representing an object with a stream of bytes, in order to store it on disk, send it over a network or save to a database, while the restoring procedure is known as deserialization.

代码语言:javascript
复制
import pickle

# Create your model here (same as above)

# Save to file in the current working directory
pkl_filename = "pickle_model.pkl"  
with open(pkl_filename, 'wb') as file:  
    pickle.dump(model, file)

# Load from file
with open(pkl_filename, 'rb') as file:  
    pickle_model = pickle.load(file)

# Calculate the accuracy score and predict target values
score = pickle_model.score(Xtest, Ytest)  
print("Test score: {0:.2f} %".format(100 * score))  
Ypredict = pickle_model.predict(Xtest)

也可以将一些过程中的参数通过tuple的形式保存下来:

代码语言:javascript
复制
tuple_objects = (model, Xtrain, Ytrain, score)

# Save tuple
pickle.dump(tuple_objects, open("tuple_model.pkl", 'wb'))

# Restore tuple
pickled_model, pickled_Xtrain, pickled_Ytrain, pickled_score = pickle.load(open("tuple_model.pkl", 'rb'))

cPickle是用 C 编码的pickle模块,性能更好,推荐在大多数的场景中使用该模块。

需要注意的是:在序列化模型的时候尽可能的保持python及主要的依赖库(如numpy, sklearn等)版本一致,以防不兼容的错误。

Joblib Module

joblibsklearn中自带的一个工具,用于模型的持久化存储,做了很多的优化。在多数场景下,joblib的性能要优于pickle,尤其是当数据量较大的情况更加明显。

代码语言:javascript
复制
from sklearn.externals import joblib

# Save to file in the current working directory
joblib_file = "joblib_model.pkl"  
joblib.dump(model, joblib_file)

# Load from file
joblib_model = joblib.load(joblib_file)

# Calculate the accuracy and predictions
score = joblib_model.score(Xtest, Ytest)  
print("Test score: {0:.2f} %".format(100 * score))  
Ypredict = pickle_model.predict(Xtest)

参见下面pickle和joblib的对比试验:

代码语言:javascript
复制
from sklearn.externals import joblib
import time
import numpy
import pickle

bigarray = numpy.zeros([190,91,190])
bigarray = bigarray.flatten()

### Saving
start = time.time()
joblib.dump(bigarray,"bigarray1.pkl")
end = time.time() - start
end
# 0.31264686584472656

start = time.time()
pickle.dump(bigarray,open("bigarray2.pkl","wb"))
end = time.time()-start
end
# 4.827500104904175

### Loading
start = time.time()
joblib.load("bigarray1.pkl")
end = time.time() - start
end
# 0.47748589515686035

start = time.time()
pickle.load(open("bigarray2.pkl","rb"))
end = time.time()-start
end
# 0.7575929164886475

DIY Module

也可以自己DIY模型的JSON存储模块,可以设计的更加兼容,灵活性大,但是复杂度相应也会更高。

代码语言:javascript
复制
import json  
import numpy as np

class MyLogReg(LogisticRegression):

    # Override the class constructor
    def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train

    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'

        # Creat json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
            file.write(json_txt)

    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)

        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None

应用模块:

代码语言:javascript
复制
filepath = "mylogreg.json"

# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)  
mylogreg.save_json(filepath)

# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()  
json_mylogreg.load_json(filepath)  
json_mylogreg

"""
MyLogReg(C=1.0,  
     X_train=array([[ 4.3,  3. ,  1.1,  0.1],
       [ 5.7,  4.4,  1.5,  0.4],
       ...,
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.7,  2.8,  6.7,  2. ]]),
     Y_train=array([0, 0, ..., 2, 2]), class_weight=None, dual=False,
     fit_intercept=True, intercept_scaling=1, max_iter=100,
     multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
     solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
"""

Reference

[1] Better Python compressed persistence in joblib: [http://gael-varoquaux.info/programming/new_low-overhead_persistence_in_joblib_for_big_data.html](http://gael-varoquaux.info/programming/new_low-overhead_persistence_in_joblib_for_big_data.html) [2] Save and Restore Models: https://stackabuse.com/scikit-learn-save-and-restore-models/ [3] Saving machine learning models: https://chrisalbon.com/machine_learning/basics/saving_machine_learning_models/ [4] Model persistence: http://scikit-learn.org/stable/modules/model_persistence.html [5] Save and Load Machine Learning Models in Python with scikit-learn: https://machinelearningmastery.com/save-load-machine-learning-models-python-scikit-learn/ [6] DataCamp pickle: https://www.datacamp.com/community/tutorials/pickle-python-tutorial#not [7] model persistence using JSON.: https://github.com/rasbt/python-machine-learning-book/blob/master/code/bonus/scikit-model-to-json.ipynb

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 MeteoAI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Before
  • Pickle Module (also: cPickle)
  • Joblib Module
  • DIY Module
  • Reference
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档