Skip to content

Commit

Permalink
More type stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
annaelisalappe committed Dec 13, 2024
1 parent a68ba70 commit 54bba46
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
16 changes: 8 additions & 8 deletions src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def __init__(
self._set_configs(config=config)
self.torch_rng = set_seed(random_seed)

def _set_strategy_and_init_ray(self, strategy: str):
def _set_strategy_and_init_ray(self, strategy: str) -> None:
"""Set the distributed training strategy. This will initialize the ray backend.
Args:
Expand All @@ -1274,7 +1274,7 @@ def _set_strategy_and_init_ray(self, strategy: str):
else:
raise ValueError(f"Unsupported strategy: {strategy}")

def _set_configs(self, config: Dict):
def _set_configs(self, config: Dict) -> None:
self.config = config
self._set_scaling_config()
self._set_tune_config()
Expand Down Expand Up @@ -1362,8 +1362,8 @@ def create_dataloaders(
def execute(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Execute the training pipeline with the given datasets.
Expand Down Expand Up @@ -1400,7 +1400,7 @@ def set_epoch(self, epoch: int) -> None:
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)

def _set_tune_config(self):
def _set_tune_config(self) -> None:
tune_config = self.config.get("tune_config", {})

if not tune_config:
Expand Down Expand Up @@ -1431,7 +1431,7 @@ def _set_tune_config(self):
"https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html."
)

def _set_scaling_config(self):
def _set_scaling_config(self) -> None:
scaling_config = self.config.get("scaling_config", {})

if not scaling_config:
Expand All @@ -1447,7 +1447,7 @@ def _set_scaling_config(self):
"https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html"
)

def _set_run_config(self):
def _set_run_config(self) -> None:
run_config = self.config.get("run_config", {})

if not run_config:
Expand All @@ -1471,7 +1471,7 @@ def _set_run_config(self):
"https://docs.ray.io/en/latest/train/api/doc/ray.train.RunConfig.html"
)

def _set_train_loop_config(self):
def _set_train_loop_config(self) -> None:
train_loop_config = self.config.get("train_loop_config", {})

if train_loop_config:
Expand Down
6 changes: 4 additions & 2 deletions src/itwinai/torch/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# - Anna Lappe <[email protected]> - CERN
# --------------------------------------------------------------------------------------

from typing import Dict

from ray.tune.schedulers import (
AsyncHyperBandScheduler,
HyperBandForBOHB,
Expand All @@ -25,7 +27,7 @@


def get_raytune_search_alg(
tune_config,
tune_config: Dict,
) -> (
TuneBOHB
| BayesOptSearch
Expand Down Expand Up @@ -103,7 +105,7 @@ def get_raytune_search_alg(


def get_raytune_scheduler(
tune_config,
tune_config: Dict,
) -> (
AsyncHyperBandScheduler
| HyperBandScheduler
Expand Down

0 comments on commit 54bba46

Please sign in to comment.