首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >model = tf.keras.models.load_model()

model = tf.keras.models.load_model()
EN

Stack Overflow用户
提问于 2021-03-15 16:12:36
回答 2查看 6.6K关注 0票数 0

我用这种类型的代码保存了MLP回归类型算法:

代码语言:javascript
运行
复制
#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()


model.save('./saved_model/kwSummer')

但是当我尝试加载保存的模型时:

代码语言:javascript
运行
复制
model = tf.keras.models.load_model('./saved_model/kwSummer')

# Check its architecture
new_model.summary()

当我试图加载模型时,我在下面得到了这个错误。有人想试试吗?

代码语言:javascript
运行
复制
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.

我一直在试验使用Python3.7来训练模型,然后使用IPython AnacondaPython3.8加载模型,这与这个问题有什么关系吗?就像两个不同版本的tensorflow?

编辑, --这是整个脚本

代码语言:javascript
运行
复制
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import backend

from datetime import datetime
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import seaborn as sns
import math


df = pd.read_csv('./colabData.csv', index_col='Date', parse_dates=True)

print(df.info())



# This function keeps the learning rate at 0.001
# and decreases it exponentially after that.
def scheduler(epoch):
  if epoch < 1:
    return 0.001
  else:
    return 0.001 * tf.math.exp(0.01 * (1 - epoch))

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)


#function to calculate RSME
def rmse(y_true, y_pred):
    return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))




dfTrain = df.copy()

# split into input (X) and output (Y) variables
X = dfTrain.drop(['kW'],1)
Y = dfTrain['kW']

#define training & testing data set
offset = int(X.shape[0] * 0.8)
X_train, Y_train = X[:offset], Y[:offset]
X_test, Y_test = X[offset:], Y[offset:]


#define model
model = Sequential()
model.add(Dense(80, input_dim=2, kernel_initializer='normal', activation='relu'))
model.add(Dense(60, kernel_initializer='normal', activation='relu'))
model.add(Dense(40, kernel_initializer='normal', activation='relu'))
model.add(Dense(20, kernel_initializer='normal', activation='relu'))
model.add(Dense(10, kernel_initializer='normal', activation='relu'))
model.add(Dense(5, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.summary()
model.compile(loss='mse', optimizer='adam', metrics=[rmse])



# train model, test callback option
history = model.fit(X_train, Y_train, epochs=75, batch_size=1, verbose=2, callbacks=[callback])
#history = model.fit(X_train, Y_train, epochs=60, batch_size=1, verbose=2)

# plot metrics
plt.plot(history.history['rmse'])
plt.title('kW RSME Vs Epoch')
plt.show()

model.save('./saved_model/kwSummer')
print('[INFO] Saved model to drive')
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-03-15 17:50:50

由于您有一个自定义对象,所以必须使用custom_object参数加载它。它还在错误日志中通知了您。Src

代码语言:javascript
运行
复制
In addition, please use the `custom_objects` arg when calling `load_model()`.

尝试如下:

代码语言:javascript
运行
复制
new_model = tf.keras.models.load_model('./saved_model/kwSummer', , 
                                       custom_objects={"rmse": rmse})
票数 2
EN

Stack Overflow用户

发布于 2021-03-15 17:44:14

我可以建议通过google运行代码吗?这可能有助于查看代码问题还是兼容性问题。因为google将确保兼容性,因为它解决了很多我遇到的ML问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66641637

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档