Skip to content

Commit

Permalink
UPDATE: load components from pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 23, 2023
1 parent 8e19c62 commit d3a2630
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
36 changes: 22 additions & 14 deletions src/itwinai/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,27 @@ def _pack_args(self, args) -> Tuple:
return args


def recursive_replace(config: Dict, target_field: str, new_value: Any) -> None:
def _recursive_replace_key(sub_dict: Dict):
if not isinstance(sub_dict, dict):
return
for k, v in sub_dict.items():
if k == target_field:
sub_dict[k] = new_value
return
else:
_recursive_replace_key(v)
_recursive_replace_key(config)
def add_replace_field(
config: Dict,
key_chain: str,
value: Any
) -> None:
sub_config = config
for idx, k in enumerate(key_chain.split('.')):
if idx >= len(key_chain.split('.')) - 1:
# Last key reached
break
if not isinstance(sub_config.get(k), dict):
sub_config[k] = dict()
sub_config = sub_config[k]
sub_config[k] = value


def load_pipeline_step(
pipe: Union[str, Dict],
step_id: Union[str, int],
override_keys: Optional[Dict[str, Any]] = None
override_keys: Optional[Dict[str, Any]] = None,
verbose: bool = False
) -> Executable:
if isinstance(pipe, str):
# Load pipe from YAML file path
Expand All @@ -369,8 +373,12 @@ def load_pipeline_step(

# Override fields
if override_keys is not None:
for key, value in override_keys.items():
recursive_replace(step_dict_config, key, value)
for key_chain, value in override_keys.items():
add_replace_field(step_dict_config, key_chain, value)
if verbose:
import json
print(f"NEW STEP <ID:{step_id}> CONFIG:")
print(json.dumps(step_dict_config, indent=4))

# Wrap config under "step" field and parse it
step_dict_config = dict(step=step_dict_config)
Expand Down
20 changes: 12 additions & 8 deletions use-cases/3dgan/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ executor:
class_path: itwinai.components.Executor
init_args:
steps:
- class_path: dataloader.Lightning3DGANDownloader
dataloading_step:
class_path: dataloader.Lightning3DGANDownloader
init_args:
data_path: exp_data/ # Set to null to skip dataset download
data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX

- class_path: trainer.Lightning3DGANTrainer
training_step:
class_path: trainer.Lightning3DGANTrainer
init_args:
# Pytorch lightning config for training
config:
Expand Down Expand Up @@ -49,7 +51,7 @@ executor:
limit_test_batches: null
limit_train_batches: null
limit_val_batches: null
log_every_n_steps: 2
log_every_n_steps: 1
logger:
# - class_path: lightning.pytorch.loggers.CSVLogger
# init_args:
Expand All @@ -59,8 +61,8 @@ executor:
experiment_name: 3DGAN
save_dir: ml_logs/mlflow_logs
log_model: all
max_epochs: 1
max_steps: 20
max_epochs: 5
# max_steps: 2000
max_time: null
min_epochs: null
min_steps: null
Expand All @@ -69,7 +71,7 @@ executor:
plugins: null
profiler: null
reload_dataloaders_every_n_epochs: 0
strategy: ddp_find_unused_parameters_true #auto
strategy: auto #ddp_find_unused_parameters_true #auto
sync_batchnorm: false
use_distributed_sampler: true
val_check_interval: null
Expand All @@ -79,7 +81,7 @@ executor:
class_path: model.ThreeDGAN
init_args:
latent_size: 256
batch_size: 64
batch_size: 4
loss_weights: [3, 0.1, 25, 0.1]
power: 0.85
lr: 0.001
Expand All @@ -90,4 +92,6 @@ executor:
class_path: dataloader.ParticlesDataModule
init_args:
datapath: exp_data/*/*.h5
batch_size: 64
batch_size: 4
num_workers: 0
max_samples: 12

0 comments on commit d3a2630

Please sign in to comment.