Skip to content

Commit

Permalink
Add more Search Algorithms and Schedulers for Tuning (#276)
Browse files Browse the repository at this point in the history
* Some refactoring, added more search algorithms and schedulers

* Some more refactoring, bugfixes

* More refactoring

* Some files must have gotten duplicated during the merge...

* Add return types back to tuning functions

* More type stuff

* Added more error messages

* PR comments
  • Loading branch information
annaelisalappe authored Dec 13, 2024
1 parent 85119c9 commit ce3f18c
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 571 deletions.
226 changes: 89 additions & 137 deletions src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@
import torch.optim as optim
import torchvision
from lightning.pytorch.cli import LightningCLI
from pydantic.utils import deep_update
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data.distributed import DistributedSampler

from itwinai.torch.tuning import get_raytune_schedule, get_raytune_search_alg
from itwinai.torch.tuning import get_raytune_scheduler, get_raytune_search_alg

# Imports from this repository
from ..components import Trainer, monitor_exec
Expand All @@ -59,34 +58,6 @@
from .reproducibility import seed_worker, set_seed
from .type import Batch, LrScheduler, Metric

DEFAULT_RAY_CONFIG = {
"scaling_config": {
"num_workers": 4, # Default to 4 workers
"use_gpu": True,
"resources_per_worker": {"CPU": 5, "GPU": 1},
},
"tune_config": {
"num_samples": 1, # Number of trials to run, increase for more thorough tuning
"metric": "loss",
"mode": "min",
},
"run_config": {
"checkpoint_at_end": True, # Save checkpoint at the end of each trial
"checkpoint_freq": 10, # Save checkpoint every 10 iterations
"storage_path": "ray_results", # Directory to save results, logs, and checkpoints
},
"train_loop_config": {
"learning_rate": 1e-3,
"batch_size": 32,
"epochs": 10,
"optimizer": "adam",
"loss": "cross_entropy",
"optim_momentum": 0.9,
"optim_weight_decay": 0,
"random_seed": 21,
},
}


class TorchTrainer(Trainer, LogMixin):
"""Trainer class for torch training algorithms.
Expand Down Expand Up @@ -1275,17 +1246,18 @@ class RayTorchTrainer(Trainer):
def __init__(
self,
config: Dict,
strategy: Optional[Literal["ddp", "deepspeed"]] = "ddp",
name: Optional[str] = None,
logger: Optional[Logger] = None,
strategy: Literal["ddp", "deepspeed"] = "ddp",
name: str | None = None,
logger: Logger | None = None,
random_seed: int = 1234,
) -> None:
super().__init__(name=name)
self.logger = logger
self._set_strategy_and_init_ray(strategy)
self._set_configs(config=config)
self.torch_rng = set_seed(self.train_loop_config["random_seed"])
self.torch_rng = set_seed(random_seed)

def _set_strategy_and_init_ray(self, strategy: str):
def _set_strategy_and_init_ray(self, strategy: str) -> None:
"""Set the distributed training strategy. This will initialize the ray backend.
Args:
Expand All @@ -1302,9 +1274,8 @@ def _set_strategy_and_init_ray(self, strategy: str):
else:
raise ValueError(f"Unsupported strategy: {strategy}")

def _set_configs(self, config: Dict):
# TODO: Think about how to implement the config more nicely
self.config = deep_update(DEFAULT_RAY_CONFIG, config)
def _set_configs(self, config: Dict) -> None:
self.config = config
self._set_scaling_config()
self._set_tune_config()
self._set_run_config()
Expand Down Expand Up @@ -1391,8 +1362,8 @@ def create_dataloaders(
def execute(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> Tuple[Dataset, Dataset, Dataset, Any]:
"""Execute the training pipeline with the given datasets.
Expand Down Expand Up @@ -1429,130 +1400,111 @@ def set_epoch(self, epoch: int) -> None:
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)

def _set_tune_config(self):
tune_config = self.config["tune_config"]
def _set_tune_config(self) -> None:
tune_config = self.config.get("tune_config", {})

if not tune_config:
print(
"INFO: Empty Tune Config configured. Using the default configuration with "
"WARNING: Empty Tune Config configured. Using the default configuration with "
"a single trial."
)

num_samples = tune_config.get("num_samples", 1)
search_alg = get_raytune_search_alg(tune_config)
scheduler = get_raytune_scheduler(tune_config)

metric = tune_config.get("metric", "loss")
mode = tune_config.get("mode", "min")

search_alg = get_raytune_search_alg(tune_config)
scheduler = get_raytune_schedule(tune_config)

# Only set metric and mode if search_alg and scheduler aren't defined
self.tune_config = tune.TuneConfig(
num_samples=num_samples,
metric=metric,
mode=mode,
search_alg=search_alg,
scheduler=scheduler,
)
try:
self.tune_config = tune.TuneConfig(
**tune_config,
search_alg=search_alg,
scheduler=scheduler,
metric=metric,
mode=mode,
)
except AttributeError as e:
print(
"Could not set Tune Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html."
)
print(e)

def _set_scaling_config(self):
scaling_config = self.config["scaling_config"]
def _set_scaling_config(self) -> None:
scaling_config = self.config.get("scaling_config", {})

if not scaling_config:
print("INFO: Empty Scaling Config configured. Running trials non-distributed.")
self.scaling_config = None
return

self.scaling_config = ray.train.ScalingConfig(**self.config["scaling_config"])

def _set_run_config(self):
run_config = self.config["run_config"]

if not run_config:
print("INFO: Empty RunConfig provided. Assuming local or single-node execution.")
self.run_config = None
return

storage_path = Path(run_config.get("storage_path"))
if not storage_path:
print("INFO: Empty storage path provided. Using default path 'ray_checkpoints'")
storage_path = Path("ray_checkpoints")

self.run_config = ray.train.RunConfig(storage_path=Path.absolute(storage_path))
print("WARNING: No Scaling Config configured. Running trials non-distributed.")

def _set_train_loop_config(self):
train_loop_config = self.config["train_loop_config"]

if train_loop_config:
self.train_loop_config = self._set_searchspace(train_loop_config)
else:
try:
self.scaling_config = ray.train.ScalingConfig(**scaling_config)
except AttributeError as e:
print(
"INFO: No training_loop_config detected. "
"No parameters are being tuned or passed to the training function."
"Could not set Scaling Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html"
)
print(e)

def _set_searchspace(self, train_loop_dict: Dict):
train_loop_config = {}

for name, values in train_loop_dict.items():
if not isinstance(values, dict):
# Constant parameters can be added as-is
train_loop_config[name] = values
continue

param_type = values.get("type")

if param_type == "choice":
train_loop_config[name] = tune.choice(values["options"])
def _set_run_config(self) -> None:
run_config = self.config.get("run_config", {})

elif param_type == "uniform":
train_loop_config[name] = tune.uniform(
float(values["min"]), float(values["max"])
)

elif param_type == "quniform":
train_loop_config[name] = tune.quniform(
values["min"], values["max"], values["q"]
)

elif param_type == "loguniform":
train_loop_config[name] = tune.loguniform(values["min"], values["max"])

elif param_type == "qloguniform":
train_loop_config[name] = tune.qloguniform(
values["min"], values["max"], values["q"]
)
if not run_config:
print("WARNING: No RunConfig provided. Assuming local or single-node execution.")

elif param_type == "randint":
train_loop_config[name] = tune.randint(values["min"], values["max"])
try:
storage_path = Path(run_config.pop("storage_path")).resolve()

elif param_type == "qrandint":
train_loop_config[name] = tune.qrandint(
values["min"], values["max"], values["q"]
if not storage_path:
print(
"INFO: Empty storage path provided. Using default path 'ray_checkpoints'"
)
storage_path = Path("ray_checkpoints").resolve()

elif param_type == "lograndint":
train_loop_config[name] = tune.lograndint(values["min"], values["max"])
self.run_config = ray.train.RunConfig(**run_config, storage_path=storage_path)
except AttributeError as e:
print(
"Could not set Run Config. Please ensure that you have passed the "
"correct arguments for it. You can find more information for which "
"arguments to set at "
"https://docs.ray.io/en/latest/train/api/doc/ray.train.RunConfig.html"
)
print(e)

elif param_type == "qlograndint":
train_loop_config[name] = tune.qlograndint(
values["min"], values["max"], values["q"]
)
def _set_train_loop_config(self) -> None:
self.train_loop_config = self.config.get("train_loop_config", {})

elif param_type == "randn":
train_loop_config[name] = tune.randn(values["mean"], values["stddev"])
if not self.train_loop_config:
print(
"WARNING: No training_loop_config detected. "
"If you want to tune any hyperparameters, make sure to define them here."
)
return

elif param_type == "qrandn":
train_loop_config[name] = tune.qrandn(
values["mean"], values["stddev"], values["q"]
)
try:
for name, param in self.train_loop_config.items():
if not isinstance(param, dict):
continue

elif param_type == "grid_search":
train_loop_config[name] = tune.grid_search(values["options"])
# Convert specific keys to float if necessary
for key in ["lower", "upper", "mean", "std"]:
if key in param:
param[key] = float(param[key])

else:
raise ValueError(f"Unsupported search space type: {param_type}")
param_type = param.pop("type")
param = getattr(tune, param_type)(**param)
self.train_loop_config[name] = param

return train_loop_config
except AttributeError as e:
print(
f"{param} could not be set. Check that this parameter type is "
"supported by Ray Tune at "
"https://docs.ray.io/en/latest/tune/api/search_space.html"
)
print(e)

# TODO: Can I also log the checkpoint?
def checkpoint_and_report(self, epoch, tuning_metrics, checkpointing_data=None):
Expand Down
Loading

0 comments on commit ce3f18c

Please sign in to comment.