diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 68216a00..28079205 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -1468,18 +1468,23 @@ def _set_train_loop_config(self): train_loop_config = self.config.get("train_loop_config", {}) if train_loop_config: - for name, param in train_loop_config.items(): - try: + try: + for name, param in train_loop_config.items(): if isinstance(param, dict): + # Convert specific keys to float if necessary + for key in ["lower", "upper", "mean", "std"]: + if key in param: + param[key] = float(param[key]) + param_type = param.pop("type") - # Dynamically call corresponding tune method param = getattr(tune, param_type)(**param) train_loop_config[name] = param - except AttributeError: - raise ValueError( - f"{param} could not be set. Check that this parameter type is " - "supported by Ray Tune or the itwinai TrainingConfiguration." - ) + + except AttributeError: + print( + f"{param} could not be set. Check that this parameter type is " + "supported by Ray Tune or the itwinai TrainingConfiguration." + ) else: print( "WARNING: No training_loop_config detected. " diff --git a/use-cases/virgo/config.yaml b/use-cases/virgo/config.yaml index 30ce4dda..9ad9ef89 100644 --- a/use-cases/virgo/config.yaml +++ b/use-cases/virgo/config.yaml @@ -126,12 +126,12 @@ ray_training_pipeline: GPU: 1 tune_config: num_samples: 2 - # scheduler: - # name: asha - # max_t: 5 - # grace_period: 2 - # reduction_factor: 6 - # brackets: 1 + scheduler: + name: asha + max_t: 5 + grace_period: 2 + reduction_factor: 6 + brackets: 1 run_config: storage_path: ray_checkpoints name: Virgo-HPO-Experiment @@ -144,10 +144,10 @@ ray_training_pipeline: lower: 1e-5 upper: 1e-3 num_epochs: 5 - generator: simple #unet - loss: l1 - save_best: false - shuffle_train: true + generator: simple #unet + loss: l1 + save_best: false + shuffle_train: true strategy: ${strategy} random_seed: 17 logger: