首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Keras ModelCheckpoint:保存_freq/周期可以动态变化吗?

Keras ModelCheckpoint:保存_freq/周期可以动态变化吗?
EN

Stack Overflow用户
提问于 2021-02-24 16:51:00
回答 1查看 1K关注 0票数 0

我想知道是否有一种方法可以在Keras ModelCheckpoint中更改ModelCheckpoint参数。我想做的是在每一个时代之后储蓄,但在第十个时代之后开始储蓄,而不是以前。所以我希望ModelCheckpoint在前10个时代是“非活动的”,在第10个时代之后,用save_freq='epoch'__变成“活动的”。

这样的事情能实现还是不可能?

我尝试了save_freq=batch_size if epoch<10 else 'epoch',但是,正如预期的那样,epoch变量是未知的。在ModelCheckpoint中是否有一个“局部变量”用于历代?

UPDATE:我没有尝试动态更改save_freq参数,而是尝试进行两个不同的培训,一个接一个。第一次训练持续10次,在最后(第10次)保存训练模型,第二次训练持续110次,每一阶段后保存模型。

我的问题是:我是否必须在第一次训练之后加载保存的模型,还是可以立即继续进行第二次培训,就像下面的代码所示?

代码语言:javascript
运行
复制
# [...]

batch_size = 32

epochs0    = 10
epochs     = 110


callbacks0 = [
    keras.callbacks.ModelCheckpoint(
        resuts_dir+'model0/',
        monitor='val_loss',
        verbose=1,
        save_best_only=False,
        save_weights_only=False,
        save_freq=batch_size*epochs0)
    ]
callbacks  = [
    keras.callbacks.ModelCheckpoint(
        resuts_dir+'model/',
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=False,
        save_freq='epoch'),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=0,
        patience=20,
        verbose=1,
        mode='auto')
    ]

# [...]

training_history0 = model.fit(train_gen, epochs=epochs0, validation_data=val_gen, callbacks=callbacks0, verbose=1)
training_history  = model.fit(train_gen, epochs=epochs,  validation_data=val_gen, callbacks=callbacks,  verbose=1)
EN

回答 1

Stack Overflow用户

发布于 2021-04-19 20:35:30

我不太明白你的更新,但要回答“我能在10岁以后才能保存检查点吗?”,答案是肯定的。您可以创建从类ModelCheckpoint继承的类并修改类方法on_epoch_end,如下所示:

代码语言:javascript
运行
复制
class CustomCheckpoint(tf.keras.callbacks.ModelCheckpoint):
    def __init__(self, filepath, monitor, verbose, save_best_only, 
                 save_weights_only, min_epochs, save_freq):
        self.min_epochs = min_epochs
        super().__init__(filepath=filepath, 
                         monitor=monitor,
                         verbose=verbose,
                         save_best_only=save_best_only,
                         save_weights_only=save_weights_only
                         save_freq=save_freq)
    
    def on_epoch_end(self, epoch, logs={}):
        if self.save_freq == 'epoch':
            if (epoch+1) >= self.min_epochs:
                self._save_model(epoch=epoch, logs=logs)

然后在model.fit上调用该类的实例:

代码语言:javascript
运行
复制
MIN_EPOCHS = 10
custom_callback = CustomCheckpoint(filepath="checkpoint_e{epoch:02d}", # this filepath includes the epoch name
                                   min_epochs=MIN_EPOCHS, 
                                   monitor='val_loss',
                                   verbose=1,
                                   save_best_only=False,
                                   save_weights_only=False,
                                   save_freq='epoch')

model.fit(..., callbacks=[custom_callback])
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66355298

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档