Skip to content

Commit

Permalink
ADD: mlflow autologging support for PL trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 23, 2023
1 parent f2ccfae commit 1af8ba7
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
77 changes: 77 additions & 0 deletions src/itwinai/torch/mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Dict, Optional
import os

import mlflow
import yaml


def _get_mlflow_logger_conf(pl_config: Dict) -> Optional[Dict]:
"""Extract MLFLowLogger configuration from pytorch lightning
configuration file, if present.
Args:
pl_config (Dict): lightning configuration loaded in memory.
Returns:
Optional[Dict]: if present, MLFLowLogger constructor arguments
(under 'init_args' key).
"""
if isinstance(pl_config['trainer']['logger'], list):
# If multiple loggers are provided
for logger_conf in pl_config['trainer']['logger']:
if logger_conf['class_path'].endswith('MLFlowLogger'):
return logger_conf['init_args']
elif pl_config['trainer']['logger']['class_path'].endswith('MLFlowLogger'):
return pl_config['trainer']['logger']['init_args']


def _mlflow_log_pl_config(pl_config: Dict, local_yaml_path: str) -> None:
os.makedirs(os.path.dirname(local_yaml_path), exist_ok=True)
with open(local_yaml_path, 'w') as outfile:
yaml.dump(pl_config, outfile, default_flow_style=False)
mlflow.log_artifact(local_yaml_path)


def init_lightning_mlflow(
pl_config: Dict,
default_experiment_name: str = 'Default',
**autolog_kwargs
) -> None:
"""Initialize mlflow for pytorch lightning, also setting up
auto-logging (mlflow.pytorch.autolog(...)). Creates a new mlflow
run and attaches it to the mlflow auto-logger.
Args:
pl_config (Dict): pytorch lightning configuration loaded in memory.
default_experiment_name (str, optional): used as experiment name
if it is not given in the lightning conf. Defaults to 'Default'.
**autolog_kwargs (kwargs): args for mlflow.pytorch.autolog(...).
"""
mlflow_conf: Optional[Dict] = _get_mlflow_logger_conf(pl_config)
if not mlflow_conf:
return

tracking_uri = mlflow_conf.get('tracking_uri')
if not tracking_uri:
save_path = mlflow_conf.get('save_dir')
tracking_uri = "file://" + os.path.abspath(save_path)

experiment_name = mlflow_conf.get('experiment_name')
if not experiment_name:
experiment_name = default_experiment_name

mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
mlflow.pytorch.autolog(**autolog_kwargs)
mlflow.start_run()

mlflow_conf['experiment_name'] = experiment_name
mlflow_conf['run_id'] = mlflow.active_run().info.run_id

_mlflow_log_pl_config(pl_config, '.tmp/pl_config.yml')


def teardown_lightning_mlflow() -> None:
"""End active mlflow run, if any."""
if mlflow.active_run() is not None:
mlflow.end_run()
24 changes: 12 additions & 12 deletions use-cases/3dgan/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ executor:
barebones: false
benchmark: null
# callbacks:
# # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# # init_args:
# # monitor: val_loss
# # patience: 2
# - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
# init_args:
# monitor: real_batch_loss_epoch
# patience: 2
# - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor
# init_args:
# logging_interval: step
# # - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# # init_args:
# # dirpath: checkpoints
# # filename: best-checkpoint
# # mode: min
# # monitor: val_loss
# # save_top_k: 1
# # verbose: true
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# init_args:
# dirpath: checkpoints
# filename: best-checkpoint
# mode: min
# monitor: real_batch_loss_epoch
# save_top_k: 1
# verbose: true
check_val_every_n_epoch: 1
default_root_dir: null
detect_anomaly: false
Expand Down
6 changes: 6 additions & 0 deletions use-cases/3dgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from itwinai.serialization import ModelLoader
from itwinai.torch.inference import TorchModelLoader
from itwinai.torch.types import Batch
from itwinai.torch.mlflow import (
init_lightning_mlflow,
teardown_lightning_mlflow
)

from model import ThreeDGAN
from dataloader import ParticlesDataModule
Expand All @@ -26,6 +30,7 @@ def __init__(self, config: Union[Dict, str]):
self.conf = config

def train(self) -> Any:
init_lightning_mlflow(self.conf, registered_model_name='3dgan-lite')
old_argv = sys.argv
sys.argv = ['some_script_placeholder.py']
cli = LightningCLI(
Expand All @@ -42,6 +47,7 @@ def train(self) -> Any:
)
sys.argv = old_argv
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
teardown_lightning_mlflow()

def execute(
self,
Expand Down

0 comments on commit 1af8ba7

Please sign in to comment.