diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 60b2bd6f..1a114dfa 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -32,13 +32,12 @@ import torch.optim as optim import torchvision from lightning.pytorch.cli import LightningCLI -from pydantic.utils import deep_update from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, Dataset, Sampler from torch.utils.data.distributed import DistributedSampler -from itwinai.torch.tuning import get_raytune_schedule, get_raytune_search_alg +from itwinai.torch.tuning import get_raytune_scheduler, get_raytune_search_alg # Imports from this repository from ..components import Trainer, monitor_exec @@ -59,34 +58,6 @@ from .reproducibility import seed_worker, set_seed from .type import Batch, LrScheduler, Metric -DEFAULT_RAY_CONFIG = { - "scaling_config": { - "num_workers": 4, # Default to 4 workers - "use_gpu": True, - "resources_per_worker": {"CPU": 5, "GPU": 1}, - }, - "tune_config": { - "num_samples": 1, # Number of trials to run, increase for more thorough tuning - "metric": "loss", - "mode": "min", - }, - "run_config": { - "checkpoint_at_end": True, # Save checkpoint at the end of each trial - "checkpoint_freq": 10, # Save checkpoint every 10 iterations - "storage_path": "ray_results", # Directory to save results, logs, and checkpoints - }, - "train_loop_config": { - "learning_rate": 1e-3, - "batch_size": 32, - "epochs": 10, - "optimizer": "adam", - "loss": "cross_entropy", - "optim_momentum": 0.9, - "optim_weight_decay": 0, - "random_seed": 21, - }, -} - class TorchTrainer(Trainer, LogMixin): """Trainer class for torch training algorithms. @@ -1275,17 +1246,18 @@ class RayTorchTrainer(Trainer): def __init__( self, config: Dict, - strategy: Optional[Literal["ddp", "deepspeed"]] = "ddp", - name: Optional[str] = None, - logger: Optional[Logger] = None, + strategy: Literal["ddp", "deepspeed"] = "ddp", + name: str | None = None, + logger: Logger | None = None, + random_seed: int = 1234, ) -> None: super().__init__(name=name) self.logger = logger self._set_strategy_and_init_ray(strategy) self._set_configs(config=config) - self.torch_rng = set_seed(self.train_loop_config["random_seed"]) + self.torch_rng = set_seed(random_seed) - def _set_strategy_and_init_ray(self, strategy: str): + def _set_strategy_and_init_ray(self, strategy: str) -> None: """Set the distributed training strategy. This will initialize the ray backend. Args: @@ -1302,9 +1274,8 @@ def _set_strategy_and_init_ray(self, strategy: str): else: raise ValueError(f"Unsupported strategy: {strategy}") - def _set_configs(self, config: Dict): - # TODO: Think about how to implement the config more nicely - self.config = deep_update(DEFAULT_RAY_CONFIG, config) + def _set_configs(self, config: Dict) -> None: + self.config = config self._set_scaling_config() self._set_tune_config() self._set_run_config() @@ -1391,8 +1362,8 @@ def create_dataloaders( def execute( self, train_dataset: Dataset, - validation_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, + validation_dataset: Dataset | None = None, + test_dataset: Dataset | None = None, ) -> Tuple[Dataset, Dataset, Dataset, Any]: """Execute the training pipeline with the given datasets. @@ -1429,130 +1400,111 @@ def set_epoch(self, epoch: int) -> None: if self.test_dataloader is not None: self.test_dataloader.sampler.set_epoch(epoch) - def _set_tune_config(self): - tune_config = self.config["tune_config"] + def _set_tune_config(self) -> None: + tune_config = self.config.get("tune_config", {}) if not tune_config: print( - "INFO: Empty Tune Config configured. Using the default configuration with " + "WARNING: Empty Tune Config configured. Using the default configuration with " "a single trial." ) - num_samples = tune_config.get("num_samples", 1) + search_alg = get_raytune_search_alg(tune_config) + scheduler = get_raytune_scheduler(tune_config) + metric = tune_config.get("metric", "loss") mode = tune_config.get("mode", "min") - search_alg = get_raytune_search_alg(tune_config) - scheduler = get_raytune_schedule(tune_config) - - # Only set metric and mode if search_alg and scheduler aren't defined - self.tune_config = tune.TuneConfig( - num_samples=num_samples, - metric=metric, - mode=mode, - search_alg=search_alg, - scheduler=scheduler, - ) + try: + self.tune_config = tune.TuneConfig( + **tune_config, + search_alg=search_alg, + scheduler=scheduler, + metric=metric, + mode=mode, + ) + except AttributeError as e: + print( + "Could not set Tune Config. Please ensure that you have passed the " + "correct arguments for it. You can find more information for which " + "arguments to set at " + "https://docs.ray.io/en/latest/tune/api/doc/ray.tune.TuneConfig.html." + ) + print(e) - def _set_scaling_config(self): - scaling_config = self.config["scaling_config"] + def _set_scaling_config(self) -> None: + scaling_config = self.config.get("scaling_config", {}) if not scaling_config: - print("INFO: Empty Scaling Config configured. Running trials non-distributed.") - self.scaling_config = None - return - - self.scaling_config = ray.train.ScalingConfig(**self.config["scaling_config"]) - - def _set_run_config(self): - run_config = self.config["run_config"] - - if not run_config: - print("INFO: Empty RunConfig provided. Assuming local or single-node execution.") - self.run_config = None - return - - storage_path = Path(run_config.get("storage_path")) - if not storage_path: - print("INFO: Empty storage path provided. Using default path 'ray_checkpoints'") - storage_path = Path("ray_checkpoints") - - self.run_config = ray.train.RunConfig(storage_path=Path.absolute(storage_path)) + print("WARNING: No Scaling Config configured. Running trials non-distributed.") - def _set_train_loop_config(self): - train_loop_config = self.config["train_loop_config"] - - if train_loop_config: - self.train_loop_config = self._set_searchspace(train_loop_config) - else: + try: + self.scaling_config = ray.train.ScalingConfig(**scaling_config) + except AttributeError as e: print( - "INFO: No training_loop_config detected. " - "No parameters are being tuned or passed to the training function." + "Could not set Scaling Config. Please ensure that you have passed the " + "correct arguments for it. You can find more information for which " + "arguments to set at " + "https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html" ) + print(e) - def _set_searchspace(self, train_loop_dict: Dict): - train_loop_config = {} - - for name, values in train_loop_dict.items(): - if not isinstance(values, dict): - # Constant parameters can be added as-is - train_loop_config[name] = values - continue - - param_type = values.get("type") - - if param_type == "choice": - train_loop_config[name] = tune.choice(values["options"]) + def _set_run_config(self) -> None: + run_config = self.config.get("run_config", {}) - elif param_type == "uniform": - train_loop_config[name] = tune.uniform( - float(values["min"]), float(values["max"]) - ) - - elif param_type == "quniform": - train_loop_config[name] = tune.quniform( - values["min"], values["max"], values["q"] - ) - - elif param_type == "loguniform": - train_loop_config[name] = tune.loguniform(values["min"], values["max"]) - - elif param_type == "qloguniform": - train_loop_config[name] = tune.qloguniform( - values["min"], values["max"], values["q"] - ) + if not run_config: + print("WARNING: No RunConfig provided. Assuming local or single-node execution.") - elif param_type == "randint": - train_loop_config[name] = tune.randint(values["min"], values["max"]) + try: + storage_path = Path(run_config.pop("storage_path")).resolve() - elif param_type == "qrandint": - train_loop_config[name] = tune.qrandint( - values["min"], values["max"], values["q"] + if not storage_path: + print( + "INFO: Empty storage path provided. Using default path 'ray_checkpoints'" ) + storage_path = Path("ray_checkpoints").resolve() - elif param_type == "lograndint": - train_loop_config[name] = tune.lograndint(values["min"], values["max"]) + self.run_config = ray.train.RunConfig(**run_config, storage_path=storage_path) + except AttributeError as e: + print( + "Could not set Run Config. Please ensure that you have passed the " + "correct arguments for it. You can find more information for which " + "arguments to set at " + "https://docs.ray.io/en/latest/train/api/doc/ray.train.RunConfig.html" + ) + print(e) - elif param_type == "qlograndint": - train_loop_config[name] = tune.qlograndint( - values["min"], values["max"], values["q"] - ) + def _set_train_loop_config(self) -> None: + self.train_loop_config = self.config.get("train_loop_config", {}) - elif param_type == "randn": - train_loop_config[name] = tune.randn(values["mean"], values["stddev"]) + if not self.train_loop_config: + print( + "WARNING: No training_loop_config detected. " + "If you want to tune any hyperparameters, make sure to define them here." + ) + return - elif param_type == "qrandn": - train_loop_config[name] = tune.qrandn( - values["mean"], values["stddev"], values["q"] - ) + try: + for name, param in self.train_loop_config.items(): + if not isinstance(param, dict): + continue - elif param_type == "grid_search": - train_loop_config[name] = tune.grid_search(values["options"]) + # Convert specific keys to float if necessary + for key in ["lower", "upper", "mean", "std"]: + if key in param: + param[key] = float(param[key]) - else: - raise ValueError(f"Unsupported search space type: {param_type}") + param_type = param.pop("type") + param = getattr(tune, param_type)(**param) + self.train_loop_config[name] = param - return train_loop_config + except AttributeError as e: + print( + f"{param} could not be set. Check that this parameter type is " + "supported by Ray Tune at " + "https://docs.ray.io/en/latest/tune/api/search_space.html" + ) + print(e) # TODO: Can I also log the checkpoint? def checkpoint_and_report(self, epoch, tuning_metrics, checkpointing_data=None): diff --git a/src/itwinai/torch/tuning.py b/src/itwinai/torch/tuning.py index 1cf10872..25445df5 100644 --- a/src/itwinai/torch/tuning.py +++ b/src/itwinai/torch/tuning.py @@ -1,3 +1,14 @@ +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Anna Lappe +# +# Credit: +# - Anna Lappe - CERN +# -------------------------------------------------------------------------------------- + +from typing import Dict + from ray.tune.schedulers import ( AsyncHyperBandScheduler, HyperBandForBOHB, @@ -5,14 +16,29 @@ PopulationBasedTraining, ) from ray.tune.schedulers.pb2 import PB2 # Population Based Bandits +from ray.tune.search.ax import AxSearch from ray.tune.search.bayesopt import BayesOptSearch from ray.tune.search.bohb import TuneBOHB +from ray.tune.search.hebo import HEBOSearch from ray.tune.search.hyperopt import HyperOptSearch +from ray.tune.search.nevergrad import NevergradSearch +from ray.tune.search.optuna import OptunaSearch +from ray.tune.search.zoopt import ZOOptSearch def get_raytune_search_alg( - tune_config, seeds=False -) -> TuneBOHB | BayesOptSearch | HyperOptSearch | None: + tune_config: Dict, +) -> ( + TuneBOHB + | BayesOptSearch + | HyperOptSearch + | AxSearch + | HEBOSearch + | NevergradSearch + | OptunaSearch + | ZOOptSearch + | None +): """Get the appropriate Ray Tune search algorithm based on the provided configuration. Args: @@ -23,59 +49,64 @@ def get_raytune_search_alg( Returns: An instance of the chosen Ray Tune search algorithm or None if no search algorithm is - used or if the search algorithm does not match any of the supported options. - - Notes: - - `TuneBOHB` is automatically chosen for BOHB scheduling. + used or if the search algorithm does not match any of the supported options. """ - scheduler = tune_config.get("scheduler", {}).get("name") + scheduler_name = tune_config.get("scheduler", {}).get("name", "") - search_alg = tune_config.get("search_alg", {}).get("name") - - if (scheduler == "pbt") or (scheduler == "pb2"): - if search_alg is None: - return None - else: + match scheduler_name.lower(): + case "pbt" | "pb2": print( - "INFO: Using schedule '{}' \ - is not compatible with Ray Tune search algorithms.".format(scheduler) + f"INFO: Using scheduler {scheduler_name} " + "is not compatible with Ray Tune search algorithms." ) + print(f"Using the Ray Tune {scheduler_name} scheduler without search algorithm") + return None + + case "bohb": print( - "INFO: Using the Ray Tune '{}' scheduler without search algorithm".format( - scheduler - ) + "INFO: Using TuneBOHB search algorithm since it is required for BOHB " + "scheduler." ) + return TuneBOHB() - if (scheduler == "bohb") or (scheduler == "BOHB"): - print("INFO: Using TuneBOHB search algorithm since it is required for BOHB shedule") - if seeds: - seed = 1234 - else: - seed = None - return TuneBOHB( - seed=seed, - ) + search_alg = tune_config.pop("search_alg", {}) + search_alg_name = search_alg.pop("name", "") - # requires pip install bayesian-optimization - if search_alg == "bayes": - print("INFO: Using BayesOptSearch") - return BayesOptSearch( - random_search_steps=tune_config["search_alg"]["n_random_steps"], - ) - - # requires pip install hyperopt - if search_alg == "hyperopt": - print("INFO: Using HyperOptSearch") - return HyperOptSearch( - n_initial_points=tune_config["search_alg"]["n_random_steps"], - # points_to_evaluate=, + try: + match search_alg_name.lower(): + case "ax": + return AxSearch() + case "bayesopt": + return BayesOptSearch(**search_alg) + case "hyperopt": + return HyperOptSearch(**search_alg) + case "bohb": + return TuneBOHB(**search_alg) + case "hepo": + return HEBOSearch(**search_alg) + case "nevergrad": + return NevergradSearch(**search_alg) + case "optuna": + return OptunaSearch(**search_alg) + case "zoo": + return ZOOptSearch(**search_alg) + case _: + print( + "INFO: No search algorithm detected. Using Ray Tune BasicVariantGenerator." + ) + return None + except AttributeError as e: + print( + "Invalid search algorithm configuration passed. Please make sure that the search " + "algorithm you are using has the correct attributes. You can read more about the " + "different search algorithms supported by Ray Tune at " + "https://docs.ray.io/en/latest/tune/api/suggestion.html. " ) - print("INFO: Not using any Ray Tune search algorithm") - return None + print(e) -def get_raytune_schedule( - tune_config, +def get_raytune_scheduler( + tune_config: Dict, ) -> ( AsyncHyperBandScheduler | HyperBandScheduler @@ -89,48 +120,37 @@ def get_raytune_schedule( Args: tune_config (Dict): Configuration dictionary specifying the scheduler type, metric, mode, and, depending on the scheduler, other parameters. + Returns: An instance of the chosen Ray Tune scheduler or None if no scheduler is used - or if the scheduler does not match any of the supported options. + or if the scheduler does not match any of the supported options. """ - scheduler = tune_config.get("scheduler", {}).get("name") - - if scheduler == "asha": - return AsyncHyperBandScheduler( - time_attr="training_iteration", - max_t=tune_config["scheduler"]["max_t"], - grace_period=tune_config["scheduler"]["grace_period"], - reduction_factor=tune_config["scheduler"]["reduction_factor"], - brackets=tune_config["scheduler"]["brackets"], - ) - elif scheduler == "hyperband": - return HyperBandScheduler( - time_attr="training_iteration", - max_t=tune_config["scheduler"]["max_t"], - reduction_factor=tune_config["scheduler"]["reduction_factor"], - ) - # requires pip install hpbandster ConfigSpace - elif (scheduler == "bohb") or (scheduler == "BOHB"): - return HyperBandForBOHB( - time_attr="training_iteration", - max_t=tune_config["scheduler"]["max_t"], - reduction_factor=tune_config["scheduler"]["reduction_factor"], - ) - elif (scheduler == "pbt") or (scheduler == "PBT"): - return PopulationBasedTraining( - time_attr="training_iteration", - perturbation_interval=tune_config["scheduler"]["perturbation_interval"], - hyperparam_mutations=tune_config["scheduler"]["hyperparam_mutations"], - log_config=True, - ) - # requires pip install GPy sklearn - elif (scheduler == "pb2") or (scheduler == "PB2"): - return PB2( - time_attr="training_iteration", - perturbation_interval=tune_config["scheduler"]["perturbation_interval"], - hyperparam_bounds=tune_config["scheduler"]["hyperparam_bounds"], - log_config=True, + + scheduler = tune_config.pop("scheduler", {}) + scheduler_name = scheduler.pop("name", "") + + try: + match scheduler_name.lower(): + case "asha": + return AsyncHyperBandScheduler(**scheduler) + case "hyperband": + return HyperBandScheduler(**scheduler) + case "bohb": + return HyperBandForBOHB(**scheduler) + case "pbt": + return PopulationBasedTraining(**scheduler) + case "pb2": + return PB2(**scheduler) + case _: + print( + "INFO: No search algorithm detected. Using default Ray Tune FIFOScheduler." + ) + return None + except AttributeError as e: + print( + "Invalid scheduler configuration passed. Please make sure that the scheduler " + "you are using has the correct attributes. You can read more about the " + "different schedulers supported by Ray Tune at " + "https://docs.ray.io/en/latest/tune/api/schedulers.html." ) - else: - print("INFO: Not using any Ray Tune trial scheduler.") - return None + print(e) diff --git a/tutorials/hpo-workflows/config.yaml b/tutorials/hpo-workflows/config.yaml deleted file mode 100644 index 01dbb8cc..00000000 --- a/tutorials/hpo-workflows/config.yaml +++ /dev/null @@ -1,53 +0,0 @@ -hpo_training_pipeline: - class_path: itwinai.pipeline.Pipeline - init_args: - steps: - - class_path: data.FashionMNISTGetter - - class_path: data.FashionMNISTSplitter - init_args: - train_proportion: 0.9 - validation_proportion: 0.1 - test_proportion: 0.0 - - class_path: trainer.MyRayTorchTrainer - init_args: - config: - scaling_config: - num_workers: 2 - use_gpu: true - resources_per_worker: - CPU: 6 - GPU: 1 - train_loop_config: - batch_size: - type: choice - options: [32, 64, 128] - learning_rate: - type: uniform - min: 1e-5 - max: 1e-3 - epochs: 10 - tune_config: - num_samples: 4 - scheduler: - name: asha - max_t: 10 - grace_period: 5 - reduction_factor: 4 - brackets: 1 - # search_alg: - # name: bayes - # metric: loss - # mode: min - # n_random_steps: 5 - run_config: - storage_path: ray_checkpoints - name: Virgo-HPO-Experiment - strategy: deepspeed - logger: - class_path: itwinai.loggers.LoggersCollection - init_args: - loggers: - - class_path: itwinai.loggers.MLFlowLogger - init_args: - experiment_name: MNIST HPO Experiment - log_freq: batch \ No newline at end of file diff --git a/tutorials/hpo-workflows/data.py b/tutorials/hpo-workflows/data.py deleted file mode 100644 index c53739dc..00000000 --- a/tutorials/hpo-workflows/data.py +++ /dev/null @@ -1,91 +0,0 @@ -import argparse -import sys -from pathlib import Path -from typing import Tuple - -from torch.utils.data import Dataset, random_split -from torchvision import datasets, transforms - -from itwinai.components import DataGetter, DataSplitter - -data_dir = Path("data") - - -def download_fashion_mnist() -> None: - """Download the FashionMNIST dataset using torchvision.""" - print("Downloading FashionMNIST dataset...") - datasets.FashionMNIST( - data_dir, - train=True, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ), - ) - datasets.FashionMNIST( - data_dir, - train=False, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ), - ) - print("Download complete!") - - -class FashionMNISTGetter(DataGetter): - def __init__(self) -> None: - super().__init__() - - def execute(self) -> Tuple[Dataset, Dataset]: - """Load the FashionMNIST dataset from the specified directory.""" - print("Loading FashionMNIST dataset...") - train_dataset = datasets.FashionMNIST( - data_dir, - train=True, - download=False, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ), - ) - print("Loading complete!") - return train_dataset - - -class FashionMNISTSplitter(DataSplitter): - def __init__( - self, - train_proportion: float, - validation_proportion: float, - test_proportion: float, - name: str | None = None, - ) -> None: - super().__init__(train_proportion, validation_proportion, test_proportion, name) - - def execute(self, dataset: Dataset) -> Tuple[Dataset, Dataset, Dataset]: - """Split the dataset into train, validation, and test sets.""" - print("Splitting dataset...") - total_size = len(dataset) - train_size = int(self.train_proportion * total_size) - validation_size = int(self.validation_proportion * total_size) - test_size = total_size - train_size - validation_size - - train_dataset, validation_dataset, test_dataset = random_split( - dataset, [train_size, validation_size, test_size] - ) - print("Splitting complete!") - return train_dataset, validation_dataset, test_dataset - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="FashionMNIST Dataset Loader") - parser.add_argument( - "--download_only", - action="store_true", - help="Download the FashionMNIST dataset and exit", - ) - args = parser.parse_args() - - if args.download_only: - download_fashion_mnist() - sys.exit() diff --git a/tutorials/hpo-workflows/slurm_hpo.sh b/tutorials/hpo-workflows/slurm_hpo.sh deleted file mode 100644 index 3570a4e6..00000000 --- a/tutorials/hpo-workflows/slurm_hpo.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash - -# Job configuration -#SBATCH --job-name=ray_tune_hpo -#SBATCH --account=intertwin -#SBATCH --time=00:10:00 -#SBATCH --partition=develbooster - -# Resources allocation -#SBATCH --cpus-per-task=32 -#SBATCH --nodes=2 -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --exclusive - -# Output and error logs -#SBATCH -o logs_slurm/hpo-job.out -#SBATCH -e logs_slurm/hpo-job.err - -# Load environment modules -ml --force purge -ml Stages/2024 GCC/12.3.0 OpenMPI CUDA/12 MPI-settings/CUDA -ml Python/3.11 HDF5 PnetCDF libaio mpi4py CMake cuDNN/8.9.5.29-CUDA-12 - -# Set and activate virtual environment -PYTHON_VENV="../../envAI_juwels" -source $PYTHON_VENV/bin/activate - -# make sure CUDA devices are visible -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -num_gpus=$SLURM_GPUS_PER_NODE -num_cpus=$SLURM_CPUS_PER_TASK - -# This tells Tune to not change the working directory to the trial directory -# which makes relative paths accessible from inside a trial -export RAY_CHDIR_TO_TRIAL_DIR=0 -export RAY_DEDUP_LOGS=0 -export RAY_USAGE_STATS_DISABLE=1 - -######### Set up Ray cluster ######## - -# Get the node names -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -mapfile -t nodes_array <<< "$nodes" - -# The head node will act as the central manager (head) of the Ray cluster. -head_node=${nodes_array[0]} -port=7639 # This port will be used by Ray to communicate with worker nodes. - -echo "Starting HEAD at $head_node" -# Start Ray on the head node. -# The `--head` option specifies that this node will be the head of the Ray cluster. -# `srun` submits a job that runs on the head node to start the Ray head with the specified -# number of CPUs and GPUs. -srun --nodes=1 --ntasks=1 -w "$head_node" \ - ray start --head --node-ip-address="$head_node"i --port=$port \ - --num-cpus "$num_cpus" --num-gpus "$num_gpus" --block & - -# Wait for a few seconds to ensure that the head node has fully initialized. -sleep 10 - -echo HEAD node started. - -# Start Ray worker nodes -# These nodes will connect to the head node and become part of the Ray cluster. -worker_num=$((SLURM_JOB_NUM_NODES - 1)) # Total number of worker nodes (excl the head node) -for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} # Get the current worker node hostname. - echo "Starting WORKER $i at $node_i" - - # Use srun to start Ray on the worker node and connect it to the head node. - # The `--address` option tells the worker node where to find the head node. - srun --nodes=1 --ntasks=1 -w "$node_i" \ - ray start --address "$head_node"i:"$port" --redis-password='5241580000000000' \ - --num-cpus "$num_cpus" --num-gpus "$num_gpus" --block & - - sleep 5 # Wait before starting the next worker to prevent race conditions. -done -echo All Ray workers started. - -############################################################################################## - -# Run the Python script using Ray -echo 'Starting HPO.' - -# Run pipeline -$PYTHON_VENV/bin/itwinai exec-pipeline --config config.yaml --pipe-key hpo_training_pipeline - -# Shutdown Ray after completion -ray stop \ No newline at end of file diff --git a/tutorials/hpo-workflows/trainer.py b/tutorials/hpo-workflows/trainer.py deleted file mode 100644 index d9636f76..00000000 --- a/tutorials/hpo-workflows/trainer.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Dict, Literal - -import numpy as np -import torch -from torch.nn import CrossEntropyLoss -from torch.optim import Adam -from torchvision.models import resnet18 - -from itwinai.loggers import Logger -from itwinai.torch.distributed import RayDeepSpeedStrategy -from itwinai.torch.trainer import RayTorchTrainer - - -class MyRayTorchTrainer(RayTorchTrainer): - def __init__( - self, - config: Dict, - strategy: Literal["ddp", "deepspeed"] = "ddp", - name: str | None = None, - logger: Logger | None = None, - ) -> None: - super().__init__(config=config, strategy=strategy, name=name, logger=logger) - - def create_model_loss_optimizer(self): - model = resnet18(num_classes=10) - model.conv1 = torch.nn.Conv2d( - 1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False - ) - # First, define strategy-wise optional configurations - if isinstance(self.strategy, RayDeepSpeedStrategy): - distribute_kwargs = dict( - config_params=dict( - train_micro_batch_size_per_gpu=self.training_config["batch_size"] - ) - ) - else: - distribute_kwargs = {} - optimizer = Adam(model.parameters(), lr=self.training_config["learning_rate"]) - self.model, self.optimizer, _ = self.strategy.distributed( - model, optimizer, **distribute_kwargs - ) - self.loss = CrossEntropyLoss() - - def train(self, config, data): - self.training_config = config - - # Because of the way the ray cluster is set up, - # the initialisation of the strategy and logger, as well as the creation of the - # model, loss, optimizer and dataloader are done from within the train() function - self.strategy.init() - self.initialize_logger( - hyperparams=self.training_config, rank=self.strategy.global_rank() - ) - self.create_model_loss_optimizer() - self.create_dataloaders( - train_dataset=data[0], validation_dataset=data[1], test_dataset=data[2] - ) - - for epoch in range(self.training_config["epochs"]): - if self.strategy.global_world_size() > 1: - self.set_epoch(epoch) - - train_losses = [] - val_losses = [] - - for images, labels in self.train_dataloader: - if isinstance(self.strategy, RayDeepSpeedStrategy): - device = self.strategy.device() - images, labels = images.to(device), labels.to(device) - outputs = self.model(images) - train_loss = self.loss(outputs, labels) - self.optimizer.zero_grad() - train_loss.backward() - self.optimizer.step() - train_losses.append(train_loss.detach().cpu().numpy()) - - for images, labels in self.validation_dataloader: - if isinstance(self.strategy, RayDeepSpeedStrategy): - device = self.strategy.device() - images, labels = images.to(device), labels.to(device) - with torch.no_grad(): - outputs = self.model(images) - val_loss = self.loss(outputs, labels) - val_losses.append(val_loss.detach().cpu().numpy()) - - self.log(np.mean(train_losses), "train_loss", kind="metric", step=epoch) - self.log(np.mean(val_losses), "val_loss", kind="metric", step=epoch) - checkpoint = { - "epoch": epoch, - "loss": train_loss, - "val_loss": val_loss, - } - metrics = {"loss": val_loss.item()} - self.checkpoint_and_report( - epoch, tuning_metrics=metrics, checkpointing_data=checkpoint - ) diff --git a/use-cases/virgo/config.yaml b/use-cases/virgo/config.yaml index 2fc72b61..1af23a55 100644 --- a/use-cases/virgo/config.yaml +++ b/use-cases/virgo/config.yaml @@ -140,20 +140,18 @@ ray_training_pipeline: train_loop_config: batch_size: type: choice - options: [1, 2, 4] - learning_rate: + categories: [1, 2, 4] + optim_lr: type: uniform - min: 1e-5 - max: 1e-3 - epochs: 5 + lower: 1e-5 + upper: 1e-3 + num_epochs: 5 generator: simple #unet loss: l1 save_best: false shuffle_train: true - random_seed: 17 - tracking_uri: mllogs/mlflow - experiment_name: Virgo-HPO-Experiment strategy: ${strategy} + random_seed: 17 logger: class_path: itwinai.loggers.LoggersCollection init_args: diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index af96290b..6d402265 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -378,7 +378,7 @@ def train(self): epoch_time_tracker.add_epoch_time(epoch - 1, timer() - lt) # Report training metrics of last epoch to Ray - train.report({"loss": np.mean(val_loss), "train_loss": np.mean(epoch_loss)}) + train.report({"loss": np.mean(val_loss)}) return loss_plot, val_loss_plot, acc_plot, val_acc_plot @@ -390,12 +390,15 @@ def __init__( strategy: Optional[Literal["ddp", "deepspeed"]] = "ddp", name: Optional[str] = None, logger: Optional[Logger] = None, + random_seed: int = 1234, ) -> None: - super().__init__(config=config, strategy=strategy, name=name, logger=logger) + super().__init__( + config=config, strategy=strategy, name=name, logger=logger, random_seed=random_seed + ) def create_model_loss_optimizer(self) -> None: # Select generator - generator = self.training_config["generator"] + generator = self.training_config.generator scaling = 0.02 if generator == "simple": self.model = Decoder(3, norm=False) @@ -412,7 +415,7 @@ def create_model_loss_optimizer(self) -> None: init_weights(self.model, "normal", scaling=scaling) # Select loss - loss = self.training_config["loss"] + loss = self.training_config.loss if loss == "l1": self.loss = nn.L1Loss() elif loss == "l2": @@ -421,17 +424,16 @@ def create_model_loss_optimizer(self) -> None: raise ValueError("Unrecognized loss type! Got", loss) # Optimizer + print(type(self.training_config.optim_lr), self.training_config.optim_lr) self.optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.training_config["learning_rate"] + self.model.parameters(), lr=self.training_config.optim_lr ) # First, define strategy-wise optional configurations if isinstance(self.strategy, RayDeepSpeedStrategy): # Batch size definition is not optional for DeepSpeedStrategy! distribute_kwargs = dict( - config_params=dict( - train_micro_batch_size_per_gpu=self.training_config["batch_size"] - ) + config_params=dict(train_micro_batch_size_per_gpu=self.config.batch_size) ) else: distribute_kwargs = {} @@ -458,7 +460,7 @@ def train(self, config, data): # Start the timer for profiling st = timer() - self.training_config = config + self.training_config = VirgoTrainingConfiguration(**config) self.create_model_loss_optimizer() @@ -482,7 +484,7 @@ def train(self, config, data): val_acc_plot = [] best_val_loss = float("inf") - for epoch in tqdm(range(self.training_config["epochs"])): + for epoch in tqdm(range(self.training_config.num_epochs)): # lt = timer() if self.strategy.global_world_size() > 1: