首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >KerasClassifier无法适应模型,尽管一切都很正常,否则

KerasClassifier无法适应模型,尽管一切都很正常,否则
EN

Stack Overflow用户
提问于 2020-09-27 19:50:07
回答 1查看 3K关注 0票数 2

我正在尝试使用一个KerasClassifier包装,以使我的工作流程科学友好。但是,当我尝试将它与下面的函数一起使用时,它会出现一个错误;使用本机Keras模型fit()来训练模型是有效的。(这是Tensorflow 2.2.0,运行在conda环境中)

代码语言:javascript
运行
复制
def model_arch(n_features: int):
    i = tf.keras.layers.Input(shape=(n_features,))

    hidden_dense = tf.keras.layers.Dense(64)(i)
    hidden_dense = tf.keras.layers.BatchNormalization()(hidden_dense)
    hidden_dense = tf.keras.layers.Activation(tf.nn.tanh)(hidden_dense)

    o = tf.keras.layers.Dense(1)(hidden_dense)
    o = tf.keras.layers.BatchNormalization()(o)
    o = tf.keras.layers.Activation("sigmoid")(o)

    classifier = tf.keras.models.Model(inputs=i, outputs=o)

    opt = tf.keras.optimizers.SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
    classifier.compile(
        loss="binary_crossentropy",
        optimizer=opt,
        metrics=["accuracy"],
    )
    
    return classifier

以下工作:

代码语言:javascript
运行
复制
X = np.random.random((100,3))
y = np.random.random((100,)) # 'y' is a binary vector in reality

clf = model_arch(3)
clf.fit(X, y, epochs=10)

然而,当我尝试使用KerasClassifier包装时,我会得到一个错误:

代码语言:javascript
运行
复制
clf = KerasClassifier(model_arch(3), epochs=10)
clf.fit(X, y)

# ValueError: The first argument to `Layer.call` must always be passed.

我在互联网上看到的每一个例子似乎都与我一样:定义一个返回编译好的keras模型的函数,然后将它传递给包装器,并在管道中安装或使用它。我注意到的唯一不同是,大多数(如果不是全部)示例使用的是Sequential API而不是函数API,但是afaik不应该是一个问题,对吗?

Tensorflow文档似乎没有给出任何应该传递给包装器的函数类型的例子,但是由于每个示例都使用与我类似的函数,所以我认为这是正确的。

有人能给点线索吗?谢谢。

编辑(注释后):

我像这样导入KerasClassifier:

代码语言:javascript
运行
复制
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier

错误日志:

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/adrian/miniconda3/envs/kaggle/lib/python3.8/site-packages/tensorflow/python/keras/wrappers/scikit_learn.py", line 223, in fit
    return super(KerasClassifier, self).fit(x, y, **kwargs)
  File "/home/adrian/miniconda3/envs/kaggle/lib/python3.8/site-packages/tensorflow/python/keras/wrappers/scikit_learn.py", line 154, in fit
    self.model = self.build_fn(
  File "/home/adrian/miniconda3/envs/kaggle/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 799, in __call__
    raise ValueError(
ValueError: The first argument to `Layer.call` must always be passed.
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-09-27 20:16:31

KerasClassifer需要一个构建函数,而不是模型实例本身,它在调用时返回一个Keras模型的编译实例。因此,要用最小的更改来解决这个问题,必须将其包装在一个函数中:

代码语言:javascript
运行
复制
clf = KerasClassifier(lambda: model_arch(3), epochs=10)

或者,更好的方法是将模型的参数作为关键字参数传递:

代码语言:javascript
运行
复制
clf = KerasClassifier(model_arch, n_features=3, epochs=10)
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64092664

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档