首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何将keras模型与其他数据一起保存并一起加载?

如何将keras模型与其他数据一起保存并一起加载?
EN

Stack Overflow用户
提问于 2018-08-09 05:13:08
回答 1查看 326关注 0票数 1

我有以下代码来训练keras神经网络

代码语言:javascript
复制
from keras import Sequential
from keras.layers import Dense
from keras.models import load_model

import numpy as np

class Model:
    def __init__(self, data=None):
        self.data = data
        self.metrics = []
        self.model = self.__build_model()

    def __build_model(self):
        model = Sequential()
        model.add(Dense(4, activation='relu', input_shape=(3,)))
        model.add(Dense(1, activation='relu'))
        model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])
        return model

    def train(self, epochs):
        self.model.fit(self.data[:, :-1], self.data[:,-1], epochs=epochs)
        return self

    def test(self, data):
        self.metrics = self.model.evaluate(data[:, :-1], data[:, -1])
        return self

    def predict(self, input):
        return self.model.predict(input)

    def save(self, path):
        self.model.save(path)
        # I would like to save self.metrics at the same time

    def load(self, path):
        self.model = load_model(path)


if __name__ == '__main__':
    train_data = np.random.rand(1000, 4)
    test_data = np.random.rand(100, 4)
    print("TRAINING, TESTING & SAVING..")
    model = Model(train_data)\
                .train(epochs=5)\
                .test(test_data)\
                .save('./model.h5')

    print('LOADING model & PREDICTING..')
    test_sample = np.random.rand(1, 3)
    model = Model()
    model.load('./model.h5')
    # I can then do like:
    test_output = model.predict(test_sample)
    print(test_output)

    # And want to get metrics which i had saved with it like:
    metrics = model.metrics
    print(metrics)

正如您所看到的,它将模型保存到h5文件中,但只保存keras模型,而不保存其他任何内容。如何同时保存其他数据,例如指标,然后在加载keras模型的同时也能够加载它们。

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2018-08-09 05:30:11

您可以使用任何序列化框架来实现这一点。

代码语言:javascript
复制
import hickle

def save(self, path):
      self.model.save(path)
      hkl.dump(self.metrics, 'metrics.hkl', mode='w')

def load(self, path):
      self.model = load_model(path)
      self.metrics = hkl.load('metrics.hkl')

您也可以将其保存为单个文件,只需从指标和模型对象中创建一个列表或另一个对象。我建议将它们分开保存。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51755721

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档