我想使用pytorch-lightning将模型权重保存到mlflow跟踪。pytorch-lightning支持logging。但是,似乎不支持将模型权重保存为mlflow上的工件。
一开始,我计划重写ModelCheckpoint类来做这件事,但我发现这对我来说很困难,因为复杂的Mixin操作。
有谁知道实现它的简单方法吗?
发布于 2020-10-01 20:45:23
正如@xela所说,您可以使用mlflow记录器的experiment
对象来记录工件。
如果您希望在训练期间频繁记录模型权重,可以扩展ModelCheckpoint
class MLFlowModelCheckpoint(ModelCheckpoint):
def __init__(self, mlflow_logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mlflow_logger = mlflow_logger
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
run_id = self.mlflow_logger.run_id
self.mlflow_logger.experiment.log_artifact(run_id, self.best_model_path)
然后在你的训练代码中使用
mlflow_logger = MLFlowLogger()
checkpoint_callback = MLFlowModelCheckpoint(mlflow_logger)
trainer = pl.Trainer(checkpoint_callback=checkpoint_callback, logger=mlflow_logger)
发布于 2020-03-11 11:11:26
发布于 2021-08-12 15:55:16
@stecklin的替代方法是使用记录器的after_save_checkpoint方法。
您可以像这样扩展MLFlowLogger:
class MLFlowLoggerCheckpointer(pl.MLFlowLogger):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def after_save_checkpoint(self, model_checkpoint: pl.ModelCheckpoint) -> None:
"""
Called after model checkpoint callback saves a new checkpoint.
"""
self.experiment.log_artifact(
self.run_id, model_checkpoint.best_model_path
)
https://stackoverflow.com/questions/59149725
复制相似问题