TypeError: fit_generator()获得意外的关键字参数“samples_per_epoch”
这个错误是由于在使用Keras的fit_generator()函数时,传入了意外的关键字参数“samples_per_epoch”。在Keras 2.1.0版本之后,fit_generator()函数的参数发生了变化,不再使用“samples_per_epoch”参数,而是使用“steps_per_epoch”参数来指定每个epoch中的步数。
解决这个问题的方法是将“samples_per_epoch”参数改为“steps_per_epoch”,并将其设置为每个epoch中的步数。步数可以通过将样本总数除以批次大小来计算得到。
以下是一个示例的修正代码:
from keras.models import Sequential
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator
# 创建模型
model = Sequential()
model.add(Dense(10, input_shape=(784,)))
# 编译模型
model.compile(optimizer='sgd', loss='categorical_crossentropy')
# 创建ImageDataGenerator对象
datagen = ImageDataGenerator()
# 生成器函数
def generator():
while True:
# 生成批次数据
x_batch, y_batch = datagen.flow(x_train, y_train, batch_size=32).next()
yield x_batch, y_batch
# 训练模型
model.fit_generator(generator(), steps_per_epoch=len(x_train) // 32, epochs=10)
在上面的示例中,我们使用ImageDataGenerator生成器来生成批次数据,并通过yield语句将数据传递给fit_generator()函数。steps_per_epoch参数被设置为样本总数除以批次大小的结果。
腾讯云相关产品和产品介绍链接地址:
没有搜到相关的沙龙
领取专属 10元无门槛券
手把手带您无忧上云