首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何使用pytorch-lightning将模型权重保存到mlflow跟踪服务器?

如何使用pytorch-lightning将模型权重保存到mlflow跟踪服务器?
EN

Stack Overflow用户
提问于 2019-12-03 11:20:46
回答 3查看 498关注 0票数 4

我想使用pytorch-lightning将模型权重保存到mlflow跟踪。pytorch-lightning支持logging。但是,似乎不支持将模型权重保存为mlflow上的工件。

一开始,我计划重写ModelCheckpoint类来做这件事,但我发现这对我来说很困难,因为复杂的Mixin操作。

有谁知道实现它的简单方法吗?

EN

回答 3

Stack Overflow用户

发布于 2020-10-01 20:45:23

正如@xela所说,您可以使用mlflow记录器的experiment对象来记录工件。

如果您希望在训练期间频繁记录模型权重,可以扩展ModelCheckpoint

代码语言:javascript
运行
复制
class MLFlowModelCheckpoint(ModelCheckpoint):
    def __init__(self, mlflow_logger, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mlflow_logger = mlflow_logger

    @rank_zero_only
    def on_validation_end(self, trainer, pl_module):
        super().on_validation_end(trainer, pl_module)
        run_id = self.mlflow_logger.run_id
        self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path)

然后在你的训练代码中使用

代码语言:javascript
运行
复制
mlflow_logger = MLFlowLogger()
checkpoint_callback = MLFlowModelCheckpoint(mlflow_logger)
trainer = pl.Trainer(checkpoint_callback=checkpoint_callback, logger=mlflow_logger)
票数 1
EN

Stack Overflow用户

发布于 2020-03-11 11:11:26

票数 0
EN

Stack Overflow用户

发布于 2021-08-12 15:55:16

@stecklin的替代方法是使用记录器的after_save_checkpoint方法。

您可以像这样扩展MLFlowLogger

代码语言:javascript
运行
复制
class MLFlowLoggerCheckpointer(pl.MLFlowLogger):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def after_save_checkpoint(self, model_checkpoint: pl.ModelCheckpoint) -> None:
        """
        Called after model checkpoint callback saves a new checkpoint.
        """

        self.experiment.log_artifact(
            self.run_id, model_checkpoint.best_model_path
        )
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59149725

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档