From d3a2630aa8f19a545dea5c53f1ca66c926600c1b Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 16:03:59 +0100 Subject: [PATCH] UPDATE: load components from pipeline --- src/itwinai/components.py | 36 +++++++++++++++++++++-------------- use-cases/3dgan/pipeline.yaml | 20 +++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index de155236..8232834c 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -344,23 +344,27 @@ def _pack_args(self, args) -> Tuple: return args -def recursive_replace(config: Dict, target_field: str, new_value: Any) -> None: - def _recursive_replace_key(sub_dict: Dict): - if not isinstance(sub_dict, dict): - return - for k, v in sub_dict.items(): - if k == target_field: - sub_dict[k] = new_value - return - else: - _recursive_replace_key(v) - _recursive_replace_key(config) +def add_replace_field( + config: Dict, + key_chain: str, + value: Any +) -> None: + sub_config = config + for idx, k in enumerate(key_chain.split('.')): + if idx >= len(key_chain.split('.')) - 1: + # Last key reached + break + if not isinstance(sub_config.get(k), dict): + sub_config[k] = dict() + sub_config = sub_config[k] + sub_config[k] = value def load_pipeline_step( pipe: Union[str, Dict], step_id: Union[str, int], - override_keys: Optional[Dict[str, Any]] = None + override_keys: Optional[Dict[str, Any]] = None, + verbose: bool = False ) -> Executable: if isinstance(pipe, str): # Load pipe from YAML file path @@ -369,8 +373,12 @@ def load_pipeline_step( # Override fields if override_keys is not None: - for key, value in override_keys.items(): - recursive_replace(step_dict_config, key, value) + for key_chain, value in override_keys.items(): + add_replace_field(step_dict_config, key_chain, value) + if verbose: + import json + print(f"NEW STEP CONFIG:") + print(json.dumps(step_dict_config, indent=4)) # Wrap config under "step" field and parse it step_dict_config = dict(step=step_dict_config) diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 676424aa..cd45674d 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -2,12 +2,14 @@ executor: class_path: itwinai.components.Executor init_args: steps: - - class_path: dataloader.Lightning3DGANDownloader + dataloading_step: + class_path: dataloader.Lightning3DGANDownloader init_args: data_path: exp_data/ # Set to null to skip dataset download data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - - class_path: trainer.Lightning3DGANTrainer + training_step: + class_path: trainer.Lightning3DGANTrainer init_args: # Pytorch lightning config for training config: @@ -49,7 +51,7 @@ executor: limit_test_batches: null limit_train_batches: null limit_val_batches: null - log_every_n_steps: 2 + log_every_n_steps: 1 logger: # - class_path: lightning.pytorch.loggers.CSVLogger # init_args: @@ -59,8 +61,8 @@ executor: experiment_name: 3DGAN save_dir: ml_logs/mlflow_logs log_model: all - max_epochs: 1 - max_steps: 20 + max_epochs: 5 + # max_steps: 2000 max_time: null min_epochs: null min_steps: null @@ -69,7 +71,7 @@ executor: plugins: null profiler: null reload_dataloaders_every_n_epochs: 0 - strategy: ddp_find_unused_parameters_true #auto + strategy: auto #ddp_find_unused_parameters_true #auto sync_batchnorm: false use_distributed_sampler: true val_check_interval: null @@ -79,7 +81,7 @@ executor: class_path: model.ThreeDGAN init_args: latent_size: 256 - batch_size: 64 + batch_size: 4 loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 @@ -90,4 +92,6 @@ executor: class_path: dataloader.ParticlesDataModule init_args: datapath: exp_data/*/*.h5 - batch_size: 64 + batch_size: 4 + num_workers: 0 + max_samples: 12