首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >由于不可克隆,将KerasRegressor与cross_validate一起使用失败

由于不可克隆,将KerasRegressor与cross_validate一起使用失败
EN

Stack Overflow用户
提问于 2021-09-27 15:50:18
回答 1查看 195关注 0票数 0

我正在使用Keras的sklearn包装器作为一个回归器,即tf.keras.wrappers.scikit_learn.KerasRegressor。我希望这个回归器在sklearn的交叉验证方案中工作,也就是sklearn.model_selection.cross_validate

回归器通常在没有CV的情况下工作。然而,后者失败了,因为我在回归器的__init__方法中有一个必要的参数来定义批处理输入形状,但它似乎丢失了。这似乎是因为使用clone(estimator)不能正确克隆MyRegressorKerasRegressor。具体的错误消息是:

代码语言:javascript
运行
复制
KeyError: 'batch_input_shape'

有没有一种方法可以让MyRegressorcross_validate一起工作?我是否在某种程度上违反了sklearn的要求?

请看这个精简的工作示例:

代码语言:javascript
运行
复制
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_validate
from tensorflow.keras.layers import Dense, LSTM
from tensorflow.keras.models import Sequential
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor

class MyRegressor(KerasRegressor):
    def __init__(self, batch_input_shape, build_fn=None, **kwargs):
        self.batch_input_shape = batch_input_shape
        super().__init__(**kwargs)

    def __call__(self, *kwargs):
        model = Sequential([
            LSTM(16, stateful=True, batch_input_shape=self.batch_input_shape),
            Dense(1),
        ])
        model.compile(optimizer='adam', loss='mean_squared_error', metrics=['RootMeanSquaredError'])
        return model

    def reset_states(self):
        self.model.reset_states()

X, y = make_regression(6400, 5)
X = X.reshape(X.shape[0], 1, X.shape[1])
batch_size = 64
batch_input_shape = (batch_size, 1, X.shape[-1])

# Works fine
reg = MyRegressor(batch_input_shape)
for i in range(10):
    reg.fit(X, y, batch_size=batch_size)
    reg.reset_states()

# Doesn't work
reg = MyRegressor(batch_input_shape)
results = cross_validate(reg, X, y, scoring=['neg_mean_squared_error'])
EN

回答 1

Stack Overflow用户

发布于 2021-09-27 18:22:54

可复制性需要适当的get_params方法。大多数情况下,这是通过继承sklearn的BaseEstimator获得的,但是KerasRegressor直接实现了自己的(source)。它的工作方式与您的附加batch_input_shape不兼容;您可以调整它以使其正常工作:

代码语言:javascript
运行
复制
    def get_params(self, deep=False):
        res = self.sk_params.copy()  # sk_params was set by KerasRegressor.__init__
        res.update({
            'build_fn': self.build_fn,
            'batch_input_shape': self.batch_input_shape,
        })
        return res

(在这次更新之后,我在您的示例中得到了一个关于输入形状的错误。但我对批处理大小和kera不太熟悉,无法回答后续问题。)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69349589

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档