From 30765e0719b281634c7a63de627d1d16c8e1823e Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:07:27 +0200 Subject: [PATCH] Scalability report update - Communication Plot [cleaned up] (#231) * make comm profiler into decorator * add tryout profiler script with mnist * add code for testing out pytorch profiler * add dummy script for monitoring and profiling mnist * add functionality for multi node GPU utilization * split code into files and create analyzation * update profiler to handle multi-gpu * add comm vs comp analysis * remove adjustable output path * create script for comm plot * add slurm script for comm calculation * update jupyter notebook * add docstrings, error handling and more comm patterns * Do data analysis * add docstrings etc. * update slurm script * make comm profiler into decorator * accommodate asymmetric runs and make table prettier * make comm profiler into decorator * add scheduler to profiler * remove regex dependency from file names in comm plot * add dynamic specification of directories for comm plot generator * small bugfix and black formatter * format code * fix linting errors * remove unused files and create new directory for gpu-monitoring * update docstrings * move imports into function in cli * move profiler to own file * move communication plot to torch folder * add deepspeed import in ds strategy * fix linting errors * remove gpu-monitoring files for this branch * add another communication entry * remove plots * fix small docstring typo * remove plots and small cleanup * move profiling files into new profiling module * move horovod imports and create new profiling module * fix diffs --------- Co-authored-by: Jarl Saether --- src/itwinai/cli.py | 267 +++++++++++------- src/itwinai/components.py | 76 ++--- src/itwinai/torch/distributed.py | 51 ++-- src/itwinai/torch/profiling/__init__.py | 0 .../torch/profiling/communication_plot.py | 232 +++++++++++++++ src/itwinai/torch/profiling/profiler.py | 102 +++++++ src/itwinai/torch/trainer.py | 75 ++--- use-cases/eurac/config.yaml | 2 +- use-cases/eurac/plots/comm_plot.png | Bin 0 -> 30280 bytes use-cases/eurac/runall.sh | 2 +- use-cases/eurac/slurm.sh | 2 +- use-cases/eurac/trainer.py | 32 ++- use-cases/virgo/trainer.py | 2 +- 13 files changed, 629 insertions(+), 214 deletions(-) create mode 100644 src/itwinai/torch/profiling/__init__.py create mode 100644 src/itwinai/torch/profiling/communication_plot.py create mode 100644 src/itwinai/torch/profiling/profiler.py create mode 100644 use-cases/eurac/plots/comm_plot.png diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index c6598694..fc42d94f 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -19,21 +19,69 @@ app = typer.Typer(pretty_exceptions_enable=False) +@app.command() +def generate_communication_plot( + log_dir: str = "profiling_logs", + pattern: str = r"profile_(\w+)_(\d+)_(\d+)\.csv$", + output_file: str = "plots/comm_plot.png", +) -> None: + """Generate stacked plot showing computation vs. communication fraction. Stores it + + Args: + log_dir: The directory where the csv logs are stored. Defauls to + ``profiling_logs``. + pattern: A regex pattern to recognize the file names in the 'log_dir' folder. + Defaults to ``profile_(\\w+)_(\\d+)_(\\d+)\\.csv$``. + output_file: The path to where the resulting plot should be saved. Defaults to + ``plots/comm_plot.png``. + """ + import matplotlib.pyplot as plt + + from itwinai.torch.profiling.communication_plot import ( + create_combined_comm_overhead_df, + create_stacked_plot, + get_comp_fraction_full_array, + ) + + log_dir_path = Path(log_dir) + if not log_dir_path.exists(): + raise IOError( + f"The directory '{log_dir_path.resolve()}' does not exist, so could not" + f"extract profiling logs. Make sure you are running this command in the " + f"same directory as the logging dir." + ) + + df = create_combined_comm_overhead_df(logs_dir=log_dir_path, pattern=pattern) + values = get_comp_fraction_full_array(df, print_table=True) + + strategies = sorted(df["strategy"].unique()) + gpu_numbers = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) + + fig, _ = create_stacked_plot(values, strategies, gpu_numbers) + + # TODO: set these dynamically? + fig.set_figwidth(8) + fig.set_figheight(6) + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + plt.savefig(output_path) + print(f"\nSaved computation vs. communication plot at '{output_path.resolve()}'") + + @app.command() def sanity_check( - torch: Annotated[Optional[bool], typer.Option( - help=("Check also itwinai.torch modules.") - )] = False, - tensorflow: Annotated[Optional[bool], typer.Option( - help=("Check also itwinai.tensorflow modules.") - )] = False, - all: Annotated[Optional[bool], typer.Option( - help=("Check all modules.") - )] = False, + torch: Annotated[ + Optional[bool], typer.Option(help=("Check also itwinai.torch modules.")) + ] = False, + tensorflow: Annotated[ + Optional[bool], typer.Option(help=("Check also itwinai.tensorflow modules.")) + ] = False, + all: Annotated[Optional[bool], typer.Option(help=("Check all modules."))] = False, ): - """Run sanity checks on the installation of itwinai and - its dependencies by trying to import itwinai modules. - By default, only itwinai core modules (neither torch, nor + """Run sanity checks on the installation of itwinai and its dependencies by trying + to import itwinai modules. By default, only itwinai core modules (neither torch, nor tensorflow) are tested.""" from itwinai.tests.sanity_check import ( sanity_check_all, @@ -41,6 +89,7 @@ def sanity_check( sanity_check_tensorflow, sanity_check_torch, ) + all = (torch and tensorflow) or all if all: sanity_check_all() @@ -54,18 +103,15 @@ def sanity_check( @app.command() def scalability_report( - pattern: Annotated[str, typer.Option( - help="Python pattern matching names of CSVs in sub-folders." - )], - plot_title: Annotated[Optional[str], typer.Option( - help=("Plot name.") - )] = None, - skip_id: Annotated[Optional[int], typer.Option( - help=("Skip epoch ID.") - )] = None, - archive: Annotated[Optional[str], typer.Option( - help=("Archive name to backup the data, without extension.") - )] = None, + pattern: Annotated[ + str, typer.Option(help="Python pattern matching names of CSVs in sub-folders.") + ], + plot_title: Annotated[Optional[str], typer.Option(help=("Plot name."))] = None, + skip_id: Annotated[Optional[int], typer.Option(help=("Skip epoch ID."))] = None, + archive: Annotated[ + Optional[str], + typer.Option(help=("Archive name to backup the data, without extension.")), + ] = None, ): """ Generate scalability report merging all CSVs containing epoch time @@ -88,7 +134,7 @@ def scalability_report( import numpy as np import pandas as pd - regex = re.compile(r'{}'.format(pattern)) + regex = re.compile(r"{}".format(pattern)) combined_df = pd.DataFrame() csv_files = [] for root, _, files in os.walk(os.getcwd()): @@ -104,14 +150,13 @@ def scalability_report( print(combined_df) avg_times = ( - combined_df - .drop(columns='epoch_id') - .groupby(['name', 'nodes']) + combined_df.drop(columns="epoch_id") + .groupby(["name", "nodes"]) .mean() .reset_index() ) print("\nAvg over name and nodes:") - print(avg_times.rename(columns=dict(time='avg(time)'))) + print(avg_times.rename(columns=dict(time="avg(time)"))) # fig, (sp_up_ax, eff_ax) = plt.subplots(1, 2, figsize=(12, 4)) fig, sp_up_ax = plt.subplots(1, 1, figsize=(6, 4)) @@ -125,7 +170,7 @@ def scalability_report( series_names = sorted(set(avg_times.name.values)) for name in series_names: - df = avg_times[avg_times.name == name].drop(columns='name') + df = avg_times[avg_times.name == name].drop(columns="name") # Debug # compute_time = [3791., 1884., 1011., 598.] @@ -133,32 +178,42 @@ def scalability_report( # d = {'nodes': nodes, 'time': compute_time} # df = pd.DataFrame(data=d) - df["NGPUs"] = df["nodes"]*4 + df["NGPUs"] = df["nodes"] * 4 # speedup df["Speedup - ideal"] = df["nodes"].astype(float) df["Speedup"] = df["time"].iloc[0] / df["time"] df["Nworkers"] = 1 # efficiency - df["Threadscaled Sim. Time / s"] = df["time"] * \ - df["nodes"] * df["Nworkers"] - df["Efficiency"] = df["Threadscaled Sim. Time / s"].iloc[0] / \ - df["Threadscaled Sim. Time / s"] + df["Threadscaled Sim. Time / s"] = df["time"] * df["nodes"] * df["Nworkers"] + df["Efficiency"] = ( + df["Threadscaled Sim. Time / s"].iloc[0] / df["Threadscaled Sim. Time / s"] + ) sp_up_ax.plot( - df["NGPUs"].values, df["Speedup"].values, - marker=next(markers), lw=1.0, label=name, alpha=0.7) + df["NGPUs"].values, + df["Speedup"].values, + marker=next(markers), + lw=1.0, + label=name, + alpha=0.7, + ) - sp_up_ax.plot(df["NGPUs"].values, df["Speedup - ideal"].values, - ls='dashed', lw=1.0, c='k', label="ideal") + sp_up_ax.plot( + df["NGPUs"].values, + df["Speedup - ideal"].values, + ls="dashed", + lw=1.0, + c="k", + label="ideal", + ) sp_up_ax.legend(ncol=1) sp_up_ax.set_xticks(df["NGPUs"].values) - sp_up_ax.get_xaxis().set_major_formatter( - matplotlib.ticker.ScalarFormatter()) + sp_up_ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - sp_up_ax.set_ylabel('Speedup') - sp_up_ax.set_xlabel('NGPUs (4 per node)') + sp_up_ax.set_ylabel("Speedup") + sp_up_ax.set_xlabel("NGPUs (4 per node)") sp_up_ax.grid() # Sort legend @@ -168,42 +223,42 @@ def scalability_report( plot_png = f"scaling_plot_{plot_title}.png" plt.tight_layout() - plt.savefig(plot_png, bbox_inches='tight', format='png', dpi=300) + plt.savefig(plot_png, bbox_inches="tight", format="png", dpi=300) print("Saved scaling plot to: ", plot_png) if archive is not None: - if '/' in archive: - raise ValueError("Archive name must NOT contain a path. " - f"Received: '{archive}'") - if '.' in archive: - raise ValueError("Archive name must NOT contain an extension. " - f"Received: '{archive}'") + if "/" in archive: + raise ValueError( + f"Archive name must NOT contain a path. Received: '{archive}'" + ) + if "." in archive: + raise ValueError( + f"Archive name must NOT contain an extension. Received: '{archive}'" + ) if os.path.isdir(archive): - raise ValueError(f"Folder '{archive}' already exists. " - "Change archive name.") + raise ValueError(f"Folder '{archive}' already exists. Change archive name.") os.makedirs(archive) for csvfile in csv_files: - shutil.copyfile(csvfile, os.path.join(archive, - os.path.basename(csvfile))) + shutil.copyfile(csvfile, os.path.join(archive, os.path.basename(csvfile))) shutil.copyfile(plot_png, os.path.join(archive, plot_png)) avg_times.to_csv(os.path.join(archive, "avg_times.csv"), index=False) print("Archived AVG epoch times CSV") # Copy SLURM logs: *.err *.out files - if os.path.exists('logs_slurm'): + if os.path.exists("logs_slurm"): print("Archived SLURM logs") - shutil.copytree('logs_slurm', os.path.join(archive, 'logs_slurm')) + shutil.copytree("logs_slurm", os.path.join(archive, "logs_slurm")) # Copy other SLURM logs - for ext in ['*.out', '*.err']: + for ext in ["*.out", "*.err"]: for file in glob.glob(ext): shutil.copyfile(file, os.path.join(archive, file)) # Create archive archive_name = shutil.make_archive( base_name=archive, # archive file name - format='gztar', + format="gztar", # root_dir='.', - base_dir=archive # folder path inside archive + base_dir=archive, # folder path inside archive ) shutil.rmtree(archive) print("Archived logs and plot at: ", archive_name) @@ -211,24 +266,37 @@ def scalability_report( @app.command() def exec_pipeline( - config: Annotated[Path, typer.Option( - help="Path to the configuration file of the pipeline to execute." - )], - pipe_key: Annotated[str, typer.Option( - help=("Key in the configuration file identifying " - "the pipeline object to execute.") - )] = "pipeline", - steps: Annotated[Optional[str], typer.Option( - help=("Run only some steps of the pipeline. Accepted values are " - "indices, python slices (e.g., 0:3 or 2:10:100), and " - "string names of steps.") - )] = None, - print_config: Annotated[bool, typer.Option( - help=("Print config to be executed after overrides.") - )] = False, + config: Annotated[ + Path, + typer.Option(help="Path to the configuration file of the pipeline to execute."), + ], + pipe_key: Annotated[ + str, + typer.Option( + help=( + "Key in the configuration file identifying " + "the pipeline object to execute." + ) + ), + ] = "pipeline", + steps: Annotated[ + Optional[str], + typer.Option( + help=( + "Run only some steps of the pipeline. Accepted values are " + "indices, python slices (e.g., 0:3 or 2:10:100), and " + "string names of steps." + ) + ), + ] = None, + print_config: Annotated[ + bool, typer.Option(help=("Print config to be executed after overrides.")) + ] = False, overrides_list: Annotated[ - Optional[List[str]], typer.Option( - "--override", "-o", + Optional[List[str]], + typer.Option( + "--override", + "-o", help=( "Nested key to dynamically override elements in the " "configuration file with the " @@ -237,13 +305,11 @@ def exec_pipeline( "Example: [...] " "-o pipeline.init_args.trainer.init_args.lr=0.001 " "-o pipeline.my_list.2.batch_size=64 " - ) - ) - ] = None + ), + ), + ] = None, ): - """Execute a pipeline from configuration file. - Allows dynamic override of fields. - """ + """Execute a pipeline from configuration file. Allows dynamic override of fields.""" # Add working directory to python path so that the interpreter is able # to find the local python files imported from the pipeline file import os @@ -251,23 +317,26 @@ def exec_pipeline( import sys from .utils import str_to_slice + sys.path.append(os.path.dirname(config)) sys.path.append(os.getcwd()) # Parse and execute pipeline from itwinai.parser import ConfigParser + overrides_list = overrides_list if overrides_list is not None else [] overrides = { - k: v for k, v - in map(lambda x: (x.split('=')[0], x.split('=')[1]), overrides_list) + k: v + for k, v in map(lambda x: (x.split("=")[0], x.split("=")[1]), overrides_list) } parser = ConfigParser(config=config, override_keys=overrides) if print_config: import json + print() - print("#="*15 + " Used configuration " + "#="*15) + print("#=" * 15 + " Used configuration " + "#=" * 15) print(json.dumps(parser.config, indent=2)) - print("#="*50) + print("#=" * 50) print() pipeline = parser.parse_pipeline(pipeline_nested_key=pipe_key) if steps: @@ -282,13 +351,9 @@ def exec_pipeline( @app.command() def mlflow_ui( path: str = typer.Option("ml-logs/", help="Path to logs storage."), - port: int = typer.Option( - 5000, help="Port on which the MLFlow UI is listening." - ), + port: int = typer.Option(5000, help="Port on which the MLFlow UI is listening."), ): - """ - Visualize Mlflow logs. - """ + """Visualize Mlflow logs.""" import subprocess subprocess.run(f"mlflow ui --backend-store-uri {path} --port {port}".split()) @@ -297,12 +362,9 @@ def mlflow_ui( @app.command() def mlflow_server( path: str = typer.Option("ml-logs/", help="Path to logs storage."), - port: int = typer.Option( - 5000, help="Port on which the server is listening."), + port: int = typer.Option(5000, help="Port on which the server is listening."), ): - """ - Spawn Mlflow server. - """ + """Spawn Mlflow server.""" import subprocess subprocess.run(f"mlflow server --backend-store-uri {path} --port {port}".split()) @@ -310,18 +372,13 @@ def mlflow_server( @app.command() def kill_mlflow_server( - port: int = typer.Option( - 5000, help="Port on which the server is listening."), + port: int = typer.Option(5000, help="Port on which the server is listening."), ): - """ - Kill Mlflow server. - """ + """Kill Mlflow server.""" import subprocess subprocess.run( - f"kill -9 $(lsof -t -i:{port})".split(), - check=True, - stderr=subprocess.DEVNULL + f"kill -9 $(lsof -t -i:{port})".split(), check=True, stderr=subprocess.DEVNULL ) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 700173c0..4b052319 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -82,12 +82,14 @@ >>> python my_train.py --config training_pipe.yaml --lr 0.002 """ - from __future__ import annotations import functools import time from abc import ABC, abstractmethod + +# import logging +# from logging import Logger as PythonLogger from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .serialization import ModelLoader, Serializable @@ -113,6 +115,7 @@ def wrapper(self: BaseComponent, *args, **kwargs) -> Any: msg = f"'{self.name}' executed in {self.exec_t:.3f}s" self._printout(msg) return result + return wrapper @@ -127,7 +130,8 @@ class BaseComponent(ABC, Serializable): Args: name (Optional[str], optional): unique identifier for a step. Defaults to None. - """ + """ + _name: str = None #: Dictionary storing constructor arguments. Needed to serialize the #: class to dictionary. Set by ``self.save_parameters()`` method. @@ -144,11 +148,8 @@ def __init__( @property def name(self) -> str: - """Name of current component. Defaults to ``self.__class__.__name__``. - """ - return ( - self._name if self._name is not None else self.__class__.__name__ - ) + """Name of current component. Defaults to ``self.__class__.__name__``.""" + return self._name if self._name is not None else self.__class__.__name__ @name.setter def name(self, name: str) -> None: @@ -157,7 +158,7 @@ def name(self, name: str) -> None: @abstractmethod @monitor_exec def execute(self, *args, **kwargs) -> Any: - """"Execute some operations.""" + """Execute some operations.""" # def setup_console(self): # """Setup Python logging""" @@ -186,9 +187,9 @@ def cleanup(self): @staticmethod def _printout(msg: str): msg = f"# {msg} #" - print("#"*len(msg)) + print("#" * len(msg)) print(msg) - print("#"*len(msg)) + print("#" * len(msg)) class DataGetter(BaseComponent): @@ -213,7 +214,7 @@ def execute( self, train_dataset: MLDataset, validation_dataset: MLDataset, - test_dataset: MLDataset + test_dataset: MLDataset, ) -> Tuple[MLDataset, MLDataset, MLDataset]: """Trains a machine learning model. @@ -230,6 +231,7 @@ def execute( class DataSplitter(BaseComponent): """Splits a dataset into train, validation, and test splits.""" + _train_proportion: Union[int, float] _validation_proportion: Union[int, float] _test_proportion: Union[int, float] @@ -239,7 +241,7 @@ def __init__( train_proportion: Union[int, float], validation_proportion: Union[int, float], test_proportion: Union[int, float], - name: Optional[str] = None + name: Optional[str] = None, ) -> None: super().__init__(name) self.save_parameters(**self.locals2params(locals())) @@ -291,10 +293,7 @@ def test_proportion(self, prop: Union[int, float]) -> None: @abstractmethod @monitor_exec - def execute( - self, - dataset: MLDataset - ) -> Tuple[MLDataset, MLDataset, MLDataset]: + def execute(self, dataset: MLDataset) -> Tuple[MLDataset, MLDataset, MLDataset]: """Splits a dataset into train, validation and test splits. Args: @@ -315,7 +314,7 @@ def execute( self, train_dataset: MLDataset, validation_dataset: MLDataset, - test_dataset: MLDataset + test_dataset: MLDataset, ) -> Tuple[MLDataset, MLDataset, MLDataset, MLModel]: """Trains a machine learning model. @@ -348,9 +347,7 @@ def __init__( @abstractmethod @monitor_exec def execute( - self, - predict_dataset: MLDataset, - model: Optional[MLModel] = None + self, predict_dataset: MLDataset, model: Optional[MLModel] = None ) -> MLDataset: """Applies a machine learning model on a dataset of samples. @@ -433,26 +430,29 @@ def execute(self, *args) -> Tuple: """ result = [] for itm in self.policy: - if isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX): - arg_idx = int(itm[len(self.INPUT_PREFIX):]) - if arg_idx >= len(args): - max_idx = max(map( - lambda itm: int(itm[len(self.INPUT_PREFIX):]), + if not (isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX)): + result.append(itm) + continue + + arg_idx = int(itm[len(self.INPUT_PREFIX) :]) + if arg_idx >= len(args): + max_idx = max( + map( + lambda itm: int(itm[len(self.INPUT_PREFIX) :]), filter( lambda el: ( - isinstance(el, str) - and el.startswith(self.INPUT_PREFIX) + isinstance(el, str) and el.startswith(self.INPUT_PREFIX) ), - self.policy - ))) - raise IndexError( - f"The args received as input by '{self.name}' " - "are not consistent with the given adapter policy " - "because input args are too few! " - f"Input args are {len(args)} but the policy foresees " - f"at least {max_idx+1} items." + self.policy, + ), ) - result.append(args[arg_idx]) - else: - result.append(itm) + ) + raise IndexError( + f"The args received as input by '{self.name}' " + "are not consistent with the given adapter policy " + "because input args are too few! " + f"Input args are {len(args)} but the policy foresees " + f"at least {max_idx+1} items." + ) + result.append(args[arg_idx]) return tuple(result) diff --git a/src/itwinai/torch/distributed.py b/src/itwinai/torch/distributed.py index 6534e508..559ba2b0 100644 --- a/src/itwinai/torch/distributed.py +++ b/src/itwinai/torch/distributed.py @@ -2,8 +2,6 @@ import os from typing import Any, Iterable, List, Literal, Optional, Tuple, Union -import deepspeed -import horovod.torch as hvd import torch import torch.distributed as dist import torch.nn as nn @@ -565,10 +563,7 @@ class DeepSpeedStrategy(TorchDistributedStrategy): #: Torch distributed communication backend. backend: Literal['nccl', 'gloo', 'mpi'] - def __init__( - self, - backend: Literal['nccl', 'gloo', 'mpi'] - ) -> None: + def __init__(self, backend: Literal['nccl', 'gloo', 'mpi']) -> None: super().__init__() self.backend = backend @@ -581,6 +576,8 @@ def init(self) -> None: DistributedStrategyError: when trying to initialize a strategy already initialized. """ + import deepspeed + self.deepspeed = deepspeed if not distributed_resources_available(): raise RuntimeError( "Trying to run distributed on insufficient resources.") @@ -591,10 +588,11 @@ def init(self) -> None: # https://github.com/Lightning-AI/pytorch-lightning/issues/13567 ompi_lrank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get( - 'LOCAL_RANK', ompi_lrank) + 'LOCAL_RANK', ompi_lrank + ) # https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization - deepspeed.init_distributed(dist_backend=self.backend) + self.deepspeed.init_distributed(dist_backend=self.backend) self.is_initialized = True self.set_device() @@ -608,9 +606,10 @@ def distributed( """Setup model, optimizer and scheduler for distributed.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) - distrib_model, optimizer, _, lr_scheduler = deepspeed.initialize( + distrib_model, optimizer, _, lr_scheduler = self.deepspeed.initialize( model=model, model_parameters=model_parameters, optimizer=optimizer, @@ -752,7 +751,11 @@ def init(self) -> None: "Trying to run distributed on insufficient resources.") if self.is_initialized: raise DistributedStrategyError("Strategy was already initialized") - hvd.init() + + import horovod.torch as hvd + self.hvd = hvd + + self.hvd.init() self.is_initialized = True self.set_device() @@ -772,16 +775,16 @@ def distributed( # Scale learning rate # https://github.com/horovod/horovod/issues/1653#issuecomment-574764452 lr_scaler = 1 - if optim_kwargs.get('op') == hvd.Adasum: - lr_scaler = hvd.local_size() - elif optim_kwargs.get('op') == hvd.Average: - lr_scaler = hvd.size() + if optim_kwargs.get('op') == self.hvd.Adasum: + lr_scaler = self.hvd.local_size() + elif optim_kwargs.get('op') == self.hvd.Average: + lr_scaler = self.hvd.size() for g in optimizer.param_groups: g['lr'] *= lr_scaler self._broadcast_params(model, optimizer) - distOptimizer = hvd.DistributedOptimizer( + distOptimizer = self.hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), **optim_kwargs @@ -799,8 +802,8 @@ def _broadcast_params( optimizer (optim.Optimizer): Optimizer that is to be broadcasted across processes. """ - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(optimizer, root_rank=-0) + self.hvd.broadcast_parameters(model.state_dict(), root_rank=0) + self.hvd.broadcast_optimizer_state(optimizer, root_rank=-0) def global_world_size(self) -> int: """Returns the total number of processes (global world size). @@ -811,7 +814,7 @@ def global_world_size(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.size() + return self.hvd.size() def local_world_size(self) -> int: """Returns the local number of workers available per node, @@ -823,7 +826,7 @@ def local_world_size(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.local_size() + return self.hvd.local_size() def global_rank(self) -> int: """Returns the global rank of the current process, where @@ -835,7 +838,7 @@ def global_rank(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.rank() + return self.hvd.rank() def local_rank(self) -> int: """Returns the local rank of the current process. @@ -846,14 +849,14 @@ def local_rank(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.local_rank() + return self.hvd.local_rank() def clean_up(self) -> None: """Shuts Horovod down.""" if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - hvd.shutdown() + self.hvd.shutdown() def allgather_obj(self, obj: Any) -> list[Any]: """All-gathers scalar objects across all workers to a @@ -869,7 +872,7 @@ def allgather_obj(self, obj: Any) -> list[Any]: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method." ) - return hvd.allgather_object(obj) + return self.hvd.allgather_object(obj) def gather_obj(self, obj: Any, dst_rank: int = 0) -> list[Any]: """The same as ``allgather_obj``, as gather is not supported diff --git a/src/itwinai/torch/profiling/__init__.py b/src/itwinai/torch/profiling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/itwinai/torch/profiling/communication_plot.py b/src/itwinai/torch/profiling/communication_plot.py new file mode 100644 index 00000000..23dec5a0 --- /dev/null +++ b/src/itwinai/torch/profiling/communication_plot.py @@ -0,0 +1,232 @@ +from pathlib import Path +from re import Pattern, compile +from typing import Any, List, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.patches import Patch + +# Doing this because otherwise I get an error about X11 Forwarding which I believe +# is due to the server trying to pass the image to the client computer +matplotlib.use("Agg") + +# import logging +# from logging import Logger as PythonLogger + + +def calculate_comp_and_comm_time(df: pd.DataFrame) -> Tuple[float, float]: + """Calculates the time spent computing and time spent communicating and returns a + tuple of these numbers in seconds. Assumes that you are running with an NCCL + backend. + + Raises: + ValueError: If not all expected columns ('name', 'self_cuda_time_total') are + found in the given DataFrame. + """ + expected_columns = {"name", "self_cuda_time_total"} + if not expected_columns.issubset(df.columns): + missing_columns = expected_columns - set(df.columns) + raise ValueError( + f"Invalid data format! DataFrame does not contain the necessary columns." + f"\nMissing columns: {missing_columns}" + ) + + nccl_comm_pattern = ( + r"ncclKernel_(?:AllReduce|Broadcast|Reduce|AllGather|ReduceScatter|SendRecv)" + ) + cuda_stream_pattern = r"cudaStream(?:WaitEvent|Synchronize)" + + # Any operation that is a part of PyTorch's ATen library is considered a computation + aten_comp_pattern = r"aten::" + + comm_df = df[ + (df["name"].str.contains(nccl_comm_pattern)) + | (df["name"].str.contains(cuda_stream_pattern)) + ] + comp_df = df[df["name"].str.contains(aten_comp_pattern)] + + comp_time = comp_df["self_cuda_time_total"].sum() + comm_time = comm_df["self_cuda_time_total"].sum() + + # Converting from microseconds to seconds + comp_time *= 1e-6 + comm_time *= 1e-6 + + return comp_time, comm_time + + +def create_stacked_plot( + values: np.ndarray, strategy_labels: List, gpu_numbers: List +) -> Tuple[Any, Any]: + """Creates a stacked plot showing values from 0 to 1, where the given value + will be placed on the bottom and the complement will be placed on top for + each value in 'values'. Returns the figure and the axis so that the caller can + do what they want with it, e.g. save to file, change it or just show it. + + Notes: + - Assumes that the rows of 'values' correspond to the labels in + 'strategy_labels' sorted alphabetically and that the columns correspond to + the GPU numbers in 'gpu_numbers' sorted numerically in ascending order. + """ + sns.set_theme() + + strategy_labels = sorted(strategy_labels) + gpu_numbers = sorted(gpu_numbers, key=lambda x: int(x)) + + width = 1 / (len(strategy_labels) + 1) + comp_color = "lightblue" + comm_color = "lightgreen" + complements = 1 - values + + x = np.arange(len(gpu_numbers)) + fig, ax = plt.subplots() + + # Creating an offset to "center" around zero + static_offset = len(strategy_labels) / 2 - 0.5 + for strategy_idx in range(len(strategy_labels)): + dynamic_bar_offset = strategy_idx - static_offset + + ax.bar( + x=x + dynamic_bar_offset * width, + height=values[strategy_idx], + width=width, + color=comp_color, + ) + ax.bar( + x=x + dynamic_bar_offset * width, + height=complements[strategy_idx], + width=width, + bottom=values[strategy_idx], + color=comm_color, + ) + + # Positioning the labels under the stacks + for gpu_idx in range(len(gpu_numbers)): + if np.isnan(values[strategy_idx, gpu_idx]): + continue + dynamic_label_offset = strategy_idx - static_offset + ax.text( + x=x[gpu_idx] + dynamic_label_offset * width, + y=-0.1, + s=strategy_labels[strategy_idx], + ha="center", + va="top", + fontsize=10, + rotation=60, + ) + + ax.set_ylabel("Computation fraction") + ax.set_title("Computation vs Communication Time by Method") + ax.set_xticks(x) + ax.set_xticklabels(gpu_numbers) + ax.set_ylim(0, 1) + + # Setting the appropriate colors since the legend is manual + legend_elements = [ + Patch(facecolor=comm_color, label="Communication"), + Patch(facecolor=comp_color, label="Computation"), + ] + + # Positioning the legend outside of the plot to not obstruct it + ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(0.80, 1.22), + borderaxespad=0.0, + ) + fig.subplots_adjust(bottom=0.25) + fig.subplots_adjust(top=0.85) + return fig, ax + + +def create_combined_comm_overhead_df(logs_dir: Path, pattern: str) -> pd.DataFrame: + """Reads and combines all files in a folder that matches the given regex pattern + into a single DataFrame. The files must be formatted as csv files. + + Raises: + ValueError: If not all expected columns are found in the stored DataFrame. + ValueError: If no matching files are found in the given logging directory. + """ + re_pattern: Pattern = compile(pattern) + dataframes = [] + expected_columns = { + "strategy", + "num_gpus", + "global_rank", + "name", + "self_cuda_time_total", + } + for entry in logs_dir.iterdir(): + match = re_pattern.search(str(entry)) + if not match: + continue + + df = pd.read_csv(entry) + if not expected_columns.issubset(df.columns): + missing_columns = expected_columns - set(df.columns) + raise ValueError( + f"Invalid data format! File at '{match.string}' doesn't contain all" + f" necessary columns. \nMissing columns: {missing_columns}" + ) + + dataframes.append(df) + if len(dataframes) == 0: + raise ValueError( + f"No matching files found in '{logs_dir.resolve()}' for pattern '{pattern}'" + ) + return pd.concat(dataframes) + + +def get_comp_fraction_full_array( + df: pd.DataFrame, print_table: bool = False +) -> np.ndarray: + """Creates a MxN NumPy array where M is the number of strategies + and N is the number of GPU configurations. The strategies are sorted + alphabetically and the GPU configurations are sorted in ascending number + of GPUs. + """ + unique_num_gpus = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) + unique_strategies = sorted(df["strategy"].unique()) + values = [] + + table_string = "" + + for strategy in unique_strategies: + strategy_values = [] + for num_gpus in unique_num_gpus: + filtered_df = df[ + (df["strategy"] == strategy) & (df["num_gpus"] == num_gpus) + ] + + row_string = f"{strategy:>12} | {num_gpus:>10}" + + # Allows asymmetric testing, i.e. not testing all num gpus and all + # strategies together + if len(filtered_df) == 0: + comp_time, comm_time = np.NaN, np.NaN + strategy_values.append(np.NaN) + + row_string += f" | {'(NO DATA)':>15}" + else: + comp_time, comm_time = calculate_comp_and_comm_time(df=filtered_df) + # Avoid division-by-zero errors (1e-10) + comp_fraction = comp_time / (comp_time + comm_time + 1e-10) + strategy_values.append(comp_fraction) + + row_string += f" | {comp_time:>8.2f}s" + row_string += f" | {comm_time:>8.2f}s" + + table_string += row_string + "\n" + values.append(strategy_values) + + if print_table: + print(f"{'-'*50}") + print(f"{'Strategy':>12} | {'Num. GPUs':>10} | {'Comp.':>9} | {'Comm.':>8}") + print(f"{'-'*50}") + print(table_string) + print(f"{'-'*50}") + + return np.array(values) diff --git a/src/itwinai/torch/profiling/profiler.py b/src/itwinai/torch/profiling/profiler.py new file mode 100644 index 00000000..78c740e7 --- /dev/null +++ b/src/itwinai/torch/profiling/profiler.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import functools +from pathlib import Path +from typing import Any, Callable, Iterable + +import matplotlib +import pandas as pd +from torch.profiler import ProfilerActivity, profile, schedule + +from itwinai.torch.distributed import ( + DeepSpeedStrategy, + HorovodStrategy, + NonDistributedStrategy, + TorchDDPStrategy, +) +from itwinai.torch.trainer import TorchTrainer + +# Doing this because otherwise I get an error about X11 Forwarding which I believe +# is due to the server trying to pass the image to the client computer +matplotlib.use("Agg") + + +def profile_torch_trainer(method: Callable) -> Callable: + """Decorator for execute method for components. Profiles the communication time + vs. computation time and stores the result for future analysis. + """ + + def gather_profiling_data(key_averages: Iterable) -> pd.DataFrame: + profiling_data = [] + for event in key_averages: + profiling_data.append( + { + "name": event.key, + "node_id": event.node_id, + "self_cpu_time_total": event.self_cpu_time_total, + "cpu_time_total": event.cpu_time_total, + "cpu_time_total_str": event.cpu_time_total_str, + "self_cuda_time_total": event.self_cuda_time_total, + "cuda_time_total": event.cuda_time_total, + "cuda_time_total_str": event.cuda_time_total_str, + "calls": event.count, + } + ) + return pd.DataFrame(profiling_data) + + @functools.wraps(method) + def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: + + profiler = profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + with_modules=True, + schedule=schedule( + # skip_first=1 + wait=1, + warmup=2, + active=100, + ), + ) + profiler.start() + + self.profiler = profiler + try: + result = method(self, *args, **kwargs) + finally: + profiler.stop() + + strategy = self.strategy + if isinstance(strategy, NonDistributedStrategy): + strategy_str = "non-dist" + elif isinstance(strategy, TorchDDPStrategy): + strategy_str = "ddp" + elif isinstance(strategy, DeepSpeedStrategy): + strategy_str = "deepspeed" + elif isinstance(strategy, HorovodStrategy): + strategy_str = "horovod" + else: + strategy_str = "unk" + + global_rank = strategy.global_rank() + num_gpus_global = strategy.global_world_size() + + # Extracting and storing the profiling data + key_averages = profiler.key_averages() + profiling_dataframe = gather_profiling_data(key_averages=key_averages) + profiling_dataframe["strategy"] = strategy_str + profiling_dataframe["num_gpus"] = num_gpus_global + profiling_dataframe["global_rank"] = global_rank + + profiling_log_dir = Path("profiling_logs") + profiling_log_dir.mkdir(parents=True, exist_ok=True) + + filename = f"profile_{strategy_str}_{num_gpus_global}_{global_rank}.csv" + output_path = profiling_log_dir / filename + + print(f"Writing profiling dataframe to {output_path}") + profiling_dataframe.to_csv(output_path) + strategy.clean_up() + + return result + + return profiled_method diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 121372bc..fa73289a 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -98,6 +98,8 @@ class TorchTrainer(Trainer, LogMixin): test_glob_step: int = 0 #: Dictionary of ``torchmetrics`` metrics, indexed by user-defined names. metrics: Dict[str, Metric] + #: PyTorch Profiler for communication vs. computation comparison + profiler: Optional[Any] def __init__( self, @@ -134,6 +136,7 @@ def __init__( self.checkpoints_location = checkpoints_location os.makedirs(self.checkpoints_location, exist_ok=True) self.checkpoint_every = checkpoint_every + self.profiler = None @property def strategy(self) -> TorchDistributedStrategy: @@ -372,7 +375,7 @@ def execute( if self.logger: self.logger.destroy_logger_context() - self.strategy.clean_up() + # self.strategy.clean_up() return train_dataset, validation_dataset, test_dataset, self.model def _set_epoch_dataloaders(self, epoch: int): @@ -392,6 +395,8 @@ def set_epoch(self, epoch: int) -> None: Args: epoch (int): epoch number, from 0 to ``epochs-1``. """ + if self.profiler is not None: + self.profiler.step() self._set_epoch_dataloaders(epoch) def log( @@ -520,10 +525,8 @@ def train(self): val_loss = self.validation_epoch(epoch) # Checkpointing current best model - worker_val_losses = self.strategy.gather( - val_loss, dst_rank=0 - ) - if self.strategy.global_rank() == 0: + worker_val_losses = self.strategy.gather(val_loss, dst_rank=0) + if self.strategy.is_main_worker: avg_loss = torch.mean( torch.stack(worker_val_losses) ).detach().cpu() @@ -628,7 +631,7 @@ def train_step( ) return loss, metrics - def validation_epoch(self, epoch: int) -> torch.Tensor: + def validation_epoch(self, epoch: int) -> Optional[torch.Tensor]: """Perform a complete sweep over the validation dataset, completing an epoch of validation. @@ -636,43 +639,45 @@ def validation_epoch(self, epoch: int) -> torch.Tensor: epoch (int): current epoch number, from 0 to ``self.epochs - 1``. Returns: - Loss: average validation loss for the current epoch. + Optional[Loss]: average validation loss for the current epoch if + self.validation_dataloader is not None """ - if self.validation_dataloader is not None: - self.model.eval() - validation_losses = [] - validation_metrics = [] - for batch_idx, val_batch \ - in enumerate(self.validation_dataloader): - loss, metrics = self.validation_step( - batch=val_batch, - batch_idx=batch_idx - ) - validation_losses.append(loss) - validation_metrics.append(metrics) + if self.validation_dataloader is None: + return + + self.model.eval() + validation_losses = [] + validation_metrics = [] + for batch_idx, val_batch in enumerate(self.validation_dataloader): + loss, metrics = self.validation_step( + batch=val_batch, + batch_idx=batch_idx + ) + validation_losses.append(loss) + validation_metrics.append(metrics) - # Important: update counter - self.validation_glob_step += 1 + # Important: update counter + self.validation_glob_step += 1 - # Aggregate and log losses - avg_loss = torch.mean(torch.stack(validation_losses)) + # Aggregate and log losses + avg_loss = torch.mean(torch.stack(validation_losses)) + self.log( + item=avg_loss.item(), + identifier='validation_loss_epoch', + kind='metric', + step=self.validation_glob_step, + ) + # Aggregate and log metrics + avg_metrics = pd.DataFrame(validation_metrics).mean().to_dict() + for m_name, m_val in avg_metrics.items(): self.log( - item=avg_loss.item(), - identifier='validation_loss_epoch', + item=m_val, + identifier='validation_' + m_name + '_epoch', kind='metric', step=self.validation_glob_step, ) - # Aggregate and log metrics - avg_metrics = pd.DataFrame(validation_metrics).mean().to_dict() - for m_name, m_val in avg_metrics.items(): - self.log( - item=m_val, - identifier='validation_' + m_name + '_epoch', - kind='metric', - step=self.validation_glob_step, - ) - return avg_loss + return avg_loss def validation_step( self, diff --git a/use-cases/eurac/config.yaml b/use-cases/eurac/config.yaml index 64ee45f7..8912e898 100644 --- a/use-cases/eurac/config.yaml +++ b/use-cases/eurac/config.yaml @@ -6,7 +6,7 @@ tmp_stats: /p/scratch/intertwin/datasets/eurac/stats experiment: "drought use case lstm" run_name: "alps_test" -epochs: 2 +epochs: 5 random_seed: 1010 lr: 0.001 batch_size: 256 diff --git a/use-cases/eurac/plots/comm_plot.png b/use-cases/eurac/plots/comm_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..b406b1a08737cb91521ed856bc7e8bd2934b9724 GIT binary patch literal 30280 zcmeFZcUV`TH_ZQyb^HZk8Ai($VNiLM%nVRjon2nBZ}Nbn`>7sZLXSJV!2^t zWo=?RnsAa0*Q7M9DHOVkK3?A6QL(F=Sl(>&gpfD68OLZK9H74VnMzA+Rq{55O8Z)trU#q07xskXK@uJbS5XlZV%PRo(+ zy?MGJL8oS@#jf-F_xP(#$<*ZG^(R=&W!*hI$g8Z>b(=8Ur| z-^P_WxN+*^(PExe)N4i1eh>mEOkvgMDr@Wu8HS)bnO2X2yo zcp!A&{^1DjKR*YZlqvJ$R}DQUR~tsP?)uypAFF;Q)qXs~-P5zqPcPw%oBQENZ*Ok$ zr**}wXx&f*!x4+GnF-n%7JailWg>NpB2rFHFDzQ;w8N#JK8cMryF0?a^NKj`S-|b~ z?NeK5)|-1ZCHxG#7CYQhTqVi-RjcIf+m~Tssq4Jo-=i~~9Sh%eP(Ml6Wp z+6XyEM&Zn56TiUdZG0C#;HCnPg^YEVU#WA}+D)D%%bGjvcLpw3%W3Xc_&S>WIRC?9HP3hhZGL`pv3GFjFWR0HqY$#h8T@|HPro(u`kC?W@_H{Z zrYf=J4;-#btNUxCQ?hLbIH}FqlPAr;X6m{wZsYMS4XX}!7QEVc?J2jqZgpM?v&VK< ziv-o|?LTp;GlhruHzeFKt&iJz;>KWnQ_WIO`4(SfDcsuhbtS8B-1f1w+if4hF9znGcM zthZ8p8@JkpmUQ!~#y6JzpKT`pd{7$7*Xb#g@Pyu7^pT&J(sBtG(~u1-}o zMnyH}24k8@4UH%Ba_}d=#FwkrE!ipO;?d|}nN9cO21LmEM-?pRM?B+g ziMmoDs+B&t$&+#Z?oo@cF?)op!e)gO0*`9)YNb}!M%Q5R&;^{ddnWDYBEPg`%iJio zuj3I7-7r6O*NG$N?`+&w{gm5p+wZ?$8nzzNiqc9mjX!;7Blq6DiYM(X8n_c_tKGMm zsiqnqO0#TNJZAB=@)2W;pxv+&)h94;e$Djuty@FY^_(@gY}w*4KV;Y0(eXN5W_w_U zAHSZ9b^Dt2Qv;u#mO1?JJd|zS^XT%I7YTheV%1?(<(h?TLsk-)D6G`qe*4Yc%S(M} zVV1dhGp48mj?_ek$g5wy8kc;fg0D|kTxtum&*#I&by13C zU%S75fAvDhx_@qRfT6{4SV+uzCNsmLYFRI%+u3|N-EMKN-D~peQoI_k>v+Ltn%nN% z`CkPdON?w@{_wycMlJrXkdtspk;mHrG1mn75Yh0KjOAlXJ!4&ECJ7_8(Y{=o$xnow zqNY7EA4}UC7yT<;oQa35oD#v3>hav&lS@x3!m*eR!9S?n#E7{AOlR5@61! zUu`;KKjx|=Tu}Vm&Yhfyc%047PEPdn6Oj@vw}U-ci;BWW(n>j2604LCE=J|mX||x_ z*DpMH@W9H-;isg8f`S6ufdgCu0!qQb!QE3oe;(XPdzYU1&-n79B8d$fHZYhgDn3>) zYSGoz&9>#9q^@&WoJ=K=zOUwrmTc>8M#lJD`|$`fS3kbRr>{kXg)amL>vF0mt?}=< zbP$iE@j^v#-Zq9CGBSRAF5kuw+hdf%Kly}-INs3WX~}c+b8=q%d`0Xkn`K+>+(4oS zgL&~IMrW>;1jW#2E^~dEOH1?3Yy66fT2Gc;cfEPidF~08Mp|#r;9%U$%`)!^7Q`I2 z^OYeCds%T0#a_&U%XCN74+h>!jI(KTYF2Ju-nL!4V#)&xRApsT z^B1hrb{#PeL-AoOGZDUeZbM7t=eMi5KYgTH6$oCC3mUBHRd&2wB?Sf zq_>7j@|(A$OPsm4rP^)PZPlprdv)?0R1pB?mZNWQb4WhgGdJQI8DrV*XEWZtnfF4b zxc!($@FIt9?w<&}X*Y(r@ENLwg$0*d+)ejI1VDMS`ncNOszkFAZ}D*+jYp3j^|$3+ zabKEpqW~*;-*VdHJZKT}{_bY=BKN1M7Oh&&GvC)1ef-F!5S$yM)1oKhIQ5v-lUa;9 z-cwQKr^ly?u;IRqJE7D!GTLQvJJXxVhsX2_^~t&}@(}Yw24*~w8_gH0j*@fqi$r5?W~aWBnjR>@!#Ud0)u<@ck;|eNHUO`^Z`}{1wqAjP*3-n%L#H1aJN? z;>&-7(GnJ#+ov8q+H!Q(;Gl35W8?Hldt`&2t7>_G;Q2fU!MZmFyXAsTP%!}!j0J%qw&T^V*a&+%u_MfTal`lpPeG%wS=-|3yshY+7nC;+KQtO)LJv{ zdH9;rf971j{uCa z`rJfcLZ0Ka#>8TUtLXWPU>%#mM#ahgI#vGMZ;9%T3akQY2`uY8`-K&E-d z>O9?g09ixVr}njkL0y7QUG89gd|<#4liU-wrm@>~f2#L<84|Q?yNW;?m;v;mN+lCJ z=EhJ=j80ArBF>-oO$GdKo0Gnq*$y=;P7X9E0UD(lmhRi>d2W1cB)=uBIW4w-y1l~e z!~KiEkB?5+StQ157sxE$8+-Vj_OJYo@OlEF3bBQUoxSJY)Z(|>c9TE6?g`o$za+~N z(ZhJK@z`XQWC#b#w{_M9Nl6FG112BLO*3<_v$H#ol@`z8BVGOdFUHG<2#of3{`|AIPF>HJl~?1ZKXOB@x}LeWhew%#NF?ClK)4@Gp9eof zUyVXfNo%XlmoL-xdB7yXB6%bo+uyv|59F(rd-E1>Mfh}EAWL7Hk=kvVNp(NI*1BB# zO;n`^zq2uCCMM=Vm8xD#uu=2#UHa?eeUT6i+eSh=5y1iTp0PgVR90(i6DltQ6#TMphf16VwJGIF zgX_}xtu`m?hZl5G0{K%ug@`Sm;m?wmmag_;3yB5D-z74~OI3>3T9THP1%~3TO>CDL z8yh=1J@Pdx@Pys4Y1_!FGwUbb(>W%erF{RL2hm02W^LxIa-_Vp&n`hz)2g;G>fuF% z0e8>E_N&UMwhrB+qjG7bK4*DpA4lwK$3<^@EU>ts*(p|lM>=NX`pdd#Hy z0If~2|62C}Su9Yx=4>H_wBT+*?= zIg?jR;7EOGOy=56cmS7?C)G3>d3klRXKt0zA>K(p<&yE48f-Fob9H)HXKZrvaLD3K zrMGWSWm|P6&yJ!{nCuC0jk=IE<})0$_qvr;wVtcak!h-wRAq(Ovf&%XD?N@^be5de z_5;;Ql*}#7n{)PG1GegHdhlApgY(Ll7o3+no72qhy$A|YO}LPCFCpaVHB$6bF5C9k zDvz`+rW#kR*2&!w7aeZ{}+*$3S4`&C;_|;=i-98r}G`_W{s3_XF>QZyQ zb1bRyO=?8#P19y(W_mr3oqIkaIX2rf{)AmYH_f6|-ly~1x0f`Vcl8Y`W+iOcWyr!UPcEPRT(`2HgSCGXx@L}M5n2IB2_BEB(-%$bv%8{#zf2G#hh zw!C?K%cbTF%fzA^xcS#|fk#b?+>JhgCP0z+`qmn~`k<2;Y1h6xQI+1`^%OL%4MdUl z8-s1`=*J!9&qv>!d*mc{{kk10y&eBQtnsoUw!o}b{evYn+&O!g1jAKm4*iwa^IU2j zyPm%N^muV8MF| zg-uPGH1G3rOxH}Y^6|;jyk8!sx8$Wgy&nA6YJg2LMOMfzQoJnMeWwv%_U-8CpxbDM*uJ&Fe}K_4L$Www)9Z}_W-kh?*}1tGRH4Wmffsai%yl}p;+lZ z9eG<@o{PEWOZ)M~15yYVeG(|h%sh_SpoIXcc~1{>ZONZJlIR&Cs7Br55}q}rbE%toz=kW-$Y z8d9Fn0{kFe)-OwJJGJoOA1Y0G-I$h+Y=IS{QB$oGN(&8Y0amn+#wupJQ_*j<$=c(eiJ9l z0}HXNsHlrt_SpvU6g&<~kcgw&=G{kC0*{(M*)O?8C-hTAXys+(%8 zNI-j}LdY}Fv#MsFdBBK9jO5SrNl=ins!Dy?mzNWM^o;}h^}V~ZY3;9tpMGH}+?r|C z<$mtbp6EZ;Z8F8I#5zv2*G$ySv+1j58*0sIxUsHz1AUDFvxmiAq6UtrFmh^Zom8h5UK_~h$+FepKTt7cQgfiJ3=>makeT-Gtrx^9b;O@`< zq4Ogm0nKS$d-v`g8y~O1QYmlI&aw&*b8K*7U#6zt9KB>a(xwB}QK`s%TV1qrBz=SY z{OoLBi0e|--Oal{53uWvgf#E)21giUJ5UcupW9y-8@XZYfv}mERgntJC`QlW?jnKM zLy;VjR_9R5Tyzg9E+B7Vrc+ek;}$BL4{3_}3&)fZP1fwhGU7Cbjv z%*@*_C)#rG%9lOkqkb29aF(#iJco`C8*uBI+qhKD|M=*ccKydk`b0SgkzVHjr>`Mm zuDU3P)WD6`rkW^^&ky4o`m4h|FxjF&C=dMn{+=`4adFO;Ol>4t#o4j0IwWyu0u3fH z<_H?5ZMj*qc2%Q4ip$Gf7Dgp!XJ<{Y1~}_ma_v({Y{Xqi4u|4vt*&3Ej-Kqw8Oc{J zd&aB0IM&f(Ss=K)gw+y3Ru~8lf4}Lz8cE0L;kvq5_448m1;v+kc%CCMW7VovG3Q_X zfr%Cg1}D6+0pt659{2yO?dFfoIt4Bw5(PRrH=>bl$#E!C#tR#0N;V!}vs(BTR>&y{ zcX~xUWailA%u+iZ^DO!S!gR992`PWq*hNZu6 zW;~&?^?)o{%C(`Abku4DJP9HDSuow;`e%Q#&hcS9#JQa+w1AnzQ4NzQtDu9o})Uq zTHJXrCXrN^MF&VCjjxwW9`8q}FMdiw(B`50wU*HwCBeIcK2cmjg5{f zg3m_=iacuG9EAzx{@_6s)@}T$+cYMv*|uET>2Xd@PDfl8Zi39xB~KJnovrWHnu?k^ zT}I?RR9|tA7*F;GTIN?hKV+bia6w39`m1%HOF0u8m{L%3NIUVU{x^|KjpE3s%DA?( z)VD6r(T>C~kL~>R1BnHMjQE6_*vDp|o)?9_DhM8^Gd{Qx6qt#2?at}|JzlodYAmoc z+X1yntPdc3tvrXcS=;H!NmUeCDlf%XnYZOOEd;QWMF;xO^g08FAg~+>A7pAF&Pn@x z2@y-&w4KiwBETHxhH~6t%QoTi7edO2oGMtd+nAUVFyW%q6Lp(2uZN=Y5wz(&5AwHf zq+q!ofq{)sg$ym4nb%%Hw?s9J(%TqXhm>XR=v_R_E5j+610+%)*i=R-hSD=IMDQ0Z zD2XjEq}c2EKvp^hrN@j*Cp-0=7gIcvx!GXj8&db~*IV<=-(a;j9|GqAJv}`VXU8-a z!9QI=EhLyXKb*rbgy=2^pb>z`eaXVCG3ixlLqo&Fc!gMC%a%RI^~r%C_o{3r`&F&y zkyd5%^YbH^T^FjpPEEB~24Wq^ZDp4~k^u&#ri{PfDzfeiP9C1mP=sY6^%%I@${q0S zLp)u%0<2>Li@<{y|G>9Q->Tp}dbGadx|35b2{94*&huKS>IoN`CUE9Rjb%)OL*&ws zDKFt{5#Ujjv#h#g*U-=&WMgCVpO4pg^R9DXz-$iWa=AH1{U|y7=+#Dj3l9&E!vJp$ z6U}CcB1vtASa&NI00n@z8BDSz*v#Zvd(cUxDru&w%9k#6oXMEBymjjqIZXNN*$)Pj zmhJg@rS0wQHuF>F)-5OzWQ{&OIYd%Mce-ZEWq)$-BYD%oS-gi1RnheBXcP^ zSIXsbX~hA|3EsJNG&BKVZ%wM|afHVS6q+VvB!(K_>@RY8crJ9-zPPAJlgH!!{e!>( z#g{@L#BrL=+j*8&kL!xKEMyktBg0t+fEaUk4)31mtBKB!pmPX0z&ssz%<^aklJd&s zlN(wbE938;>TGA6www>LYdb_98$A<~#hb*(`y~TNHs}ypw8XHlJV5`$S3a9#HVcB~ zeoMb&hp|PVsB?C-)9hG8Nr{xTwuZ)YC|4n_-zqvzRX9R0j6kAz+u74|W#i_}g8s4x zY9yJR4NbphzL`xkZp}`!&qFPs3{W8l9Wp;bWjg(v!2VWk%z&2`@BU{a(K~b z+SMij-^9ekzUg+?WB#(!E)BEu^MS}o2HB3&30duQVhby?6^pOtYdzqzGGY#|t_K2Z z0G=+h#+sizKYr(733?qG$`-!xkw*}5)Rz?nuMEQwrc{zfa2)U_L;DsrKw2M_oEzrb zcr*@5%mYm+4W{6hyIt8`Yy7>vU+$NDB&=5MHL#?azi;0@`fc07P3vO9Gmi{jM-~82ibz`vcm4MNwu$id~KK4o3IIaC4SohkXlXg#@KaaxFiJYI2sc=k#(%4-rDk>VC z->smGz)=hEGX4W9=BJt(WxyVRWnFKjrAHvSn%@}Q0|CJb_UY%(YvUpKBqtY_G@zAp z=Ta1&lM|NlJYdgfi%xrmmwiZw4DZUvQPGrg&h%g~HH#gO& zi~-ZgyfwS7BdgsKF;-xgCs1j4k+lpHzt6M2hbL|O4G$gVob7qBpam40xN*}aW#m$x zk9|jsE5)I=9r0R~ySzvksL@=MfQyZjl!}dfZ&z1y&MMt~{``4yzbnF>(DCD8nS&e( z!NM4lT3{#M=1ue*`lH8Z#o}E<5)=}Cv>WW-vuBS=nAC3`>79ML%Xx|HlTcHP+j8y6 zh!qKr^#xUN966!_=}=|b`ucUn$NMCdY5@?J=;y)AIeZJ>xm*tp5{wIbw##BRdBni9 zh@Y~^A4jJ9no=Tx$s+JzxM4zx!4uF*zZ&i4=626we#wA!R^VYRt60x^xl(Ol14iMS z=a(18y+D~e098eGi-+O@q)G7660^kb%+7+Iyq{9ozCCVeTy4~WkLEVrNVk};mw;jfZFQ8 z!uBe-ScH#L2t{gsd^)VHISZs>pt~R@_~pwN)2i2Z^5=(7TC-TRqz|o`Uz%*tL->zo zI%(p6jV0~J(b+Vo#`$!^QlB|MEuJ}%ULijY#eJ<8cP6Q?{gox?YisM%W0tx)nb-Ep z4Tuci8>^*Ji>SBAw!Aj$TYu@uE@sEw?D8)vNzs9+*?XH#^gOPq6QZMK{_Fvr@v+Ea z-_I}Wl8t2KLPT{C!IYPm7SqgIj)GLz_|pTDnF%oeiJL!P5;UbB@y zZ>UN}(4Ln|Ge77&EK5S3%~xz;lPMb=eL2ZoADru1`zWx}GOs43d}0n6P7qa!ph@-F zHS4zk>11{G_C^6R3y?5!;cU*jrnc zGWJQhM_np%M`ax7Fg+a4c=ARR)%O5r#GkfU7r=70h<>;d0u3>H>S7b~1Nx6;fx4&# zOOR?(6Cp|C-H9cAI9m_=3?{FC1fBfO(y}rcrB5jqrR`&)ydW~24GWhMN>#BY%l!`* zckBoP6ChEL`oc41%CW8)__MRS`#@;b%-NCf_2o2}TCSXqQ-cgwzQZaKj;gjFrBn^d zl&-~e7Sqz8X~lE%^zHiaPW z*7=VXU7wy>-MDddS(1#2(1{aTvO%2GSoK8tH}M!)8R^+qHC0uLG`HV&w2a`n4z=Z- zEINwQk7|rhHbg5+V(Dn*J7pCe^8~SzhH(m>GedF4P4fF(cc3L;%n&{EBS3cZ(MI5* zh0UV{kw$rXpeR*P@FTp+qcoF#BMhwKL_;)tg4n^Y8jtRR+`$ZP?fl(6%za(M!wJNi z@!`Q%h9Tf_n9xk>rgq+XBzI3VkpqQn4yxB*HfZ^60)eH+AkHd7xE~NBFB@u>770Q$YC*M}5JwWnc zml8mBOjY`K0gfJ!lq{yNc_cIa$alWI0{V}_=sH=gC;WQ4K#DM@nDMx5A8g4`ICbh& zwbXJQX(9e-SRpJ?*n`vv%EB2unT@8%mkzStqy>tgtCKaAQm9m zXL$1g6!M|4d&MHiAy=OPDkHQDTy$HZM$SjqqZB;Kc>J2UWx--PDfY3lS17#`Hc_); zu25aljH}?lit~P}e;%2D5GbGk!yC!9vSKmFDWOauaR8HL;Ldj4N|afzP(2ZJj3fp# zU@Q!RRrCf2u__~hLV7WU&Li8Xg2eH7n_NoDE5njsG4myF(`;74sezzW(FxtS*aX-sCeHI;J>cVM4jzZE_P{kznJEUmsS&J>?8E=K%YV zDtI&4Q#ztueNSHdCQgC}Av<-QMkN5Tc<}-71rr!#0G1UB&t~N`|d-mG$-WJBs({2>Y zS{|zuMO3M~|J(i}&T}@z34)U0G1QXNr%%Ig$VbXc4D%D45xZ)PnxN4qdXm|MtU9Hx z^LyXBe&a^{Zp++Ph#oO^BWR@~hwHx!#25$$d7`d})dp+hLxsVv1ZME zn~^gQQEh{9c*Ut4QGytaaln!Z<+*I(uxqTRA`zrcFBTw4|4LU0$u!p=n+R_i7q|)d zUW5st({&t5Q%U?yq5H>=P=FF+1heLB+t|yWpTB}UNM5Xlet3!1PYtywJKYaHV&Zc^{zVk-kOMS5 zQ~*bP3U$~lemaJ?6Jvq(QS&v+It|E+&=tg-3|aRKzh$-_u3II`$_TDKbtr;lRwHe{ zCfK{T1qC03+*0EsF&KM=`M$89h~Kd{M?18#ZK7xMd1uFa)JbKt`dme@2;pDj^$K*c z9t8{dYhEKzS=rjwfoh*@UtUtjG(3;6GJ)u?B`f+H?i=M>e~p4z9V$Y&AQk2O1tj3n zafvHjFDyfQ;GEK`iWiVHdhw)O<8o-%h@X>zWD9&#hs%-+7H+EVq9h&q=muSk76F3~ zPc;^_Ogi5>@dMsn{BH@*MCK~4gYf;Z*>xJoF`5ear(XMhn#7^+-@lWr^6?R4EMnLd zrz*I}Zjqx=)j{!hj=-eQ#bBq}QYiwrcwA{ZaT+KaAt+>~O5TP|YlA4kJcSi{o)CMV zeA&G~ns|S=%U?$+)Q!V~{!#4Z=jBy#X2gFDi&^QIfjH`A`(OJ&C{2JI5*?$Jf*lM* zZzz(A2~da6uv}yl*1@{9YmFCXMtABQg@~o^m%OmkTe4h#oP`iAs2O>Ow|#>h6hf{k z@SsA9Fpm1>@6)wL%33G}Yr`K}#bAP{?!LRAEt8L;Z%0Xlz z2zESQ{s4jf3G!oE8Fy>T(Oq)x4;Ok&p3F+Yy)-uQIbrBf%eG*8i>J=3)0MtvW%e5; zz8`(WbL$7A<^!)#FRje9^M;0TK0Xh3=8Gv*3|iTAo?Ku*MWNp7e8UXhQw4S9=ON$% zE{&v@2(KZ=^A8^ET<*@&yQ!3EaiBG|Uv!mOS)aG^@5LvRX&%`bN>HiiVqNc^p8K`E z`V|UbD?7Ut0@)k|@Ji@jz51%gPTZn*meK5C7QlZpj;wvCINcvTB8!NPijjVTKtUL2 zn1qJDW-0qTK@ufl5zOo}yoJ&o^`>cSmbE0}Y-*D=kr)MXc$mfCy?Y0%%Qci<#lRV) z1ZNAaih?c7Lliatvr4aorEm@T7-xe) zhYdSl6<*kxpdzoWj&6>REH=27|56D4zxwB2-4%wk%#zI7eiXywkF4j3zWi;cfw$q2 zoptZx@5IZjihE})(44(CY!%d7w<`~+KDXbVx?*{b>%iVp#DJZj%9|9s#%?xe)Bb%# z?d0?0tq){UZFh>_9sAN9W@Ud>XI=FlD?JmGxVyyZNGUX3v4{MB`A?re{`doo6Mr<3 z{ZP&$oJ7FsA(>zYk}(gZ3sEY-p~fKr8zaT4MeV?YOH#m?71o1ns!h}rBgHh)Eg(#y zBSD481!N+hX&@Z>AGdEq0^!)kabA3WdZb}y^hXx}Wr9wQRv%f5l&Sq?>%Fn;I1U^* zXWjkzF_050Cpuq<6yv@343KmF`UJc2ZaIRUlp|ynK{c^bpswVv(UA){GJ2Pn6ws3`g&T+rB<^y2Os5YeWnwFqv5R>w*uf-Pv# z2>GCZValZ=0rWd(6|(E)xFsdm+dEutXbODnPL)djB)Uqw^l7VRxK(ybHZ43M4<0^L zRq{FXe)>nBe^3gq0tW^T5iMfJz;9q2!JS0x*lksS?4(Rg9+Q)4#{9^vfGd%(;lS6E z&l>#v`KV>^ypGp{nxM?HvBO)F?kEYF>#U-%gtD^jj9=QhvCp^bhqk!58{DCBguZP- zRPKVq2-cN|Y~5FvFp*6zC8`g#%285vtW`#p)EKY7Z=R)s-RL?ex|zWz&sj=-Lr zBZi0R;a1+^Rn?fNg67S&3&&Yl#7O~7gc%e@YOt_Y6*IfW0cu475~ttf!uiA_{D2-n ziB%(5AGzLf76kaG@^U#)kM&T#c;~KUI-%+AA?dx;K*G8eV@MHoW< z3Nc~t^g9yl{iT(B6u*Bzq}Jr;>wDX(+omSCz(Rm#ZAI=ori2*HgX~CQYmvg(+L~U} zM%`qBb=u(RAr?_l3)<7Jr*j=}MgK1bsgI@g>%xlQ1(*OGfd&(9c%`0{8YXL49Na;G z{cctA$iFEL9&qD3 z0=va|^ggvY4Db?e2JjOOb0Zgc_s9hIC8LS3xwYwbqpyVrrR7VK0YNWNy@YJ;i{%A>qv)*-d6yGnqN1!!cT>oOa+wj1SGwd zq|-Q#;&llR7(_Xj;_I_9Win4VUO^e=RwH-M5|4|d=T?&ikwmqfg)WFHUGZPCM(qOl>*X1{avaBvZQ`t@#l@1i z{dC)v3{btdhQ@n-#KvX$ha`2Fj@-a_Y3iylkZ>4OA_}80{OrZth2ljmY75>McQ9RuqTkp$EPMxIZf7H#Uzqap{7a870=A10L5{OqYG(@ZTg@`(95qHKV3A{(C z{$#Pb1ZAoXW`l~~Jrwg`Q0uwL2#W_a$x-Z|>^^1z!y@cR$o`{RE(%y9<;|{E|ept@1H)XqaNe78Z~@@J{Bi`bT!cLZQby2kU)T$>TFfgq48( zR}~N7NVnoA(Q8VO;t0&6Dgo_Ie80yW0|wz}0-5dN4o3YHs{$mUmTWSraKl;9yb6JR z3T;5LwzF85S({Gab?;JKvHu~1yHGxS`p4kd&}S%so{wIVyhYFo_jW zY{?t`9<6J=pFiILVlN{5D?F&U6}7mPC!mG%al(j1B%l5EUfDNgJiHl-rL*nQf<43oS)f{VP(y4~SR}l`_#-3L z6}Y;#Fp$QKZ?7chd8W&63jpY z5vc_H#>F38f2U1+A+UJBNDyJOa^1oMgO_RVLk2_CVWq;pY-MFd61lxS$UC?o zo^%hLJ9{=tufT;%H#hTgMR5E;L!ud`KM^F`T$dN`B{;$e0-;xqP5MbC_)nAXc>g<$b~(1!Ux0F5gx3;e# z^6t+Ky95mHLof2{&xRWRhHxr4m4C&jnfdw1EYeuc5}On#lcam|`#STcWC?ISK4E^k z4%Cx32G2ujuEi)52MV5%%=is7z3{_RMcUs$VtxYe4_8fe!gh3kVSz+Luj~FvF8SN{ z!EV|TaPu(`>jT6F?a=9gUP)5!?O|jj)&?wp&hnPcTej4L(=x8=208+);y8Z%c-!Z9WdD?@ldSR@6ZLij+=FDyV{D4X@~e&z!{{TCb@bLe%b&a1pG zhKE}oBBlqr%O@m6FJR!%O7+fOJ`jE{G<1!n8bZ44{r?K^mGG^;&x#&5Z3H0VCE=xc zC>H%+Aix-iV8G8{UG7;NgSy-*$adiJLLlJ6H%*@>hMk9>1SJIma_;K-i@Auwa)Y#U z-d*Me%qDZwBOwt(G@w1utIA+-3QLV|{rdIL|Buy8ZrsMj3MT{+w284IgBicT{NhE+ zfguuCBX9#qBFGcwmp6}T)CW|>Ur z7UD~T|5TM^V;H<%9=`I$xYUPj*YRsF5eFWpvy-T3xC)Ts#zP#0W4(GO#M%O zOwswl!vquHKI=$wghnQG#+1yeL7j-rEa2#fn1|qOOI1SHfd`U}nK>jfKT#=M+N7^K zoNg!uPKNdA88B~wu5#!XB(uKawd~11U4d(fg+Xk%C1d2wh%8BrEaPRVt5~7x5V?2u z4KQ2ghP`;5GVg1`{}znIQzD^a1(m*kf5wL}$gqd7)Y;sVUmf3oGp&KhUQmbA`(fkUVTaH%Q<*?h?r*f;}ba~OdTg#PDga5 z^a94re~}~#ULsYaOP{&1*nre zx}d#4A4*nSNoPz&$aCm{Aom3RR0hHG39K#gG9+hspTpSY!;DA5DgiMsf`ZI-$OCh; zNdgGO8io!e^nC?J7~oi9kOhnbn?SBmp%CXr;YyfOV!^5pC$CDe1IG$n`tS={TG6Z6 zR*fUj0_P%hT%())_S=`3Y)w)B8LE9ivqc*iSEN;t8Th!jJig}{zn;aw@8oR4dvP{$ zt%XZXxFlkFr zpuhq7n3XINgw;wgTWK)qX659xWGaylM86572pD{%B_?cSR%HbD@?)&B=r2KmMc{`i z-0u$X5M8~%iNoD$8_RciV~xbXL5`F86}46d1qIbX)gtCZ)M0u5IoWR&C;dV2dW2=? zo_+AU`~CY#|6naCOeNTzgH`-S5aN*+{)NM&fNO~@aWj*!#(Fv?6)CCv&@ybk6|N#i zL4vMe(l8nzJ4#T5|7@i_J%S5?nVeJ=13<70WnkA~{D}h2ZH^*$6Zoib=C(W6)}upQxANUg;BtqlC><~GcD%wq;rX!O&~;e zw>ZE6D+U$Ja4EWD7G@Lyy|3KqC+~+zVygSHD?EFvf3V%=;Orst;`~#SHBz7|nC1vy z;wm(H7BnXmo5w5bnMA{t^&Eq@$$H+xGJVo@Dfd6E z-(dg0xHJJVps9USWWyTh5tFd{5=RaCRU|^7y*0fQARotwd~tIk3LCUX3v>Wt(+nah zAcByngqtH>7?4G@{Z`=MUXs?Kvy-Bd(DdLkh#QhuCs!r?T)+NS#J6{}=h=VZ>P3QRyS5mkiap=0O?Nc-eu$F@SMw%jCS$liT zU$Rd_>rA8v|o zq=emTUHdCFn&u&REMmzKvhGf`k|6I??^E_#1qJ3&_m?2oEABYGh-|hC>*+<5w91y%Sq7`~y&*dHdR)X+`P7Q{aLv!ldo*dn<=R1T*I?Yy?8Snni8^>LomI zhb3~bu>{y!rtD#8_Vl}V$3m`(GzszQ=rjzvfO&Cj7RCOzS2Gt3#De`P7|f?fDlSa2 z5=q0Xo`(qZka_J*TWqf#|w?T!Xk`7F;IUrJrCs9Z!e6dULe+G&o z@nL@QT+WfcEZ|Qp3TTuu2Iw|vvjB8cemHsG-Mt!44dsWMv)ZA^+b+ypCw|7FWJGcj z4L1Jri9D9Q7KBMOhVx%VUD9-)_Bx|)QiN8O2QUWPT7*EbySt#l^*^Nc=YD6(FEdE7suQZ`-khEX=9$D60HPHx9`k6k6!-5;}-iDgq;&t13PDk*oaMT z5(qj5zY~p>5B~{`9-4u=D?<3;;>Ue8r=Q31lxWO`WR!^L95_!bw8s3Xnpdx0eFViA zpl|?%q#MIaHx56z4j|LtkQmZG=H^C0MXB*-A31@E5$He!aKpY;z$xh^lIa}w-WxF1}K5Wx7khdW+@0g3#DsdPni67R-i?L^m?>dl*(#A8B? zvG3o%uV*8JEwvxnhcrg%x-Qxh%J3h^`KK;gz#_hrFgo}Gp$N@`a;V2@A-?>>lFB5( zbr6x6d7K!0h>->b6=av^uiK$VTs@ENAV5EX)aV31woA+<_mY7DF^UQ&0UNNm;a!UV zbDSFii}WpE0;3#x+BrOIG2lbq40{ZcGRtW;nE+2@g5SwRp`0Y+2FJ3yyL<9q%wDh? zPJ$*gZDODU=K1FW53!(*CAJkp>|^L+#c!>)x|Z}=0id2t8pS*0j4PlXY{FhD*+%Ws zrAsRwsowXV8%gB+T#ls*&__fR{K;0tNoamXd6sCyld zut=K&EQ7L0+%+T_=@(L9hCf4^i){Vz0SXfdB%rRbUxy@dw17Lw$Pz&&azK(L90R!K z(!2jqTE{HgRN%Hz*3kvzY0b5-ao@Hw6Sy%G5O&bcS*)bvf+*Hi5z_vz_ZUyuRKi0t zfsDs}{I~|J+pr&R1l_To+36+mih=d~XEE?-fcUM$Y@I;IC24}Ox|Zn|1VaKI0NJVw zx6+&`O|GBL;Ml4U@b@93s-i?C+8T*EKB(6B!kG{y>o0IL;a4137>b-qRf~|_`HnnU z{YZhy9p0FFL_EZLn>{m+EkyjV!vx(EZ*TqZPvqo)%v)HS?@j2PlPAU0^sZdF0+Z^MYb(Gf_$*lhx15Q?5P6jwP0yHY$yO_K z7#OSlei8`OxrrJDqJ)x`Nm8YubpSjbJvR6P1J?`Yj2LCuL!4ftifzs|B#(gZaWwHR zF)^cAwi>>ZG#K;U&)Pjb%>p7NmwQwDZ7D=VwC zwYBx*)0Mfj3Cz&6wVfUNO>*`dT*`E$DVy^1>RvbULej34O*AO6|9AY!uD#uRC%~Nv zI*~}U#=VjrV#2HeaYrXugNek>2r5K}00Qp9{#qn6!~cTf+k0Cqxv)V=_#y!PNs_>Z zS_w#81F&YGizH^aEiV=p9kIViJfubJW2dKrC`FoMNG~BA#8v35VrOP%<`yLcM4_nR z>`G$#?{k;ag_B4%5riKp?Q`Y72-caH{jlksB5BAtH;dD_=@1JKoYwylKiSjQXY!jU zE2rzy0@}~2;B6C#xj?=<0mAI>d-lYEEz?THu6x8%hD9o0lg_7`8-+$9m|pu$o{{+{ zd^mZ5P2<0!6^q**!9Nm4^~F3T`!38qFtxHumglpv`8-CToA7=D-IARs_F9DZMa)zr zec>qcs`D8CW(nQNBoLnen z86r}W+_I9whQH26*ZwzHYhPJ$22wwvD{#Nv=?GV4P$CcmoO`oDIS_|lEgYKw7ng0K z$-{5`pO<;j>IsR8FyW*Toj89N$bM`8EBY>*KXb?lTWQkg>+Y^ZS~lj;lR`pbn)D*M z9FarM5jf*L3YVc@qqWZ9sWZrAAmn31vjY*sAK0{vLfqo%r>?1%{~;vQ_Lz-WsYc;O zIaf5^J2kW;_Teh*%V6;9+}4gd0?+<8wvGD#Kl^`9BmO`6E)~oJ0R_ZENp@%;t^+5h zh5ZeO{bjM$Sp^9qXj)W6zO6y4q|fL>1k`V4#3#3*@E%iyE){X5p$XbLAyZs?z) zjVD@olDeh~tLPaPYTi59}0N8?7t|(>vk5v-eQM z4-lW2uq|?vGBktlO=dp{Lmbx>N)Far?&&D;N1acBkD3^9yw!&R_v3lTXRd=l9PmNbtO<4lqu z8v7{1s^D{gHR{*+0gOL_mZ`btKkZiIr%%9I(ylMK!~dH>8F`e-54^lGY&YQ1lTD%x zcagns0W63Y`S-IhT}n_^3H*ja7R7ELwgDo%1TYtoQc1u7N390FCSGPua3gli2t(t~ zxPeRie(uB1Z?31P89~~EjtNGI6bk@)aBz?m9Ok&p39|;o{>H|I*>UW9bb$m!5r>KM zz|NgNei#fSyV+zxI@NXPV?*0S4BAgf!wr@S+vrJV&8@4}CcTD32X+N6lsRbeu0^vv z(OZbx10B8r;q3{xIv)t+#?P23WIH<;|F8k_IXKRJpqc|kCL(CX!f6DTSc+ zh{Xt|wNR)7Q78$py>hm_(5bbW8}N;{KpmIBtU{VsptBH?fY=a;Mo!coFjm@rzi6~B zG8YnN1o$5tR_+`7`LhyyB0b<9J|2aF3(W<@!*cb)A@8FIqu5~T5tKw;ur$P{2$ZM> zsn1+zm%b{o!j^FP@aJJiFn7=r#Fq}>FXKaSDe7dijt>t=DS{pF=yt1b($KS-(MZ4Q zPjTgtlro}*<3Iw9a(q!LG>eHQ902OS_VNtrUcuAKOkJxPfbt2})zs@h z&QX10V`EvP!5|QS2HE081gqES{B4~k_`9(OMkvUT0R9wn?2`Z;J`_Q%7x)&;{;xm8 zz1gm9VDP?`Kij<<*;xjagF5JB)%2^HLD)8h7`LE85TzK7Rl>CrsQ}QYu7i5cLJ&=poY*@?ZD??b@`2h{mXgjC}If&h_O~@QUkXYG`0}+jwR+ycB-om!wB2*;| z?nJ$|Wi(n4e9q2@Kl9mh0E`Ezrbv%2>JZXA0|6}@Dgh_p0AU}xe*8EieoyDnE>E&Y z25Kg#7e~-cip|gNc6WiRR>F!QI|;w%S~YNyfrsUG5ZqIE7Uv;RgYJ&RZXDQ6ZzFsU zPgqxMxnSL~1nY9_EV1Pj**@ln0FU7YSUliRA)8ttMCx&3*=AU_tD)c{Y8?z#9N4{z z^xNPHKHs+SPo|&h&zo+m0VhDVG$K1~WMBLK2Wf0V-76$Cns+dnX7w-Lr^pU|*bfqN z98v0FjzA@)28OaO%}gCzOo?L?fyYGMF~LX%9J34tj+n%niA7y^)P+X;W474Y+@J3@ zk-lECY_aIczUr9Sfu7uJt9zCVPC{QI9S8{JS}^kVwQXP7%m!WXpPq1rqGG|`dQEq? ztwA~jmL|RG+24aA>P*Lf(slLpb zFY|p|%jH_?;hgjS-}kxfXYc*usI{3H-cHsijR|-MD z@`DQ;&g-xu+pWc4VlrjV#v$*d1D5T=3`x)w_F-E|C=2{LM)rD8bcMfSx4pnnKBq1;xR^Q7UY9CbIZ_URa4{lx^!C_yrbg45s^7BUrbt4jK5YBxGFVLP# zI3*S*Uc4n?S+iRWpz=r*b+uqL4-bzGkzFF49yB(y8j{^X@)2%M zA&D(7b0&C7$=f(V;R22^WV2W-@~V-tMwtU zct`+;vU~5WC9&gg&1t*%BY_q&hJ$Y=RT%MWFbBG-_)~vGd?m#BMSvFodZN@~LU|w& zt4Y(i!)r2}Nz$T~YZu{1<}cm|j~Ab(*Kl&bjMBwFG_2!{-h|i`iJZH~9 zC-d0kO-`I&_2gs=w*PQs)ML($Y}_PG9r%TY70l9n8T8Mc?%wdaB_^Kv%o~IPUogNd z6f@dQ0BwH!m6Tq5y@^arOFqxHoHktUd+tnrMnhxKCY(;q^;+o(J-_C2-G_VRS2kvNXX~EnXXk zo`t65?$s>ix43&{VFbc5iQork#E(@_)YemFS07k(&}$?ADsWhBAt|Bg+Ir+$n1DG$ zZUPJ1Sl?_tb86rnE05lsLYxIIVRaz-FF~mgo({##w_t)(Kv>^_C4dH^gv~s6djH@> z1-@&aAY7KGpm9z^`Ee{-lK7nl;U1z&0N#Oaw%@nD`5_Go{KTyXuDov+i$<&@sWj^9 z>R^$sXW?GGZ!ix3?hrJs5cec-+AS1mwNMo;b4mC6?zJr zPJJCsHa0frspz$8Es0nLM7(%GK+l_R@P!v);1G=!TRQ$@D&OYJ$$3#e7d-;lTy7O%=%w`4N^q_oLCEGJ|wTgO^OmPGDGqW`3vW z`#+sc@rTUi@~~CX)Z6v9s7LXVso@h~p=uWc@~mik6gB5d@0G{87FxS^r>>qD05Mfm zg_hAXHH!dw3>sdhIPs?q;QRJpTX5R0P>SDF1{`IT zM`XI|4IO6s6&yP)k0J3GNkS#;am}<9;D~3wz=@*cB)U%XW0LUY{5-GgxsD4Rm(N99 zq6QrxfCt-OV{i5)wOYuf=`)|lotYib89bNhB8`gN&Kc5YMK?v#>P-+$6s%7b#kBS9 zQQM-+Y)SW=>i;=!Wnr7WcL0JK7B4JKPuh{mH$gv5`g)NTFX9zVey@p4P z9y6wPRP%#)N|zlhWZ_KrG)^V?ZT-O%I4&tpPFs5IGbfk0?mulo&!#Z(`2Mwc>GVM0 zXZ`FRH*O76zWj~_4qt2Tg_2IJK}I`3g5QwKR9;n#FIW-PFeD{7 zzJO1Fbjt*brNMoz-m{(IW@Sh|RB1bdq3d~G%LOpQ7;H@-r(YClovN#r04f&@qn&5) zP|L~FWH?z`z;543ZqVgCDw)K3MEXZxGA`AS`T?JF@eRwd2l6Nqj{$Po9d}?9*>~~G zzL7i^yS`oSaf8h{D8BJyn9SE~>*}OVlgHmF9L#Ig1K8 z<~x|lh>boZy+i(^=T>wuEDpEqiHDMNHIb`Gly5B}T`6~>f)Zzgpv#PV`ZLcB^!oFl z8BTZ1$$dlEL`so-)E7Yv#QS5&ticsXM;bL?EnM8tOVi7p)ZFjrIXr5`%9S;V{y~!| zDFx-?H=Yl~qRo1*790%606f13Iy<>HS_*)X`K7 zD>Q`z`L}@7sfT_6?OICoFGfIS@onqh{ zx#myrS{Wu5le`jnH|ABf$e}U9+zK@<3snh$!z^kX$R49Q)u0IGI(i0op`;t2Mh

AYV%&ejf+HF@tK|Zl5(&!igQDcb^(3%9!7Hm zYdhb(r2oSg6gLlkfgc&~!Y=;-cZT&f*p6ZGQl~{EWJfoBnK|QFCY6ac9`(k+Mnd%r z`7r8U8A9-ojAs`|73@2t);zTUG~w&`p6tOO?Z|!Pemwy@63q$vTUDRPq&A*hJ?2Yu6A54nkN8o zKaRM`M~D%~DRrSq4Kr&=3GQTCVO!RkzK2YputxkOOH77zChG*SUAbpNiI}=@;ow&1 z;uZ#5I~p8|Eo~F(FYPxI^bRqyUlM>sQHH7vajSvObm!6TudUA)lL9yfr6vC!)7wPX z?mc#F0x&F}L*+nud2m<%r_1^XK$`yK(VdC13c%iSjaaptJ(l>e64$E~=0=aUzS{BR zySm~>kBtSr*LXDs`Bwl*Qj4{)v2MB?{7QR275SkXQDR|}%4rg#jGAcl{o!<}s8`h` zP77&zH$9A(P*q{N9IkSVlbxckxHX2r5?Zn3AKXA)l+*jG1r zRe~pJtOjoPKIXklty__9%UlIGer`<3Ylk1oC#Y7aWN=eVONm12RaM)qFxS^h`@8@# zL}wxzZP*+ZSyn7yG794hQuw&l#9#qlhyGTd=^6d7Y;S1!aT#L*eiecdW9BFUog|VF zJT^9U17m@Q$gMS9lzu2_(qP{V$*|uWBhzjVOfA7s#B zRT`D~=-)u!)Da1J0T###_#yYCE=C|Js*$PjY+K-6LKw)xrgKgPA00pJa)wmPr zSl?Nw8(gMmA2XfTN$ooIX^#=vc1}Df)lZr@D*&_O=wX*$kE>>Fr{clTQqLAewLbhK zd@>Th3@^k6k?O9N1=fav3kX`-i-TC+PZw~Me=Z{r1ssYJd3I5hcpvKGDN88|#6Twg zj#0UpoYpX?)%yN-83;L}LFH)wn9-wQolxUBUr;_1{wFfynan7d8R~c)a7^Yw1MGe4 zDw5`7@vF%`MdG*9<{h0?8xcFZ4fuP5SHWyI*{1_HQv8B z(F+te@-r8hh5Lla>gY{H5FwZwepx|THc-$(O5We&<(?4_7OL)<2o;VcEK%>D$AhJ* zl;H>olf%w5R<~os8s(p&%q@_gyH+kKM53-!aYLj!@6kEUD7PnJD}I@Q{Y>pW*yKZ= z9DsIZ`Hv1=#%VHRm86Ivj?$+5Oeun};<3oN4CGbpr#UAiX4rSE`N=eTz*2m>u=U!m zJ;X*78JJFTJJSSZJC;DcsX!P6eo{(o-sVE8+H}sJ7(CQ4DlrjVxyHoAU=kz)7{MOx z+q6QVknomlAT~tb%Ig9)wRx4TLyJU+Kq)`h37dv(WqzMV>X!MZ?UX)I(0p4izG_D?zw41Tac{(-q2lXoI;vP1Ug2O;!+bd47K-2{q%vgm{KZq zO9&Uv2LqbfLTesI_}AwF`N!Doxh5;a*r8v=8Vo*Vc|6d=s==0J|r3s*X^~fc{fT?er2p$Q2pDVY%a!-|M!0QA6Dau?a6@i4Lj10 R?9hu$uTB`9`PbkNeg(6umzn?o literal 0 HcmV?d00001 diff --git a/use-cases/eurac/runall.sh b/use-cases/eurac/runall.sh index 6169366e..03a4798d 100755 --- a/use-cases/eurac/runall.sh +++ b/use-cases/eurac/runall.sh @@ -12,7 +12,7 @@ if [ -z "$NUM_GPUS" ]; then NUM_GPUS=4 fi if [ -z "$TIME" ]; then - TIME=0:20:00 + TIME=0:40:00 fi if [ -z "$DEBUG" ]; then DEBUG=false diff --git a/use-cases/eurac/slurm.sh b/use-cases/eurac/slurm.sh index e1ec58b1..e907e54c 100644 --- a/use-cases/eurac/slurm.sh +++ b/use-cases/eurac/slurm.sh @@ -100,7 +100,7 @@ if [ "$DIST_MODE" == "horovod" ] ; then srun --cpu-bind=none \ --ntasks-per-node=$SLURM_GPUS_PER_NODE \ --cpus-per-task=$SLURM_CPUS_PER_GPU \ - --ntasks=$SLURM_GPUS_PER_NODE \ + --ntasks=$(($SLURM_GPUS_PER_NODE * $SLURM_NNODES)) \ $TRAINING_CMD else # E.g. for 'deepspeed' or 'ddp' srun --cpu-bind=none --ntasks-per-node=1 \ diff --git a/use-cases/eurac/trainer.py b/use-cases/eurac/trainer.py index 53c50202..88ac42f5 100644 --- a/use-cases/eurac/trainer.py +++ b/use-cases/eurac/trainer.py @@ -1,7 +1,7 @@ import os from pathlib import Path from timeit import default_timer as timer -from typing import Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional, Union, Any, Tuple import pandas as pd import torch @@ -13,8 +13,10 @@ from hython.trainer import ConvTrainer, RNNTrainer, RNNTrainParams from ray import train from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import Dataset from tqdm.auto import tqdm +from itwinai.distributed import suppress_workers_print from itwinai.loggers import EpochTimeTracker, Logger from itwinai.torch.config import TrainingConfiguration from itwinai.torch.distributed import ( @@ -25,6 +27,7 @@ ) from itwinai.torch.trainer import TorchTrainer from itwinai.torch.type import Metric +from itwinai.torch.profiling.profiler import profile_torch_trainer class RNNDistributedTrainer(TorchTrainer): @@ -88,6 +91,16 @@ def __init__( ) self.save_parameters(**self.locals2params(locals())) + @suppress_workers_print + @profile_torch_trainer + def execute( + self, + train_dataset: Dataset, + validation_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None + ) -> Tuple[Dataset, Dataset, Dataset, Any]: + return super().execute(train_dataset, validation_dataset, test_dataset) + def create_model_loss_optimizer(self) -> None: self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr) self.lr_scheduler = ReduceLROnPlateau( @@ -125,6 +138,14 @@ def create_model_loss_optimizer(self) -> None: **distribute_kwargs, ) + def set_epoch(self, epoch: int): + if self.profiler is not None: + self.profiler.step() + + if self.strategy.is_distributed: + self.train_loader.sampler.set_epoch(epoch) + self.val_loader.sampler.set_epoch(epoch) + def train(self): """Override version of hython to support distributed strategy.""" # Tracking epoch times for scaling test @@ -158,11 +179,7 @@ def train(self): best_loss = float("inf") for epoch in tqdm(range(self.epochs)): epoch_start_time = timer() - if self.strategy.is_distributed: - # *Added for distributed* - self.train_loader.sampler.set_epoch(epoch) - self.val_loader.sampler.set_epoch(epoch) - + self.set_epoch(epoch) self.model.train() # set time indices for training @@ -368,7 +385,6 @@ def create_model_loss_optimizer(self) -> None: patience=self.config.lr_reduction_patience ) - target_weights = { t: 1 / len(self.config.target_names) for t in self.config.target_names } @@ -489,7 +505,7 @@ def create_dataloaders(self, train_dataset, validation_dataset, test_dataset): processing=( "multi-gpu" if self.config.distributed else "single-gpu" ), - ) + ) val_sampler_builder = SamplerBuilder( validation_dataset, diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index 8226e070..5fd2a3f9 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -43,7 +43,7 @@ def __init__( ) -> None: super().__init__( epochs=num_epochs, - config={}, + config=config, strategy=strategy, logger=logger, random_seed=random_seed,