diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..697f33f2 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,131 @@ +**/TODO +**/mamba* +pl-training.yml +.vscode + +# Project folders/files +# use-cases +workflows +tests +CHANGELOG + +# Docs +docs + +# Data +**/MNIST +**/*-predictions/ +**/*-data/ +**/*.tar.gz +**/exp_data + +# Logs +**/logs +**/lightning_logs +**/mlruns +**/.logs +**/mllogs +**/nohup* +**/*.out +**/*.err +**/checkpoints/ +**/*_logs +**/tmp* +**/.tmp* + +# Markdown +**/*.md + +# Custom envs +**/.venv* + +# Git +.git +.gitignore +.github + +# CI +.codeclimate.yml +.travis.yml +.taskcluster.yml + +# Docker +docker-compose.yml +.docker +.dockerignore +Dockerfile + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +**/eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +**/*.egg-info/ +**/.installed.cfg +**/*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Virtual environment +.env/ +.venv/ +venv/ + +# PyCharm +.idea + +# Python mode for VIM +.ropeproject +*/.ropeproject +*/*/.ropeproject +*/*/*/.ropeproject + +# Vim swap files +*.swp +*/*.swp +*/*/*.swp +*/*/*/*.swp \ No newline at end of file diff --git a/.gitignore b/.gitignore index 41349b5f..2f0ad142 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*_logs +exp_data/ TODO /data nohup* @@ -13,6 +15,10 @@ mllogs *.err .logs/ pl-training.yml +*-predictions/ +*-data/ +*.pth +*.tar.gz # Custom envs .venv* diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 2e7e2e79..c1e6e372 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -1,11 +1,14 @@ from __future__ import annotations -from .cluster import ClusterEnvironment -from typing import Iterable, Dict, Any, Optional, Tuple +from typing import Iterable, Dict, Any, Optional, Tuple, Union from abc import ABCMeta, abstractmethod import time # import logging # from logging import Logger as PythonLogger +from .cluster import ClusterEnvironment +from .types import ModelML, DatasetML +from .serialization import ModelLoader + class Executable(metaclass=ABCMeta): """Base Executable class. @@ -136,6 +139,7 @@ def _printout(self, msg: str): class Trainer(Executable): + """Trains a machine learning model.""" @abstractmethod def train(self, *args, **kwargs): pass @@ -149,6 +153,57 @@ def load_state(self): pass +class Predictor(Executable): + """Applies a pre-trained machine learning model to unseen data.""" + + model: ModelML + + def __init__( + self, + model: Union[ModelML, ModelLoader], + name: Optional[str] = None, + **kwargs + ) -> None: + super().__init__(name, **kwargs) + self.model = model() if isinstance(model, ModelLoader) else model + + def execute( + self, + predict_dataset: DatasetML, + config: Optional[Dict] = None, + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """"Execute some operations. + + Args: + predict_dataset (DatasetML): dataset object for inference. + config (Dict, optional): key-value configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as + (results, config). + """ + return self.predict(predict_dataset), config + + @abstractmethod + def predict( + self, + predict_dataset: DatasetML, + model: Optional[ModelML] = None + ) -> Iterable[Any]: + """Applies a machine learning model on a dataset of samples. + + Args: + predict_dataset (DatasetML): dataset for inference. + model (Optional[ModelML], optional): overrides the internal model, + if given. Defaults to None. + + Returns: + Iterable[Any]: predictions with the same cardinality of the + input dataset. + """ + + class DataGetter(Executable): @abstractmethod def load(self, *args, **kwargs): @@ -167,18 +222,12 @@ def preproc(self, *args, **kwargs): # pass -class Evaluator(Executable): +class Saver(Executable): @abstractmethod - def evaluate(self, *args, **kwargs): + def save(self, *args, **kwargs): pass -# class Saver(Executable): -# @abstractmethod -# def save(self, *args, **kwargs): -# pass - - class Executor(Executable): """Sets-up and executes a sequence of Executable steps.""" diff --git a/src/itwinai/loggers.py b/src/itwinai/loggers.py index 1116c503..d04becd7 100644 --- a/src/itwinai/loggers.py +++ b/src/itwinai/loggers.py @@ -9,7 +9,7 @@ import wandb import mlflow -import mlflow.keras +# import mlflow.keras BASE_EXP_NAME: str = 'unk_experiment' diff --git a/src/itwinai/serialization.py b/src/itwinai/serialization.py new file mode 100644 index 00000000..a7b70cd3 --- /dev/null +++ b/src/itwinai/serialization.py @@ -0,0 +1,14 @@ +from .types import ModelML +import abc + + +class ModelLoader(abc.ABC): + """Loads a machine learning model from somewhere.""" + + def __init__(self, model_uri: str) -> None: + super().__init__() + self.model_uri = model_uri + + @abc.abstractmethod + def __call__(self) -> ModelML: + """Loads model from model URI.""" diff --git a/src/itwinai/torch/inference.py b/src/itwinai/torch/inference.py new file mode 100644 index 00000000..4d7797c6 --- /dev/null +++ b/src/itwinai/torch/inference.py @@ -0,0 +1,217 @@ +from typing import Optional, Dict, Any, Union +import os +import abc + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from ..utils import dynamically_import_class +from .utils import clear_key +from ..components import Predictor +from .types import TorchDistributedStrategy as StrategyT +from .types import Metric, Batch +from ..serialization import ModelLoader + + +class TorchModelLoader(ModelLoader): + """Loads a torch model from somewhere. + + Args: + model_uri (str): Can be a path on local filesystem + or an mlflow 'locator' in the form: + 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + """ + + def __call__(self) -> nn.Module: + """"Loads model from model URI. + + Raises: + ValueError: if the model URI is not recognized + or the model is not found. + + Returns: + nn.Module: torch neural network. + """ + if os.path.exists(self.model_uri): + # Model is on local filesystem. + model = torch.load(self.model_uri) + return model.eval() + + if self.model_uri.startswith('mlflow+'): + # Model is on an MLFLow server + # Form is 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + import mlflow + from mlflow import MlflowException + _, tracking_uri, run_id, artifact_path = self.model_uri.split('+') + mlflow.set_tracking_uri(tracking_uri) + + # Check that run exists + try: + mlflow.get_run(run_id) + except MlflowException: + raise ValueError(f"Run ID '{run_id}' was not found!") + + # Download model weights + ckpt_path = mlflow.artifacts.download_artifacts( + run_id=run_id, + artifact_path=artifact_path, + dst_path='tmp/', + tracking_uri=mlflow.get_tracking_uri() + ) + model = torch.load(ckpt_path) + return model.eval() + + raise ValueError( + 'Unrecognized model URI: model may not be there!' + ) + + +class TorchPredictor(Predictor): + """Applies a pre-trained torch model to unseen data.""" + + model: nn.Module = None + test_dataset: Dataset + test_dataloader: DataLoader = None + _strategy: StrategyT = StrategyT.NONE.value + epoch_idx: int = 0 + train_glob_step: int = 0 + validation_glob_step: int = 0 + train_metrics: Dict[str, Metric] + validation_metrics: Dict[str, Metric] + + def __init__( + self, + model: Union[nn.Module, ModelLoader], + test_dataloader_class: str = 'torch.utils.data.DataLoader', + test_dataloader_kwargs: Optional[Dict] = None, + # strategy: str = StrategyT.NONE.value, + # seed: Optional[int] = None, + # logger: Optional[List[Logger]] = None, + # cluster: Optional[ClusterEnvironment] = None, + # test_metrics: Optional[Dict[str, Metric]] = None, + name: str = None + ) -> None: + super().__init__(model=model, name=name) + self.model = self.model.eval() + # self.seed = seed + # self.strategy = strategy + # self.cluster = cluster + + # Train and validation dataloaders + self.test_dataloader_class = dynamically_import_class( + test_dataloader_class + ) + test_dataloader_kwargs = ( + test_dataloader_kwargs + if test_dataloader_kwargs is not None else {} + ) + self.test_dataloader_kwargs = clear_key( + test_dataloader_kwargs, 'train_dataloader_kwargs', 'dataset' + ) + + # # Loggers + # self.logger = logger if logger is not None else ConsoleLogger() + + # # Metrics + # self.train_metrics = ( + # {} if train_metrics is None else train_metrics + # ) + # self.validation_metrics = ( + # self.train_metrics if validation_metrics is None + # else validation_metrics + # ) + + def predict( + self, + test_dataset: Dataset, + model: nn.Module = None, + ) -> Dict[str, Any]: + """Applies a torch model to a dataset for inference. + + Args: + test_dataset (Dataset[str, Any]): each item in this dataset is a + couple (item_unique_id, item) + model (nn.Module, optional): torch model. Overrides the existing + model, if given. Defaults to None. + + Returns: + Dict[str, Any]: maps each item ID to the corresponding predicted + value(s). + """ + if model is not None: + # Overrides existing "internal" model + self.model = model + + test_dataloader = self.test_dataloader_class( + test_dataset, **self.test_dataloader_kwargs + ) + + all_predictions = dict() + for samples_ids, samples in test_dataloader: + with torch.no_grad(): + pred = self.model(samples) + pred = self.transform_predictions(pred) + for idx, pre in zip(samples_ids, pred): + # For each item in the batch + if pre.numel() == 1: + pre = pre.item() + else: + pre = pre.to_dense().tolist() + all_predictions[idx] = pre + return all_predictions + + @abc.abstractmethod + def transform_predictions(self, batch: Batch) -> Batch: + """ + Post-process the predictions of the torch model (e.g., apply + threshold in case of multilabel classifier). + """ + + +class MulticlassTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + multiclass classification. + """ + + def transform_predictions(self, batch: Batch) -> Batch: + batch = batch.argmax(-1) + return batch + + +class MultilabelTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + multilabel classification, applying a threshold on the + output of the neural network. + """ + + threshold: float + + def __init__( + self, + model: Union[nn.Module, ModelLoader], + test_dataloader_class: str = 'torch.utils.data.DataLoader', + test_dataloader_kwargs: Optional[Dict] = None, + threshold: float = 0.5, + name: str = None + ) -> None: + super().__init__( + model, test_dataloader_class, test_dataloader_kwargs, name + ) + self.threshold = threshold + + def transform_predictions(self, batch: Batch) -> Batch: + return (batch > self.threshold).float() + + +class RegressionTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + regression, leaving untouched the output of the neural + network. + """ + + def transform_predictions(self, batch: Batch) -> Batch: + return batch diff --git a/src/itwinai/types.py b/src/itwinai/types.py new file mode 100644 index 00000000..9c302eb1 --- /dev/null +++ b/src/itwinai/types.py @@ -0,0 +1,11 @@ +""" +Framework-independent types. +""" + + +class DatasetML: + """A framework-independent machine learning dataset.""" + + +class ModelML: + """A framework-independent machine learning model.""" diff --git a/use-cases/3dgan/Dockerfile b/use-cases/3dgan/Dockerfile new file mode 100644 index 00000000..515caff0 --- /dev/null +++ b/use-cases/3dgan/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.9.12 + +WORKDIR /usr/src/app + +RUN pip install --upgrade pip + +# Install pytorch (cpuonly) +# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache-dir lightning + +# Add 3DGAN custom requirements +COPY use-cases/3dgan/requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt + +# Install itwinai and dependencies +COPY pyproject.toml ./ +COPY src ./ +RUN pip install --no-cache-dir . + +# Add 3DGAN use case files +COPY use-cases/3dgan/* ./ + +# Run inference +CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md new file mode 100644 index 00000000..95a428c8 --- /dev/null +++ b/use-cases/3dgan/README.md @@ -0,0 +1,122 @@ +# 3DGAN use case + +## Training + +At CERN, use the dedicated configuration file: + +```bash +cd use-cases/3dgan +python train.py -p cern-pipeline.yaml +``` + +Anywhere else, use the general purpose training configuration: + +```bash +cd use-cases/3dgan +python train.py -p pipeline.yaml +``` + +To visualize the logs with MLFLow run the following in the terminal: + +```bash +micromamba run -p ../../.venv-pytorch mlflow ui --backend-store-uri ml_logs/mlflow_logs +``` + +And select the "3DGAN" experiment. + +## Inference + +The following is preliminary and not 100% ML/scientifically sound. + +1. As inference dataset we can reuse training/validation dataset, +for instance the one downloaded from Google Drive folder: if the +dataset root folder is not present, the dataset will be downloaded. +The inference dataset is a set of H5 files stored inside `exp_data` +sub-folders: + + ```text + ├── exp_data + │ ├── data + | │ ├── file_0.h5 + | │ ├── file_1.h5 + ... + | │ ├── file_N.h5 + ``` + +2. As model, if a pre-trained checkpoint is not available, +we can create a dummy version of it with: + + ```python + import torch + from model import ThreeDGAN + # Same params as in the training config file! + my_gan = ThreeDGAN() + torch.save(my_gan, '3dgan-inference.pth') + ``` + +3. Run inference command. This will generate a "3dgan-generated" +folder containing generated particle traces in form of torch tensors +(.pth files) and 3D scatter plots (.jpg images). + + ```bash + python train.py -p inference-pipeline.yaml + ``` + +Note the same entry point as for training. + +The inference execution will produce a folder called +"3dgan-generated-data" containing +generated 3D particle trajectories (overwritten if already +there). Each generated 3D image is stored both as a +torch tensor (.pth) and 3D scatter plot (.jpg): + +```text +├── 3dgan-generated-data +| ├── energy=1.296749234199524&angle=1.272539496421814.pth +| ├── energy=1.296749234199524&angle=1.272539496421814.jpg +... +| ├── energy=1.664689540863037&angle=1.4906378984451294.pth +| ├── energy=1.664689540863037&angle=1.4906378984451294.jpg +``` + +### Docker image + +Build from project root with + +```bash +# Local +docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . + +# Ghcr.io +docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 -f use-cases/3dgan/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +``` + +From wherever a sample of MNIST jpg images is available +(folder called 'mnist-sample-data/'): + +```text +├── $PWD +| ├── exp_data +| │ ├── data +| | │ ├── file_0.h5 +| | │ ├── file_1.h5 +... +| | │ ├── file_N.h5 +``` + +```bash +docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +``` + +This command will store the results in a folder called "3dgan-generated-data": + +```text +├── $PWD +| ├── 3dgan-generated-data +| │ ├── energy=1.296749234199524&angle=1.272539496421814.pth +| │ ├── energy=1.296749234199524&angle=1.272539496421814.jpg +... +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.pth +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.jpg +``` diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml new file mode 100644 index 00000000..7d251ae5 --- /dev/null +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -0,0 +1,95 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: /eos/user/k/ktsolaki/data/3dgan_data # exp_data/ + data_url: null # https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX + + - class_path: trainer.Lightning3DGANTrainer + init_args: + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: auto #[0] + enable_checkpointing: true + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: 2 + logger: + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger + init_args: + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all + max_epochs: 100 + max_steps: -1 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: ddp_find_unused_parameters_true #auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 128 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + checkpoint_path: checkpoints/3dgan.pth + + # Lightning data module configuration + data: + class_path: dataloader.ParticlesDataModule + init_args: + datapath: /eos/user/k/ktsolaki/data/3dgan_data/*.h5 # exp_data/*/*.h5 + batch_size: 128 + num_workers: 0 + max_samples: 3000 diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py new file mode 100644 index 00000000..d6e5a880 --- /dev/null +++ b/use-cases/3dgan/dataloader.py @@ -0,0 +1,216 @@ +from typing import Optional, Tuple, Dict +import os +from lightning.pytorch.utilities.types import EVAL_DATALOADERS + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +import lightning as pl +import glob +import h5py +import gdown + +from itwinai.components import DataGetter + + +class Lightning3DGANDownloader(DataGetter): + def __init__( + self, + data_path: str, + data_url: Optional[str] = None, + name: Optional[str] = None, + **kwargs) -> None: + super().__init__(name, **kwargs) + self.data_path = data_path + self.data_url = data_url + + def load(self): + # Download data + if not os.path.exists(self.data_path): + if self.data_url is None: + print("WARNING! Data URL is None. " + "Skipping dataset downloading") + + gdown.download_folder( + url=self.data_url, quiet=False, + output=self.data_path + ) + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[None, Optional[Dict]]: + self.load() + return None, config + + +class ParticlesDataset(Dataset): + def __init__(self, datapath: str, max_samples: Optional[int] = None): + self.datapath = datapath + self.max_samples = max_samples + self.data = dict() + + self.fetch_data() + + def __len__(self): + return len(self.data["X"]) + + def __getitem__(self, idx): + return {"X": self.data["X"][idx], "Y": self.data["Y"][idx], + "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]} + + def fetch_data(self) -> None: + + print("Searching in :", self.datapath) + files = sorted(glob.glob(self.datapath)) + print("Found {} files. ".format(len(files))) + if len(files) == 0: + raise RuntimeError(f"No H5 files found at '{self.datapath}'!") + + # concatenated_datasets = [] + # for datafile in files: + # f = h5py.File(datafile, 'r') + # dataset = self.GetDataAngleParallel(f) + # concatenated_datasets.append(dataset) + # # Initialize result dictionary + # result = {key: [] for key in concatenated_datasets[0].keys()} + # for d in concatenated_datasets: + # for key in result.keys(): + # result[key].extend(d[key]) + # return result + + for datafile in files: + f = h5py.File(datafile, 'r') + dataset = self.GetDataAngleParallel(f) + for field, vals_array in dataset.items(): + if self.data.get(field) is not None: + # Resize to include the new array + new_shape = list(self.data[field].shape) + new_shape[0] += len(vals_array) + self.data[field].resize(new_shape) + self.data[field][-len(vals_array):] = vals_array + else: + self.data[field] = vals_array + + # Stop loading data, if self.max_samples reached + if (self.max_samples is not None + and len(self.data[field]) >= self.max_samples): + for field, vals_array in self.data.items(): + self.data[field] = vals_array[:self.max_samples] + + break + + def GetDataAngleParallel( + self, + dataset, + xscale=1, + xpower=0.85, + yscale=100, + angscale=1, + angtype="theta", + thresh=1e-4, + daxis=-1 + ): + """Preprocess function for the dataset + + Args: + dataset (str): Dataset file path + xscale (int, optional): Value to scale the ECAL values. + Defaults to 1. + xpower (int, optional): Value to scale the ECAL values, + exponentially. Defaults to 1. + yscale (int, optional): Value to scale the energy values. + Defaults to 100. + angscale (int, optional): Value to scale the angle values. + Defaults to 1. + angtype (str, optional): Which type of angle to use. + Defaults to "theta". + thresh (_type_, optional): Maximum value for ECAL values. + Defaults to 1e-4. + daxis (int, optional): Axis to expand values. Defaults to -1. + + Returns: + Dict: Dictionary containning the preprocessed dataset + """ + X = np.array(dataset.get("ECAL")) * xscale + Y = np.array(dataset.get("energy")) / yscale + X[X < thresh] = 0 + X = X.astype(np.float32) + Y = Y.astype(np.float32) + ecal = np.sum(X, axis=(1, 2, 3)) + indexes = np.where(ecal > 10.0) + X = X[indexes] + Y = Y[indexes] + if angtype in dataset: + ang = np.array(dataset.get(angtype))[indexes] + # else: + # ang = gan.measPython(X) + X = np.expand_dims(X, axis=daxis) + ecal = ecal[indexes] + ecal = np.expand_dims(ecal, axis=daxis) + if xpower != 1.0: + X = np.power(X, xpower) + + Y = np.array([[el] for el in Y]) + ang = np.array([[el] for el in ang]) + ecal = np.array([[el] for el in ecal]) + + final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal} + + return final_dataset + + +class ParticlesDataModule(pl.LightningDataModule): + def __init__( + self, + datapath: str, + batch_size: int, + num_workers: int = 4, + max_samples: Optional[int] = None + ) -> None: + super().__init__() + self.batch_size = batch_size + self.num_workers = num_workers + self.datapath = datapath + self.max_samples = max_samples + + def setup(self, stage: str = None): + # make assignments here (val/train/test split) + # called on every process in DDP + + if stage == 'fit' or stage is None: + self.dataset = ParticlesDataset( + self.datapath, + max_samples=self.max_samples + ) + dataset_length = len(self.dataset) + split_point = int(dataset_length * 0.9) + self.train_dataset, self.val_dataset = \ + torch.utils.data.random_split( + self.dataset, [split_point, dataset_length - split_point]) + + if stage == 'predict': + # TODO: inference dataset should be different in that it + # does not contain images! + self.predict_dataset = ParticlesDataset( + self.datapath, + max_samples=self.max_samples + ) + + # if stage == 'test' or stage is None: + # self.test_dataset = MyDataset(self.data_dir, train=False) + + def train_dataloader(self): + return DataLoader(self.train_dataset, num_workers=self.num_workers, + batch_size=self.batch_size, drop_last=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, num_workers=self.num_workers, + batch_size=self.batch_size, drop_last=True) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader(self.predict_dataset, num_workers=self.num_workers, + batch_size=self.batch_size, drop_last=True) + + # def test_dataloader(self): + # return DataLoader(self.test_dataset, batch_size=self.batch_size) diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml new file mode 100644 index 00000000..3939b206 --- /dev/null +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -0,0 +1,104 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: /usr/data/exp_data/ + data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX + + - class_path: trainer.Lightning3DGANPredictor + init_args: + model: + class_path: trainer.LightningModelLoader + init_args: + model_uri: 3dgan-inference.pth + + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: auto #[0] + enable_checkpointing: true + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: 2 + logger: + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger + init_args: + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all + max_epochs: 1 + max_steps: 20 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: ddp_find_unused_parameters_true #auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 64 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + checkpoint_path: exp_data/3dgan.pth + + # Lightning data module configuration + data: + class_path: dataloader.ParticlesDataModule + init_args: + datapath: /usr/data/exp_data/*/*.h5 + batch_size: 64 + num_workers: 2 + max_samples: 10 + + - class_path: saver.ParticleImagesSaver + init_args: + save_dir: /usr/data/3dgan-generated-data \ No newline at end of file diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py new file mode 100644 index 00000000..4fc5cc99 --- /dev/null +++ b/use-cases/3dgan/model.py @@ -0,0 +1,783 @@ +import sys +import os +import pickle +from collections import defaultdict +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import lightning as pl +import numpy as np + + +class Generator(nn.Module): + def __init__(self, latent_dim): # img_shape + super().__init__() + # self.img_shape = img_shape + self.latent_dim = latent_dim + + self.l1 = nn.Linear(self.latent_dim, 5184) + self.up1 = nn.Upsample( + scale_factor=(6, 6, 6), + mode='trilinear', + align_corners=False + ) + self.conv1 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(6, 6, 8), padding=0 + ) + nn.init.kaiming_uniform_(self.conv1.weight) + # num_features is the number of channels (see doc) + self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + + self.conv2 = nn.Conv3d( + in_channels=8, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) + nn.init.kaiming_uniform_(self.conv2.weight) + self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + + self.conv3 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) + nn.init.kaiming_uniform_(self.conv3.weight) + self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + + self.conv4 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) + nn.init.kaiming_uniform_(self.conv4.weight) + self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) + + self.conv5 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(3, 3, 5), padding=0 + ) + nn.init.kaiming_uniform_(self.conv5.weight) + self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) + + self.conv6 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(3, 3, 3), padding=0 + ) + nn.init.kaiming_uniform_(self.conv6.weight) + + self.conv7 = nn.Conv3d( + in_channels=6, out_channels=1, + kernel_size=(2, 2, 2), padding=0 + ) + nn.init.xavier_normal_(self.conv7.weight) + + def forward(self, z): + img = self.l1(z) + img = img.view(-1, 8, 9, 9, 8) + img = self.up1(img) + img = self.conv1(img) + img = F.relu(img) + img = self.bn1(img) + img = self.pad1(img) + + img = self.conv2(img) + img = F.relu(img) + img = self.bn2(img) + img = self.pad2(img) + + img = self.conv3(img) + img = F.relu(img) + img = self.bn3(img) + img = self.pad3(img) + + img = self.conv4(img) + img = F.relu(img) + img = self.bn4(img) + img = self.pad4(img) + + img = self.conv5(img) + img = F.relu(img) + img = self.bn5(img) + img = self.pad5(img) + + img = self.conv6(img) + img = F.relu(img) + + img = self.conv7(img) + img = F.relu(img) + + return img + + +class Discriminator(nn.Module): + def __init__(self, power): + super().__init__() + + self.power = power + + self.conv1 = nn.Conv3d( + in_channels=1, out_channels=16, + kernel_size=(5, 6, 6), padding=(2, 3, 3) + ) + self.drop1 = nn.Dropout(0.2) + self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) + + self.conv2 = nn.Conv3d( + in_channels=16, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) + self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop2 = nn.Dropout(0.2) + self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) + + self.conv3 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) + self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop3 = nn.Dropout(0.2) + + self.conv4 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) + self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop4 = nn.Dropout(0.2) + + self.avgpool = nn.AvgPool3d((2, 2, 2)) + self.flatten = nn.Flatten() + + # The input features for the Linear layer need to be calculated based + # on the output shape from the previous layers. + self.fakeout = nn.Linear(19152, 1) + self.auxout = nn.Linear(19152, 1) # The same as above for this layer. + + # calculate sum of intensities + def ecal_sum(self, image, daxis): + sum = torch.sum(image, dim=daxis) + return sum + + # angle calculation + def ecal_angle(self, image, daxis1): + image = torch.squeeze(image, dim=daxis1) # squeeze along channel axis + + # get shapes + x_shape = image.shape[1] + y_shape = image.shape[2] + z_shape = image.shape[3] + sumtot = torch.sum(image, dim=(1, 2, 3)) # sum of events + + # get 1. where event sum is 0 and 0 elsewhere + amask = torch.where(sumtot == 0.0, torch.ones_like( + sumtot), torch.zeros_like(sumtot)) + # masked_events = torch.sum(amask) # counting zero sum events + + # ref denotes barycenter as that is our reference point + x_ref = torch.sum(torch.sum(image, dim=(2, 3)) + * (torch.arange(x_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) # sum for x position * x index + y_ref = torch.sum( + torch.sum(image, dim=(1, 3)) + * (torch.arange(y_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) + z_ref = torch.sum( + torch.sum(image, dim=(1, 2)) + * (torch.arange(z_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) + + # return max position if sumtot=0 and divide by sumtot otherwise + x_ref = torch.where( + sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot) + y_ref = torch.where( + sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot) + z_ref = torch.where( + sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot) + + # reshape + x_ref = x_ref.unsqueeze(1) + y_ref = y_ref.unsqueeze(1) + z_ref = z_ref.unsqueeze(1) + + sumz = torch.sum(image, dim=(1, 2)) # sum for x,y planes going along z + + # Get 0 where sum along z is 0 and 1 elsewhere + zmask = torch.where(sumz == 0.0, torch.zeros_like( + sumz), torch.ones_like(sumz)) + + x = torch.arange(x_shape, device=image.device).unsqueeze( + 0) # x indexes + x = (x.unsqueeze(2).float()) + 0.5 + y = torch.arange(y_shape, device=image.device).unsqueeze( + 0) # y indexes + y = (y.unsqueeze(2).float()) + 0.5 + + # barycenter for each z position + x_mid = torch.sum(torch.sum(image, dim=2) * x, dim=1) + y_mid = torch.sum(torch.sum(image, dim=1) * y, dim=1) + + x_mid = torch.where(sumz == 0.0, torch.zeros_like( + sumz), x_mid / sumz) # if sum != 0 then divide by sum + y_mid = torch.where(sumz == 0.0, torch.zeros_like( + sumz), y_mid / sumz) # if sum != 0 then divide by sum + + # Angle Calculations + z = (torch.arange( + z_shape, + device=image.device, + dtype=torch.float32 + # Make an array of z indexes for all events + ) + 0.5) * torch.ones_like(z_ref) + + # projection from z axis with stability check + zproj = torch.sqrt( + torch.max( + (x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, + torch.tensor( + [torch.finfo(torch.float32).eps] + ).to(x_mid.device) + ) + ) + # torch.finfo(torch.float32).eps)) + # to avoid divide by zero for zproj =0 + m = torch.where(zproj == 0.0, torch.zeros_like( + zproj), (y_mid - y_ref) / zproj) + m = torch.where(z < z_ref, -1 * m, m) # sign inversion + ang = (math.pi / 2.0) - torch.atan(m) # angle correction + zmask = torch.where(zproj == 0.0, torch.zeros_like(zproj), zmask) + ang = ang * zmask # place zero where zsum is zero + ang = ang * z # weighted by position + sumz_tot = z * zmask # removing indexes with 0 energies or angles + + # zunmasked = K.sum(zmask, axis=1) # used for simple mean + # Mean does not include positions where zsum=0 + # ang = K.sum(ang, axis=1)/zunmasked + + # sum ( measured * weights)/sum(weights) + ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1) + # Place 100 for measured angle where no energy is deposited in events + ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang)) + ang = ang.unsqueeze(1) + return ang + + def forward(self, x): + z = self.conv1(x) + z = F.leaky_relu(z) + z = self.drop1(z) + z = self.pad1(z) + + z = self.conv2(z) + z = F.leaky_relu(z) + z = self.bn1(z) + z = self.drop2(z) + z = self.pad2(z) + + z = self.conv3(z) + z = F.leaky_relu(z) + z = self.bn2(z) + z = self.drop3(z) + + z = self.conv4(z) + z = F.leaky_relu(z) + z = self.bn3(z) + z = self.drop4(z) + z = self.avgpool(z) + z = self.flatten(z) + + # generation output that says fake/real + fake = torch.sigmoid(self.fakeout(z)) + aux = self.auxout(z) # auxiliary output + inv_image = x.pow(1.0 / self.power) + ang = self.ecal_angle(inv_image, 1) # angle calculation + ecal = self.ecal_sum(inv_image, (2, 3, 4)) # sum of energies + + return fake, aux, ang, ecal + + +class ThreeDGAN(pl.LightningModule): + def __init__( + self, + latent_size=256, + batch_size=64, + loss_weights=[3, 0.1, 25, 0.1], + power=0.85, + lr=0.001, + checkpoint_path: str = '3Dgan.pth' + ): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + self.latent_size = latent_size + self.batch_size = batch_size + self.loss_weights = loss_weights + self.lr = lr + self.power = power + + self.generator = Generator(self.latent_size) + self.discriminator = Discriminator(self.power) + + self.epoch_gen_loss = [] + self.epoch_disc_loss = [] + self.disc_epoch_test_loss = [] + self.gen_epoch_test_loss = [] + self.index = 0 + self.train_history = defaultdict(list) + self.test_history = defaultdict(list) + self.pklfile = checkpoint_path + checkpoint_dir = os.path.dirname(checkpoint_path) + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + def BitFlip(self, x, prob=0.05): + """ + Flips a single bit according to a certain probability. + + Args: + x (list): list of bits to be flipped + prob (float): probability of flipping one bit + + Returns: + list: List of flipped bits + + """ + x = np.array(x) + selection = np.random.uniform(0, 1, x.shape) < prob + x[selection] = 1 * np.logical_not(x[selection]) + return x + + def mean_absolute_percentage_error(self, y_true, y_pred): + return torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-7))) * 100 + + def compute_global_loss( + self, + labels, + predictions, + loss_weights=(3, 0.1, 25, 0.1) + ): + # Can be initialized outside + binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none') + # there is no equivalent in pytorch for + # tf.keras.losses.MeanAbsolutePercentageError --> using the + # custom "mean_absolute_percentage_error" above! + mean_absolute_percentage_error_object1 = \ + self.mean_absolute_percentage_error(predictions[1], labels[1]) + mean_absolute_percentage_error_object2 = \ + self.mean_absolute_percentage_error(predictions[3], labels[3]) + mae_object = nn.L1Loss(reduction='none') + + binary_example_loss = binary_crossentropy_object( + predictions[0], labels[0]) * loss_weights[0] + + # mean_example_loss_1 = mean_absolute_percentage_error_object( + # predictions[1], labels[1]) * loss_weights[1] + mean_example_loss_1 = \ + mean_absolute_percentage_error_object1 * loss_weights[1] + + mae_example_loss = mae_object( + predictions[2], labels[2]) * loss_weights[2] + + # mean_example_loss_2 = mean_absolute_percentage_error_object( + # predictions[3], labels[3]) * loss_weights[3] + mean_example_loss_2 = \ + mean_absolute_percentage_error_object2 * loss_weights[3] + + binary_loss = binary_example_loss.mean() + mean_loss_1 = mean_example_loss_1.mean() + mae_loss = mae_example_loss.mean() + mean_loss_2 = mean_example_loss_2.mean() + + return [binary_loss, mean_loss_1, mae_loss, mean_loss_2] + + def forward(self, z): + return self.generator(z) + + def training_step(self, batch, batch_idx): + image_batch, energy_batch, ang_batch, ecal_batch = \ + batch['X'], batch['Y'], batch['ang'], batch['ecal'] + + image_batch = image_batch.permute(0, 4, 1, 2, 3) + + image_batch = image_batch.to(self.device) + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + ecal_batch = ecal_batch.to(self.device) + + optimizer_discriminator, optimizer_generator = self.optimizers() + + noise = torch.randn( + (energy_batch.shape[0], self.latent_size - 2), + # (self.batch_size, self.latent_size - 2), + dtype=torch.float32, + device=self.device + ) + # print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}') + # print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}') + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), + dim=1 + ) + generated_images = self.generator(generator_ip) + + # Train discriminator first on real batch + fake_batch = self.BitFlip(np.ones(self.batch_size).astype(np.float32)) + fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + predictions = self.discriminator(image_batch) + # print("calculating real_batch_loss...") + real_batch_loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("real_batch_loss", sum(real_batch_loss), + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) + # print("real batch disc train") + # the following 3 lines correspond in tf version to: + # gradients = tape.gradient(real_batch_loss, + # discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, + # discriminator.trainable_variables)) in Tensorflow + optimizer_discriminator.zero_grad() + self.manual_backward(sum(real_batch_loss)) + # sum(real_batch_loss).backward() + # real_batch_loss.backward() + optimizer_discriminator.step() + + # Train discriminator on the fake batch + fake_batch = self.BitFlip(np.zeros(self.batch_size).astype(np.float32)) + fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + predictions = self.discriminator(generated_images) + + fake_batch_loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("fake_batch_loss", sum(fake_batch_loss), + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) + # print("fake batch disc train") + # the following 3 lines correspond to + # gradients = tape.gradient(fake_batch_loss, + # discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, + # discriminator.trainable_variables)) in Tensorflow + optimizer_discriminator.zero_grad() + self.manual_backward(sum(fake_batch_loss)) + # sum(fake_batch_loss).backward() + optimizer_discriminator.step() + + # avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2 + + trick = np.ones(self.batch_size).astype(np.float32) + fake_batch = torch.tensor([[el] for el in trick]).to(self.device) + labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch] + + gen_losses_train = [] + # Train generator twice using combined model + for _ in range(2): + noise = torch.randn( + (self.batch_size, self.latent_size - 2)).to(self.device) + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), + dim=1 + ) + + generated_images = self.generator(generator_ip) + predictions = self.discriminator(generated_images) + + loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("gen_loss", sum(loss), prog_bar=True, + on_step=True, on_epoch=True, sync_dist=True) + # print("gen train") + optimizer_generator.zero_grad() + self.manual_backward(sum(loss)) + # sum(loss).backward() + optimizer_generator.step() + + for el in loss: + gen_losses_train.append(el) + + avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train) + self.log("generator_loss", avg_generator_loss.item(), + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) + # avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)] + # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, + # on_step=True, on_epoch=True, sync_dist=True) + + gen_losses = [] + # I'm not returning anything as in pl you do not return anything when + # you back-propagate manually + # return_loss = real_batch_loss + real_batch_loss = [real_batch_loss[0], real_batch_loss[1], + real_batch_loss[2], real_batch_loss[3]] + fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1], + fake_batch_loss[2], fake_batch_loss[3]] + gen_batch_loss = [gen_losses_train[0], gen_losses_train[1], + gen_losses_train[2], gen_losses_train[3]] + gen_losses.append(gen_batch_loss) + gen_batch_loss = [gen_losses_train[4], gen_losses_train[5], + gen_losses_train[6], gen_losses_train[7]] + gen_losses.append(gen_batch_loss) + + real_batch_loss = [el.cpu().detach().numpy() for el in real_batch_loss] + real_batch_loss_total_loss = np.sum(real_batch_loss) + new_real_batch_loss = [real_batch_loss_total_loss] + for i_weights in range(len(real_batch_loss)): + new_real_batch_loss.append( + real_batch_loss[i_weights] / self.loss_weights[i_weights]) + real_batch_loss = new_real_batch_loss + + fake_batch_loss = [el.cpu().detach().numpy() for el in fake_batch_loss] + fake_batch_loss_total_loss = np.sum(fake_batch_loss) + new_fake_batch_loss = [fake_batch_loss_total_loss] + for i_weights in range(len(fake_batch_loss)): + new_fake_batch_loss.append( + fake_batch_loss[i_weights] / self.loss_weights[i_weights]) + fake_batch_loss = new_fake_batch_loss + + # if ecal sum has 100% loss(generating empty events) then end + # the training + if fake_batch_loss[3] == 100.0 and self.index > 10: + # print("Empty image with Ecal loss equal to 100.0 " + # f"for {self.index} batch") + torch.save(self.generator.state_dict(), "generator_weights.pth") + torch.save(self.discriminator.state_dict(), + "discriminator_weights.pth") + # print("real_batch_loss", real_batch_loss) + # print("fake_batch_loss", fake_batch_loss) + sys.exit() + + # append mean of discriminator loss for real and fake events + self.epoch_disc_loss.append( + [(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)]) + + gen_losses[0] = [el.cpu().detach().numpy() for el in gen_losses[0]] + gen_losses_total_loss = np.sum(gen_losses[0]) + new_gen_losses = [gen_losses_total_loss] + for i_weights in range(len(gen_losses[0])): + new_gen_losses.append( + gen_losses[0][i_weights] / self.loss_weights[i_weights]) + gen_losses[0] = new_gen_losses + + gen_losses[1] = [el.cpu().detach().numpy() for el in gen_losses[1]] + gen_losses_total_loss = np.sum(gen_losses[1]) + new_gen_losses = [gen_losses_total_loss] + for i_weights in range(len(gen_losses[1])): + new_gen_losses.append( + gen_losses[1][i_weights] / self.loss_weights[i_weights]) + gen_losses[1] = new_gen_losses + + generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses)] + + self.epoch_gen_loss.append(generator_loss) + + # # MB: verify weight synchronization among workers + # # Ref: https://github.com/Lightning-AI/lightning/issues/9237 + # disc_w = self.discriminator.conv1.weight.reshape(-1)[0:5] + # gen_w = self.generator.conv1.weight.reshape(-1)[0:5] + # print(f"DISC w: {disc_w}") + # print(f"GEN w: {gen_w}") + + # self.index += 1 #this might be moved after test cycle + + # logging of gen and disc loss done by Trainer + # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, + # on_epoch=True, sync_dist=True) + # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, o + # n_epoch=True, sync_dist=True) + + # return avg_disc_loss + avg_generator_loss + + def on_train_epoch_end(self): # outputs + discriminator_train_loss = np.mean( + np.array(self.epoch_disc_loss), axis=0) + generator_train_loss = np.mean(np.array(self.epoch_gen_loss), axis=0) + + self.train_history["generator"].append(generator_train_loss) + self.train_history["discriminator"].append(discriminator_train_loss) + + print("-" * 65) + ROW_FMT = ( + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | " + "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (train)", + *self.train_history["generator"][-1])) + print(ROW_FMT.format("discriminator (train)", + *self.train_history["discriminator"][-1])) + + torch.save(self.generator.state_dict(), "generator_weights.pth") + torch.save(self.discriminator.state_dict(), + "discriminator_weights.pth") + + with open(self.pklfile, "wb") as f: + pickle.dump({"train": self.train_history, + "test": self.test_history}, f) + + # pickle.dump({"train": self.train_history}, open(self.pklfile, "wb")) + print("train-loss:" + str(self.train_history["generator"][-1][0])) + + def validation_step(self, batch, batch_idx): + image_batch, energy_batch, ang_batch, ecal_batch = batch[ + 'X'], batch['Y'], batch['ang'], batch['ecal'] + + image_batch = image_batch.permute(0, 4, 1, 2, 3) + + image_batch = image_batch.to(self.device) + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + ecal_batch = ecal_batch.to(self.device) + + # Generate Fake events with same energy and angle as data batch + noise = torch.randn( + (energy_batch.shape[0], self.latent_size - 2), + # (self.batch_size, self.latent_size - 2), + dtype=torch.float32, + device=self.device + ) + + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) + generated_images = self.generator(generator_ip) + + # concatenate to fake and real batches + X = torch.cat((image_batch, generated_images), dim=0) + + # y = np.array([1] * self.batch_size \ + # + [0] * self.batch_size).astype(np.float32) + y = torch.tensor([1] * self.batch_size + [0] * + self.batch_size, dtype=torch.float32).to(self.device) + y = y.view(-1, 1) + + ang = torch.cat((ang_batch, ang_batch), dim=0) + ecal = torch.cat((ecal_batch, ecal_batch), dim=0) + aux_y = torch.cat((energy_batch, energy_batch), dim=0) + + # y = [[el] for el in y] + labels = [y, aux_y, ang, ecal] + + # Calculate discriminator loss + disc_eval = self.discriminator(X) + disc_eval_loss = self.compute_global_loss( + labels, disc_eval, self.loss_weights) + + # Calculate generator loss + trick = np.ones(self.batch_size).astype(np.float32) + fake_batch = torch.tensor([[el] for el in trick]).to(self.device) + # fake_batch = [[el] for el in trick] + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + generated_images = self.generator(generator_ip) + gen_eval = self.discriminator(generated_images) + gen_eval_loss = self.compute_global_loss( + labels, gen_eval, self.loss_weights) + + self.log('val_discriminator_loss', sum( + disc_eval_loss), on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_generator_loss', sum(gen_eval_loss), + on_epoch=True, prog_bar=True, sync_dist=True) + + disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1], + disc_eval_loss[2], disc_eval_loss[3]] + gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1], + gen_eval_loss[2], gen_eval_loss[3]] + + # Configure the loss so it is equal to the original values + disc_eval_loss = [el.cpu().detach().numpy() for el in disc_test_loss] + disc_eval_loss_total_loss = np.sum(disc_eval_loss) + new_disc_eval_loss = [disc_eval_loss_total_loss] + for i_weights in range(len(disc_eval_loss)): + new_disc_eval_loss.append( + disc_eval_loss[i_weights] / self.loss_weights[i_weights]) + disc_eval_loss = new_disc_eval_loss + + gen_eval_loss = [el.cpu().detach().numpy() for el in gen_test_loss] + gen_eval_loss_total_loss = np.sum(gen_eval_loss) + new_gen_eval_loss = [gen_eval_loss_total_loss] + for i_weights in range(len(gen_eval_loss)): + new_gen_eval_loss.append( + gen_eval_loss[i_weights] / self.loss_weights[i_weights]) + gen_eval_loss = new_gen_eval_loss + + self.index += 1 + # evaluate discriminator loss + self.disc_epoch_test_loss.append(disc_eval_loss) + # evaluate generator loss + self.gen_epoch_test_loss.append(gen_eval_loss) + + def on_validation_epoch_end(self): + discriminator_test_loss = np.mean( + np.array(self.disc_epoch_test_loss), axis=0) + generator_test_loss = np.mean( + np.array(self.gen_epoch_test_loss), axis=0) + + self.test_history["generator"].append(generator_test_loss) + self.test_history["discriminator"].append(discriminator_test_loss) + + print("-" * 65) + ROW_FMT = ( + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | " + "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (test)", + *self.test_history["generator"][-1])) + print(ROW_FMT.format("discriminator (test)", + *self.test_history["discriminator"][-1])) + + # save loss dict to pkl file + with open(self.pklfile, "wb") as f: + pickle.dump({"train": self.train_history, + "test": self.test_history}, f) + # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) + # print("train-loss:" + str(self.train_history["generator"][-1][0])) + + def predict_step( + self, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Any: + energy_batch, ang_batch = batch['Y'], batch['ang'] + + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + + # Generate Fake events with same energy and angle as data batch + noise = torch.randn( + (energy_batch.shape[0], self.latent_size - 2), + dtype=torch.float32, + device=self.device + ) + + # print(f"Reshape energy: {energy_batch.view(-1, 1).shape}") + # print(f"Reshape angle: {ang_batch.view(-1, 1).shape}") + # print(f"Noise: {noise.shape}") + + generator_ip = torch.cat( + [energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise], + dim=1 + ) + # print(f"Generator input: {generator_ip.shape}") + generated_images = self.generator(generator_ip) + # print(f"Generated batch size {generated_images.shape}") + return {'images': generated_images, + 'energies': energy_batch, + 'angles': ang_batch} + + def configure_optimizers(self): + lr = self.lr + + optimizer_discriminator = torch.optim.RMSprop( + self.discriminator.parameters(), + lr + ) + optimizer_generator = torch.optim.RMSprop( + self.generator.parameters(), + lr + ) + return [optimizer_discriminator, optimizer_generator], [] diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml new file mode 100644 index 00000000..676424aa --- /dev/null +++ b/use-cases/3dgan/pipeline.yaml @@ -0,0 +1,93 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: exp_data/ # Set to null to skip dataset download + data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX + + - class_path: trainer.Lightning3DGANTrainer + init_args: + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: auto #[0] + enable_checkpointing: true + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: 2 + logger: + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger + init_args: + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all + max_epochs: 1 + max_steps: 20 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: ddp_find_unused_parameters_true #auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 64 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + checkpoint_path: exp_data/3dgan.pth + + # Lightning data module configuration + data: + class_path: dataloader.ParticlesDataModule + init_args: + datapath: exp_data/*/*.h5 + batch_size: 64 diff --git a/use-cases/3dgan/requirements.txt b/use-cases/3dgan/requirements.txt new file mode 100644 index 00000000..f1f3b0bf --- /dev/null +++ b/use-cases/3dgan/requirements.txt @@ -0,0 +1,6 @@ +h5py>=3.7.0 +google>=3.0.0 +protobuf>=4.24.3 +gdown>=4.7.1 +# plotly>=5.18.0 +# kaleido>=0.2.1 \ No newline at end of file diff --git a/use-cases/3dgan/saver.py b/use-cases/3dgan/saver.py new file mode 100644 index 00000000..7aa72429 --- /dev/null +++ b/use-cases/3dgan/saver.py @@ -0,0 +1,120 @@ +from typing import Dict, Tuple, Optional +import os +import shutil + +import torch +from torch import Tensor +import matplotlib.pyplot as plt +import numpy as np + +from itwinai.components import Saver + + +class ParticleImagesSaver(Saver): + """Saves generated particle trajectories to disk.""" + + def __init__( + self, + save_dir: str = '3dgan-generated' + ) -> None: + super().__init__() + self.save_dir = save_dir + + def execute( + self, + generated_images: Dict[str, Tensor], + config: Optional[Dict] = None + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """Saves generated images to disk. + + Args: + generated_images (Dict[str, Tensor]): maps unique item ID to + the generated image. + config (Optional[Dict], optional): inherited configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited + configuration and saver return value. + """ + result = self.save(generated_images) + return ((result,), config) + + def save(self, generated_images: Dict[str, Tensor]) -> None: + """Saves generated images to disk. + + Args: + generated_images (Dict[str, Tensor]): maps unique item ID to + the generated image. + """ + if os.path.exists(self.save_dir): + shutil.rmtree(self.save_dir) + os.makedirs(self.save_dir) + + # Save as torch tensor and jpg image + for img_id, img in generated_images.items(): + img_path = os.path.join(self.save_dir, img_id) + torch.save(img, img_path + '.pth') + self._save_image(img, img_id, img_path + '.jpg') + + def _save_image( + self, + img: Tensor, + img_idx: str, + img_path: str, + center: bool = True + ) -> None: + """Converts a 3D tensor to a 3D scatter plot and saves it + to disk as jpg image. + """ + x_offset = img.shape[0] // 2 if center else 0 + y_offset = img.shape[1] // 2 if center else 0 + z_offset = img.shape[2] // 2 if center else 0 + + # Convert tensor dimension IDs to coordinates + x_coords = [] + y_coords = [] + z_coords = [] + values = [] + + for x in range(img.shape[0]): + for y in range(img.shape[1]): + for z in range(img.shape[2]): + if img[x, y, z] > 0.0: + x_coords.append(x - x_offset) + y_coords.append(y - y_offset) + z_coords.append(z - z_offset) + values.append(img[x, y, z]) + + # import plotly.graph_objects as go + # normalize_intensity_by = 1 + # trace = go.Scatter3d( + # x=x_coords, + # y=y_coords, + # z=z_coords, + # mode='markers', + # marker_symbol='square', + # marker_color=[ + # f"rgba(0,0,255,{i*100//normalize_intensity_by/10})" + # for i in values], + # ) + # fig = go.Figure() + # fig.add_trace(trace) + # fig.write_image(img_path) + + values = np.array(values) + # 0-1 scaling + values = (values - values.min()) / (values.max() - values.min()) + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.scatter(x_coords, y_coords, z_coords, alpha=values) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + + # Extract energy and angle from idx + en, ang = img_idx.split('&') + en = en[7:] + ang = ang[6:] + ax.set_title(f"Energy: {en} - Angle: {ang}") + fig.savefig(img_path) diff --git a/use-cases/3dgan/startscript b/use-cases/3dgan/startscript new file mode 100644 index 00000000..579ce3b3 --- /dev/null +++ b/use-cases/3dgan/startscript @@ -0,0 +1,34 @@ +#!/bin/bash + +# general configuration of the job +#SBATCH --job-name=PrototypeTest +#SBATCH --account=intertwin +#SBATCH --mail-user= +#SBATCH --mail-type=ALL +#SBATCH --output=job.out +#SBATCH --error=job.err +#SBATCH --time=00:30:00 + +# configure node and process count on the CM +#SBATCH --partition=batch +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus-per-node=4 + +#SBATCH --exclusive + +# gres options have to be disabled for deepv +#SBATCH --gres=gpu:4 + +# load modules +ml --force purge +ml Stages/2023 StdEnv/2023 NVHPC/23.1 OpenMPI/4.1.4 cuDNN/8.6.0.163-CUDA-11.7 Python/3.10.4 HDF5 libaio/0.3.112 GCC/11.3.0 + +# shellcheck source=/dev/null +source ~/.bashrc + +# ON LOGIN NODE download datasets: +# $ micromamba run -p ../../.venv-pytorch python train.py -p pipeline.yaml --download-only + +srun micromamba run -p ../../.venv-pytorch python train.py -p pipeline.yaml \ No newline at end of file diff --git a/use-cases/3dgan/train.py b/use-cases/3dgan/train.py new file mode 100644 index 00000000..d04596be --- /dev/null +++ b/use-cases/3dgan/train.py @@ -0,0 +1,55 @@ +""" +Training pipeline. To run this script, use the following commands. + +On login node: + +>>> micromamba run -p ../../.venv-pytorch/ \ + python train.py -p pipeline.yaml -d + +On compute nodes: + +>>> micromamba run -p ../../.venv-pytorch/ \ + python train.py -p pipeline.yaml + +""" + +import argparse + +from itwinai.components import Executor +from itwinai.utils import parse_pipe_config +from jsonargparse import ArgumentParser + + +if __name__ == "__main__": + # Create CLI Parser + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", "--pipeline", type=str, required=True, + help='Configuration file to the pipeline to execute.' + ) + parser.add_argument( + '-d', '--download-only', + action=argparse.BooleanOptionalAction, + default=False, + help=('Whether to download only the dataset and exit execution ' + '(suggested on login nodes of HPC systems)') + ) + args = parser.parse_args() + + # Create parser for the pipeline (ordered) + pipe_parser = ArgumentParser() + pipe_parser.add_subclass_arguments(Executor, "executor") + + # Parse, Instantiate pipe + parsed = parse_pipe_config(args.pipeline, pipe_parser) + pipe = pipe_parser.instantiate_classes(parsed) + executor: Executor = getattr(pipe, 'executor') + + if args.download_only: + print('Downloading datasets and exiting...') + executor = executor[:1] + else: + print('Downloading datasets (if not already done) and running...') + executor = executor + executor.setup() + executor() diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py new file mode 100644 index 00000000..faf7dc32 --- /dev/null +++ b/use-cases/3dgan/trainer.py @@ -0,0 +1,170 @@ +import os +import sys +from typing import Union, Dict, Tuple, Optional, Any + +import torch +from torch import Tensor +import lightning as pl +from lightning.pytorch.cli import LightningCLI + +from itwinai.components import Trainer, Predictor +from itwinai.serialization import ModelLoader +from itwinai.torch.inference import TorchModelLoader +from itwinai.torch.types import Batch + +from model import ThreeDGAN +from dataloader import ParticlesDataModule +from utils import load_yaml + + +class Lightning3DGANTrainer(Trainer): + def __init__(self, config: Union[Dict, str]): + super().__init__() + if isinstance(config, str) and os.path.isfile(config): + # Load from YAML + config = load_yaml(config) + self.conf = config + + def train(self) -> Any: + old_argv = sys.argv + sys.argv = ['some_script_placeholder.py'] + cli = LightningCLI( + args=self.conf, + model_class=ThreeDGAN, + datamodule_class=ParticlesDataModule, + run=False, + save_config_kwargs={ + "overwrite": True, + "config_filename": "pl-training.yml", + }, + subclass_mode_model=True, + subclass_mode_data=True, + ) + sys.argv = old_argv + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[Any, Optional[Dict]]: + result = self.train() + return result, config + + def save_state(self): + return super().save_state() + + def load_state(self): + return super().load_state() + + +class LightningModelLoader(TorchModelLoader): + """Loads a torch lightning model from somewhere. + + Args: + model_uri (str): Can be a path on local filesystem + or an mlflow 'locator' in the form: + 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + """ + + def __call__(self) -> pl.LightningModule: + """"Loads model from model URI. + + Raises: + ValueError: if the model URI is not recognized + or the model is not found. + + Returns: + pl.LightningModule: torch lightning module. + """ + # TODO: improve + # # Load best model + # loaded_model = cli.model.load_from_checkpoint( + # ckpt_path, + # lightning_conf['model']['init_args'] + # ) + return super().__call__() + + +class Lightning3DGANPredictor(Predictor): + + def __init__( + self, + model: Union[ModelLoader, pl.LightningModule], + config: Union[Dict, str], + name: Optional[str] = None + ): + super().__init__(model, name) + if isinstance(config, str) and os.path.isfile(config): + # Load from YAML + config = load_yaml(config) + self.conf = config + + def predict( + self, + datamodule: Optional[pl.LightningDataModule] = None, + model: Optional[pl.LightningModule] = None + ) -> Dict[str, Tensor]: + old_argv = sys.argv + sys.argv = ['some_script_placeholder.py'] + cli = LightningCLI( + args=self.conf, + model_class=ThreeDGAN, + datamodule_class=ParticlesDataModule, + run=False, + save_config_kwargs={ + "overwrite": True, + "config_filename": "pl-training.yml", + }, + subclass_mode_model=True, + subclass_mode_data=True, + ) + sys.argv = old_argv + + # Override config file with inline arguments, if given + if datamodule is None: + datamodule = cli.datamodule + if model is None: + model = cli.model + + predictions = cli.trainer.predict(model, datamodule=datamodule) + + # Transpose predictions into images, energies and angles + images = torch.cat(list(map( + lambda pred: self.transform_predictions( + pred['images']), predictions + ))) + energies = torch.cat(list(map( + lambda pred: pred['energies'], predictions + ))) + angles = torch.cat(list(map( + lambda pred: pred['angles'], predictions + ))) + + predictions_dict = dict() + for img, en, ang in zip(images, energies, angles): + sample_key = f"energy={en.item()}&angle={ang.item()}" + predictions_dict[sample_key] = img + + return predictions_dict + + def transform_predictions(self, batch: Batch) -> Batch: + """ + Post-process the predictions of the torch model. + """ + return batch.squeeze(1) + + def execute( + self, + config: Optional[Dict] = None, + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """"Execute some operations. + + Args: + config (Dict, optional): key-value configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as + (results, config). + """ + return self.predict(), config diff --git a/use-cases/3dgan/utils.py b/use-cases/3dgan/utils.py new file mode 100644 index 00000000..d04f9e63 --- /dev/null +++ b/use-cases/3dgan/utils.py @@ -0,0 +1,108 @@ +""" +Utilities for itwinai package. +""" +import os +import yaml + +from collections.abc import MutableMapping +from typing import Dict +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig + + +def load_yaml(path: str) -> Dict: + """Load YAML file as dict. + + Args: + path (str): path to YAML file. + + Raises: + exc: yaml.YAMLError for loading/parsing errors. + + Returns: + Dict: nested dict representation of parsed YAML file. + """ + with open(path, "r", encoding="utf-8") as yaml_file: + try: + loaded_config = yaml.safe_load(yaml_file) + except yaml.YAMLError as exc: + print(exc) + raise exc + return loaded_config + + +def load_yaml_with_deps_from_file(path: str) -> DictConfig: + """ + Load YAML file with OmegaConf and merge it with its dependencies + specified in the `conf-dependencies` field. + Assume that the dependencies live in the same folder of the + YAML file which is importing them. + + Args: + path (str): path to YAML file. + + Raises: + exc: yaml.YAMLError for loading/parsing errors. + + Returns: + DictConfig: nested representation of parsed YAML file. + """ + yaml_conf = load_yaml(path) + use_case_dir = os.path.dirname(path) + deps = [] + if yaml_conf.get("conf-dependencies"): + for dependency in yaml_conf["conf-dependencies"]: + deps.append(load_yaml(os.path.join(use_case_dir, dependency))) + + return OmegaConf.merge(yaml_conf, *deps) + + +def load_yaml_with_deps_from_dict(dict_conf, use_case_dir) -> DictConfig: + deps = [] + + if dict_conf.get("conf-dependencies"): + for dependency in dict_conf["conf-dependencies"]: + deps.append(load_yaml(os.path.join(use_case_dir, dependency))) + + return OmegaConf.merge(dict_conf, *deps) + + +def dynamically_import_class(name: str): + """ + Dynamically import class by module path. + Adapted from https://stackoverflow.com/a/547867 + + Args: + name (str): path to the class (e.g., mypackage.mymodule.MyClass) + + Returns: + __class__: class object. + """ + module, class_name = name.rsplit(".", 1) + mod = __import__(module, fromlist=[class_name]) + klass = getattr(mod, class_name) + return klass + + +def flatten_dict( + d: MutableMapping, parent_key: str = "", sep: str = "." +) -> MutableMapping: + """Flatten dictionary + + Args: + d (MutableMapping): nested dictionary to flatten + parent_key (str, optional): prefix for all keys. Defaults to ''. + sep (str, optional): separator for nested key concatenation. + Defaults to '.'. + + Returns: + MutableMapping: flattened dictionary with new keys. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) diff --git a/use-cases/mnist/torch/Dockerfile b/use-cases/mnist/torch/Dockerfile new file mode 100644 index 00000000..b4cf3654 --- /dev/null +++ b/use-cases/mnist/torch/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.9.12 + +WORKDIR /usr/src/app + +# Install pytorch (cpuonly) +# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu + +# Install itwinai and dependencies +COPY pyproject.toml ./ +COPY src ./ +RUN pip install --no-cache-dir . + +# Add torch MNIST use case +COPY use-cases/mnist/torch/* ./ + +# Run inference +CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/mnist/torch/README.md b/use-cases/mnist/torch/README.md new file mode 100644 index 00000000..c953671f --- /dev/null +++ b/use-cases/mnist/torch/README.md @@ -0,0 +1,74 @@ +# Pure torch example on MNIST dataset + +## Training + +```bash +python train.py -p pipeline.yaml [-d] +``` + +Use `-d` flag to run only the fist step in the pipeline. + +## Inference + +1. Create sample dataset + + ```python + from dataloader import InferenceMNIST + InferenceMNIST.generate_jpg_sample('mnist-sample-data/', 10) + ``` + +2. Generate a dummy pre-trained neural network + + ```python + import torch + from model import Net + dummy_nn = Net() + torch.save(dummy_nn, 'mnist-pre-trained.pth') + ``` + +3. Run inference command. This will generate a "mnist-predictions" +folder containing a CSV file with the predictions as rows. + + ```bash + python train.py -p inference-pipeline.yaml + ``` + +Note the same entry point as for training. + +### Docker image + +Build from project root with + +```bash +# Local +docker buildx build -t itwinai-mnist-torch-inference -f use-cases/mnist/torch/Dockerfile . + +# Ghcr.io +docker buildx build -t ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 -f use-cases/mnist/torch/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +``` + +From wherever a sample of MNIST jpg images is available +(folder called 'mnist-sample-data/'): + +```text +├── $PWD +│ ├── mnist-sample-data +| │ ├── digit_0.jpg +| │ ├── digit_1.jpg +| │ ├── digit_2.jpg +... +| │ ├── digit_N.jpg +``` + +```bash +docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +``` + +This command will store the results in a folder called "mnist-predictions": + +```text +├── $PWD +│ ├── mnist-predictions +| │ ├── predictions.csv +``` diff --git a/use-cases/mnist/torch/dataloader.py b/use-cases/mnist/torch/dataloader.py index 56ef807d..39e9b56b 100644 --- a/use-cases/mnist/torch/dataloader.py +++ b/use-cases/mnist/torch/dataloader.py @@ -1,7 +1,10 @@ """Dataloader for Torch-based MNIST use case.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Callable, Any +import os +import shutil +from PIL import Image from torch.utils.data import Dataset from torchvision import transforms, datasets @@ -58,3 +61,99 @@ def execute( # ) # return (train_dataloder, validation_dataloader) return (self.train_dataset, self.val_dataset), config + + +class InferenceMNIST(Dataset): + """Loads a set of MNIST images from a folder of JPG files.""" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + supported_format: str = '.jpg' + ) -> None: + self.root = root + self.transform = transform + self.supported_format = supported_format + self.data = dict() + self._load() + + def _load(self): + for img_file in os.listdir(self.root): + if not img_file.lower().endswith(self.supported_format): + continue + filename = os.path.basename(img_file) + img = Image.open(os.path.join(self.root, img_file)) + self.data[filename] = img + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image_identifier, image) where image_identifier + is the unique identifier for the image (e.g., filename). + """ + img_id, img = list(self.data.items())[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + # print(type(img)) + # img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + return img_id, img + + @staticmethod + def generate_jpg_sample( + root: str, + max_items: int = 100 + ): + """Generate a sample dataset of JPG images starting from + LeCun's test dataset. + + Args: + root (str): sample path on disk + max_items (int, optional): max number of images to + generate. Defaults to 100. + """ + if os.path.exists(root): + shutil.rmtree(root) + os.makedirs(root) + + test_data = datasets.MNIST(root='.tmp', train=False, download=True) + for idx, (img, _) in enumerate(test_data): + if idx >= max_items: + break + savepath = os.path.join(root, f'digit_{idx}.jpg') + img.save(savepath) + + +class MNISTPredictLoader(DataGetter): + def __init__( + self, + test_data_path: str + ) -> None: + super().__init__() + self.test_data_path = test_data_path + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[Tuple[Dataset, Dataset], Optional[Dict]]: + data = self.load() + return data, config + + def load(self) -> Dataset: + return InferenceMNIST( + root=self.test_data_path, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml new file mode 100644 index 00000000..ba4f5e86 --- /dev/null +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -0,0 +1,22 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.MNISTPredictLoader + init_args: + test_data_path: /usr/data/mnist-sample-data + + - class_path: itwinai.torch.inference.MulticlassTorchPredictor + init_args: + model: + class_path: itwinai.torch.inference.TorchModelLoader + init_args: + model_uri: mnist-pre-trained.pth + test_dataloader_kwargs: + batch_size: 3 + + - class_path: saver.TorchMNISTLabelSaver + init_args: + save_dir: /usr/data/mnist-predictions + predictions_file: predictions.csv + class_labels: null \ No newline at end of file diff --git a/use-cases/mnist/torch/pipeline.yaml b/use-cases/mnist/torch/pipeline.yaml index df29c3fe..9bb7fb98 100644 --- a/use-cases/mnist/torch/pipeline.yaml +++ b/use-cases/mnist/torch/pipeline.yaml @@ -25,7 +25,7 @@ executor: batch_size: 32 pin_memory: True shuffle: False - epochs: 2 + epochs: 30 train_metrics: accuracy: class_path: torchmetrics.classification.MulticlassAccuracy diff --git a/use-cases/mnist/torch/saver.py b/use-cases/mnist/torch/saver.py new file mode 100644 index 00000000..fd54c0cf --- /dev/null +++ b/use-cases/mnist/torch/saver.py @@ -0,0 +1,65 @@ +""" +This module is used during inference to save predicted labels to file. +""" + +from typing import Optional, List, Dict, Tuple +import os +import shutil +import csv + +from itwinai.components import Saver + + +class TorchMNISTLabelSaver(Saver): + """Serializes to disk the labels predicted for MNIST dataset.""" + + def __init__( + self, + save_dir: str = 'mnist_predictions', + predictions_file: str = 'predictions.csv', + class_labels: Optional[List] = None + ) -> None: + super().__init__() + self.save_dir = save_dir + self.predictions_file = predictions_file + self.class_labels = ( + class_labels if class_labels is not None + else [f'Digit {i}' for i in range(10)] + ) + + def execute( + self, + predicted_classes: Dict[str, int], + config: Optional[Dict] = None + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """Translate predictions from class idx to class label and save + them to disk. + + Args: + predicted_classes (Dict[str, int]): maps unique item ID to + the predicted class ID. + config (Optional[Dict], optional): inherited configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited + configuration and saver return value. + """ + if os.path.exists(self.save_dir): + shutil.rmtree(self.save_dir) + os.makedirs(self.save_dir) + + # Map class idx (int) to class label (str) + predicted_labels = { + itm_name: self.class_labels[cls_idx] + for itm_name, cls_idx in predicted_classes.items() + } + result = self.save(predicted_labels) + return ((result,), config) + + def save(self, predicted_labels: Dict[str, str]) -> None: + filepath = os.path.join(self.save_dir, self.predictions_file) + with open(filepath, 'w') as csv_file: + writer = csv.writer(csv_file) + for key, value in predicted_labels.items(): + writer.writerow([key, value])