Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
annaelisalappe committed Dec 13, 2024
1 parent 7f132db commit a3ab518
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
21 changes: 13 additions & 8 deletions src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,18 +1468,23 @@ def _set_train_loop_config(self):
train_loop_config = self.config.get("train_loop_config", {})

if train_loop_config:
for name, param in train_loop_config.items():
try:
try:
for name, param in train_loop_config.items():
if isinstance(param, dict):
# Convert specific keys to float if necessary
for key in ["lower", "upper", "mean", "std"]:
if key in param:
param[key] = float(param[key])

param_type = param.pop("type")
# Dynamically call corresponding tune method
param = getattr(tune, param_type)(**param)
train_loop_config[name] = param
except AttributeError:
raise ValueError(
f"{param} could not be set. Check that this parameter type is "
"supported by Ray Tune or the itwinai TrainingConfiguration."
)

except AttributeError:
print(
f"{param} could not be set. Check that this parameter type is "
"supported by Ray Tune or the itwinai TrainingConfiguration."
)
else:
print(
"WARNING: No training_loop_config detected. "
Expand Down
20 changes: 10 additions & 10 deletions use-cases/virgo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ ray_training_pipeline:
GPU: 1
tune_config:
num_samples: 2
# scheduler:
# name: asha
# max_t: 5
# grace_period: 2
# reduction_factor: 6
# brackets: 1
scheduler:
name: asha
max_t: 5
grace_period: 2
reduction_factor: 6
brackets: 1
run_config:
storage_path: ray_checkpoints
name: Virgo-HPO-Experiment
Expand All @@ -144,10 +144,10 @@ ray_training_pipeline:
lower: 1e-5
upper: 1e-3
num_epochs: 5
generator: simple #unet
loss: l1
save_best: false
shuffle_train: true
generator: simple #unet
loss: l1
save_best: false
shuffle_train: true
strategy: ${strategy}
random_seed: 17
logger:
Expand Down

0 comments on commit a3ab518

Please sign in to comment.