diff --git a/config/datasets/anomaly_detection_images.yaml b/config/datasets/anomaly_detection_images.yaml new file mode 100644 index 00000000..8496966c --- /dev/null +++ b/config/datasets/anomaly_detection_images.yaml @@ -0,0 +1,25 @@ +task: + - anomaly-detection-images + +name: MVTEC +description: " +The MVTec anomaly detection dataset (MVTec AD) +https://www.mvtec.com/company/research/datasets/mvtec-ad +DOI: 10.1007/s11263-020-01400-4 +DOI: 10.1109/CVPR.2019.00982 +https://www:mvtec:com/company/research/datasets/mvtec-ad" +markup_info: 'Train images do not contain anomalies' +date_time: 20.07.2024 + +_target_: innofw.core.datamodules.lightning_datamodules.anomaly_detection_images.ImageAnomaliesLightningDataModule + +train: + source: https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/train.zip + target: ./data/MVTEC/train +test: + source: https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/test.zip + target: ./data/MVTEC/test + +infer: + source: https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/test.zip + target: ./data/MVTEC/test diff --git a/config/experiments/anomaly-detection/KG_210724_ba083ak_anomaly_detection_images.yaml b/config/experiments/anomaly-detection/KG_210724_ba083ak_anomaly_detection_images.yaml new file mode 100644 index 00000000..9ec8aa85 --- /dev/null +++ b/config/experiments/anomaly-detection/KG_210724_ba083ak_anomaly_detection_images.yaml @@ -0,0 +1,24 @@ +# @package _global_ +defaults: + - override /models: anomaly-detection/cae + - override /datasets: anomaly_detection_images + - override /optimizers: adam + - override /augmentations_train: none + - override /augmentations_val: none + - override /augmentations_test: none + - override /losses: mse + + +project: "anomaly-detection-mvtec" +task: "anomaly-detection-images" +random_seed: 0 +epochs: 50 +batch_size: 8 +accelerator: gpu + +wandb: + enable: True + project: anomaly_detect_mvtec + entity: "k-galliamov" + group: none + job_type: training diff --git a/config/losses/mse.yaml b/config/losses/mse.yaml index 832d8e09..07c2a62c 100755 --- a/config/losses/mse.yaml +++ b/config/losses/mse.yaml @@ -2,6 +2,7 @@ name: MSE description: Mean squared error measures the average of the squares of the errors task: - regression + - anomaly-detection-images implementations: sklearn: diff --git a/config/models/anomaly-detection/cae.yaml b/config/models/anomaly-detection/cae.yaml new file mode 100644 index 00000000..7631e625 --- /dev/null +++ b/config/models/anomaly-detection/cae.yaml @@ -0,0 +1,4 @@ +name: convolutional AE +_target_: innofw.core.models.torch.architectures.autoencoders.convolutional_ae.CAE +description: Base Unet segmentation model with 3 channels input +anomaly_threshold: 0.05 \ No newline at end of file diff --git a/innofw/core/datamodules/lightning_datamodules/__init__.py b/innofw/core/datamodules/lightning_datamodules/__init__.py index 5d7130a7..8c506ac5 100755 --- a/innofw/core/datamodules/lightning_datamodules/__init__.py +++ b/innofw/core/datamodules/lightning_datamodules/__init__.py @@ -1,4 +1,5 @@ from .image_folder_dm import ImageLightningDataModule from .qsar_dm import QsarSelfiesDataModule from .semantic_segmentation.hdf5 import HDF5LightningDataModule -from .drugprot import DrugprotDataModule \ No newline at end of file +from .drugprot import DrugprotDataModule +from .anomaly_detection_images import ImageAnomaliesLightningDataModule \ No newline at end of file diff --git a/innofw/core/datamodules/lightning_datamodules/anomaly_detection_images.py b/innofw/core/datamodules/lightning_datamodules/anomaly_detection_images.py new file mode 100644 index 00000000..d95b1d28 --- /dev/null +++ b/innofw/core/datamodules/lightning_datamodules/anomaly_detection_images.py @@ -0,0 +1,105 @@ +import os +import logging +import pathlib + +import pandas as pd +import torch +import cv2 +import numpy as np +from torch.utils.data import random_split + +from innofw.constants import Frameworks +from innofw.constants import Stages +from innofw.core.datamodules.lightning_datamodules.base import ( + BaseLightningDataModule, +) +from innofw.core.datasets.anomalies import AnomaliesDataset + + +class ImageAnomaliesLightningDataModule(BaseLightningDataModule): + """ + A Class used for working with Time Series + ... + + Attributes + ---------- + aug : dict + The list of augmentations + val_size: float + The proportion of the dataset to include in the validation set + + Methods + ------- + save_preds(preds, stage: Stages, dst_path: pathlib.Path): + Saves inference predictions to csv file + + setup_infer(): + The method prepares inference data + + """ + + task = ["anomaly-detection-images"] + framework = [Frameworks.torch] + + def __init__( + self, + train, + test, + infer=None, + batch_size: int = 2, + val_size: float = 0.5, + num_workers: int = 1, + augmentations=None, + stage=None, + *args, + **kwargs, + ): + super().__init__( + train, test, infer, batch_size, num_workers, stage, *args, **kwargs + ) + self.aug = augmentations + self.val_size = val_size + + def setup_train_test_val(self, **kwargs): + self.train_dataset = AnomaliesDataset(self.train_source, self.get_aug(self.aug, 'train'), + add_labels=False) + self.test_dataset = AnomaliesDataset(self.test_source, self.get_aug(self.aug, 'test'), + add_labels=True) + + # divide into train, val, test - val is a part of test since train does not have anomalies + n = len(self.test_dataset) + test_size = int(n * (1 - self.val_size)) + self.test_dataset, self.val_dataset = random_split( + self.test_dataset, [test_size, n - test_size] + ) + + def predict_dataloader(self): + test_dataloader = torch.utils.data.DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + return test_dataloader + + def setup_infer(self): + self.predict_dataset = AnomaliesDataset(self.predict_source, self.get_aug(self.aug, 'test')) + + def save_preds(self, out_batches, stage: Stages, dst_path: pathlib.Path): + out_file_path = dst_path / "results" + os.mkdir(out_file_path) + n = 0 + for batch in out_batches: + for img, pred in zip(batch[0], batch[1]): + img = img.cpu().numpy() + pred = pred.numpy() * 255 # shape - (1024, 1024) + if pred.dtype != np.uint8: + pred = pred.astype(np.uint8) + filename = out_file_path / f"out_{n}.png" + n += 1 + cv2.imwrite(filename, pred) + mask_vis = np.zeros_like(img) + mask_vis[1, :, :] = pred / 255 + img_with_mask = (img * 255 * 0.75 + mask_vis * 255 * 0.25).astype(np.uint8).transpose((1, 2, 0)) + img_with_mask = cv2.cvtColor(img_with_mask, cv2.COLOR_BGR2RGB) + cv2.imwrite(str(filename).replace('out_', 'vis_'), img_with_mask) + logging.info(f"Saved result to: {out_file_path}") diff --git a/innofw/core/datasets/anomalies.py b/innofw/core/datasets/anomalies.py new file mode 100644 index 00000000..e97af429 --- /dev/null +++ b/innofw/core/datasets/anomalies.py @@ -0,0 +1,51 @@ +from pathlib import Path + +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset + + +class AnomaliesDataset(Dataset): + """ + A class to represent a custom ECG Dataset. + + data_path: str + path to folder with structure: + data_path/images/ + data_path/labels/ (optional) + + augmentations: transforms to apply on images + + add_labels: whether to return anomaly segmentation with the image + + Methods + ------- + __getitem__(self, idx): + returns X-features, and Y-targets (if the dataset is for testing or validation) + """ + + def __init__(self, data_path, augmentations, add_labels=False): + if str(data_path).endswith('images') or str(data_path).endswith('labels'): + data_path = data_path.parent + self.images = list(Path(str(data_path) + '/images').iterdir()) + self.add_labels = add_labels + self.augmentations = augmentations + if self.add_labels: + self.labels = list(Path(str(data_path) + '/labels').iterdir()) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + image_path = self.images[idx] + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.from_numpy(image).float() + image = torch.div(image, 255) + if not self.add_labels: + return self.augmentations(image) if self.augmentations is not None else image + mask = cv2.imread(str(self.labels[idx]), 0) + if self.augmentations is not None: + image, mask = self.augmentations(image, mask) + return image, mask diff --git a/innofw/core/models/torch/architectures/autoencoders/convolutional_ae.py b/innofw/core/models/torch/architectures/autoencoders/convolutional_ae.py new file mode 100644 index 00000000..afbf38a0 --- /dev/null +++ b/innofw/core/models/torch/architectures/autoencoders/convolutional_ae.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from segmentation_models_pytorch import Unet + + +class CAE(nn.Module): + def __init__(self, anomaly_threshold, input_channels=3): + super(CAE, self).__init__() + self.model = Unet(classes=input_channels, activation='sigmoid') + self.anomaly_threshold = anomaly_threshold + + def forward(self, x): + x_hat = self.model(x) + return x_hat + + +if __name__ == '__main__': + model = CAE(0) + _x = torch.zeros((10, 3, 512, 512)) + print(model(_x).shape) \ No newline at end of file diff --git a/innofw/core/models/torch/lightning_modules/__init__.py b/innofw/core/models/torch/lightning_modules/__init__.py index 485e3e97..ba16c963 100755 --- a/innofw/core/models/torch/lightning_modules/__init__.py +++ b/innofw/core/models/torch/lightning_modules/__init__.py @@ -1,6 +1,7 @@ from .anomaly_detection_timeseries import ( AnomalyDetectionTimeSeriesLightningModule, ) +from .anomaly_detection_images import AnomalyDetectionImagesLightningModule from .biobert_ner_model import BiobertNERModel from .chemistry_vae import ChemistryVAEForwardLightningModule from .chemistry_vae import ChemistryVAELightningModule diff --git a/innofw/core/models/torch/lightning_modules/anomaly_detection_images.py b/innofw/core/models/torch/lightning_modules/anomaly_detection_images.py new file mode 100644 index 00000000..c2f902b4 --- /dev/null +++ b/innofw/core/models/torch/lightning_modules/anomaly_detection_images.py @@ -0,0 +1,126 @@ +from typing import Any + +import torch +from torchmetrics import MetricCollection +from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score, BinaryPrecision, \ + BinaryRecall +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError +from lovely_numpy import lo + +from innofw.core.models.torch.lightning_modules.base import BaseLightningModule + + +class AnomalyDetectionImagesLightningModule(BaseLightningModule): + """ + PyTorchLightning module for Anomaly Detection in Time Series + ... + + Attributes + ---------- + model : nn.Module + model to train + losses : losses + loss to use while training + optimizer_cfg : cfg + optimizer configurations + scheduler_cfg : cfg + scheduler configuration + + Methods + ------- + forward(x): + returns result of prediction + """ + + def __init__( + self, + model, + losses, + optimizer_cfg, + scheduler_cfg, + *args: Any, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.model = model + self.losses = losses + self.optimizer_cfg = optimizer_cfg + self.scheduler_cfg = scheduler_cfg + + self.loss_fn = torch.nn.MSELoss() + + metrics = MetricCollection( + [MeanSquaredError(), MeanAbsoluteError()] + ) + + self.train_metrics = metrics.clone(prefix='train') + segmentation_metrics = MetricCollection( + [ + BinaryF1Score(), + BinaryPrecision(), + BinaryRecall(), + BinaryJaccardIndex(), + ] + ) + self.val_metrics = segmentation_metrics.clone(prefix='val') + self.test_metrics = segmentation_metrics.clone(prefix='val') + + def forward(self, x, *args, **kwargs) -> Any: + return self.model(x.float()) + + def training_step(self, x, batch_idx): + x_rec = self.forward(x) + loss = self.loss_fn(x, x_rec) + metrics = self.compute_metrics('train', x_rec, x) + self.log_metrics('train', metrics) + self.log("train_loss", loss, on_step=False, on_epoch=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + x, y = batch + y = y.bool() + x_rec = self.forward(x) + loss = self.loss_fn(x, x_rec) + mask = self.compute_anomaly_mask(x) + metrics = self.compute_metrics('val', mask, y) + self.log_metrics('val', metrics) + print(mask.float().mean(), y.float().mean()) + self.log("val_loss", loss, on_step=False, on_epoch=True) + return {"loss": loss} + + def test_step(self, batch, batch_idx): + x, y = batch + x_rec = self.forward(x) + loss = self.loss_fn(x, x_rec) + mask = self.compute_anomaly_mask(x) + metrics = self.compute_metrics('test', mask, y) + self.log_metrics('test', metrics) + self.log("test_loss", loss, on_step=False, on_epoch=True) + return {"loss": loss} + + def predict_step(self, x, batch_idx, **kwargs): + return (x, self.compute_anomaly_mask(x)) + + def compute_anomaly_mask(self, x): + x_rec = self.forward(x) # (B, C, W, H) + diff = ((x - x_rec) ** 2).sum(dim=1) # sum across channels + mask = diff >= self.model.anomaly_threshold + return mask + + def log_metrics(self, stage, metrics_res, *args, **kwargs): + for key, value in metrics_res.items(): + self.log(key, value) # , sync_dist=True + + def compute_metrics(self, stage, predictions, labels): + # Reshape labels from [B, 1, H, W] to [B, H, W] + if labels.shape[1] == 1: + labels = labels.squeeze(1) + labels = labels.type(dtype=torch.long) + + if stage == "train": + return self.train_metrics(predictions, labels) + elif stage == "val": + out1 = self.val_metrics(predictions, labels) + return out1 + elif stage == "test": + return self.test_metrics(predictions, labels) diff --git a/innofw/core/models/torch_adapter.py b/innofw/core/models/torch_adapter.py index 39eb63ed..2c389397 100755 --- a/innofw/core/models/torch_adapter.py +++ b/innofw/core/models/torch_adapter.py @@ -161,7 +161,7 @@ def predict(self, datamodule, ckpt_path=None): def train(self, data_module, ckpt_path=None): self.trainer.fit(self.pl_module, data_module, ckpt_path=ckpt_path) - def test(self, data_module): + def test(self, data_module, ckpt_path=None): outputs = self.trainer.test(self.pl_module, data_module) return outputs diff --git a/innofw/utils/defaults.py b/innofw/utils/defaults.py index d8492e9e..89ba653d 100755 --- a/innofw/utils/defaults.py +++ b/innofw/utils/defaults.py @@ -6,6 +6,7 @@ from innofw.core.models.torch.lightning_modules import ( AnomalyDetectionTimeSeriesLightningModule, + AnomalyDetectionImagesLightningModule ) from innofw.core.models.torch.lightning_modules import BiobertNERModel from innofw.core.models.torch.lightning_modules import ( @@ -149,6 +150,15 @@ def get_default(obj_name: str, framework: str, task: str): {"_target_": "torch.optim.Adam", "lr": 3e-2} ), }, + "anomaly-detection-images": { + "lightning_module": AnomalyDetectionImagesLightningModule, + "trainer_cfg": OmegaConf.create( + {"_target_": "pytorch_lightning.Trainer", "max_epochs": 1} + ), + "optimizers_cfg": OmegaConf.create( + {"_target_": "torch.optim.Adam", "lr": 3e-2} + ), + }, "text-ner": { "lightning_module": BiobertNERModel, "trainer_cfg": OmegaConf.create( diff --git a/tests/fixtures/config/datasets.py b/tests/fixtures/config/datasets.py index 7f24eb84..9fcd91c0 100644 --- a/tests/fixtures/config/datasets.py +++ b/tests/fixtures/config/datasets.py @@ -405,6 +405,33 @@ ) +anomaly_detection_images_datamodule_cfg_w_target = DictConfig( + { + "task": ["anomaly-detection-images"], + "name": "MVTEC", + "description": """The MVTec anomaly detection dataset (MVTec AD) +https://www.mvtec.com/company/research/datasets/mvtec-ad +DOI: 10.1007/s11263-020-01400-4 +DOI: 10.1109/CVPR.2019.00982 +https://www:mvtec:com/company/research/datasets/mvtec-a""", + "markup_info": "Train images do not contain anomalies", + "date_time": "20.07.2024", + "_target_": "innofw.core.datamodules.lightning_datamodules.anomaly_detection_images.ImageAnomaliesLightningDataModule", + "train": { + "source": "https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/train.zip", + "target": "./tmp/MVTEC/train", + }, + "test": { + "source": "https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/test.zip", + "target": "./tmp/MVTEC/test", + }, + "infer": { + "source": "https://api.blackhole.ai.innopolis.university/public-datasets/anomaly_detection_mvtec/test.zip", + "target": "./tmp/MVTEC/infer", + }, + } +) + pngstroke_segmentation_datamodule_cfg_w_target = DictConfig( { "task": ["image_segmentation"], diff --git a/tests/fixtures/config/losses.py b/tests/fixtures/config/losses.py index 32a147ed..f88cd3b8 100644 --- a/tests/fixtures/config/losses.py +++ b/tests/fixtures/config/losses.py @@ -1,5 +1,7 @@ from omegaconf import DictConfig +torch_mse = "torch.nn.MSELoss" + jaccard_loss_w_target = DictConfig( { "name": "Segmentation", @@ -66,12 +68,12 @@ "mse": { "weight": 1.0, "object": { - "_target_": "torch.nn.MSELoss"} + "_target_": torch_mse} }, "target_loss": { "weight": 1.0, "object": { - "_target_": "torch.nn.MSELoss" + "_target_": torch_mse } }, "kld": { @@ -141,3 +143,22 @@ } } ) + + +mse_loss_w_target = DictConfig( + { + "name": "MSE", + "description": "Mean squared error measures the average of the squares of the errors", + "task": ["regression", "anomaly-detection-images"], + "implementations": { + "torch": { + "mse": { + "weight": 1, + "object": { + "_target_": torch_mse + } + } + } + } + } +) \ No newline at end of file diff --git a/tests/fixtures/config/models.py b/tests/fixtures/config/models.py index 932482f4..38ed9953 100644 --- a/tests/fixtures/config/models.py +++ b/tests/fixtures/config/models.py @@ -46,6 +46,15 @@ } ) +unet_anomalies_cfg_w_target = DictConfig( + { + "name": "convolutional AE", + "description": "Base Unet segmentation model with 3 channels input", + "_target_": "innofw.core.models.torch.architectures.autoencoders.convolutional_ae.CAE", + "anomaly_threshold": 0.3 + } +) + # case: SegFormer segformer_retaining = DictConfig( { diff --git a/tests/integration/models/torch/lighting_modules/test_anomaly_detection_images.py b/tests/integration/models/torch/lighting_modules/test_anomaly_detection_images.py new file mode 100644 index 00000000..fe9c8325 --- /dev/null +++ b/tests/integration/models/torch/lighting_modules/test_anomaly_detection_images.py @@ -0,0 +1,60 @@ +import shutil + +from omegaconf import DictConfig + +from innofw.constants import Frameworks, Stages +from innofw.core.datamodules.lightning_datamodules import \ + ImageAnomaliesLightningDataModule +from innofw.core.models.torch.lightning_modules import ( + AnomalyDetectionImagesLightningModule +) +from innofw.utils.framework import get_datamodule +from tests.fixtures.config.datasets import anomaly_detection_images_datamodule_cfg_w_target +from innofw.utils.framework import get_losses +from innofw.utils.framework import get_model +from tests.fixtures.config import losses as fixt_losses +from tests.fixtures.config import models as fixt_models +from tests.fixtures.config import optimizers as fixt_optimizers +from tests.fixtures.config import trainers as fixt_trainers +from tests.fixtures.config import schedulers as fixt_schedulers + + +def test_anomaly_detection(): + cfg = DictConfig( + { + "models": fixt_models.unet_anomalies_cfg_w_target, + "trainer": fixt_trainers.trainer_cfg_w_cpu_devices, + "losses": fixt_losses.mse_loss_w_target, + } + ) + model = get_model(cfg.models, cfg.trainer) + losses = get_losses(cfg, "anomaly-detection-images", Frameworks.torch) + optimizer_cfg = DictConfig(fixt_optimizers.adam_optim_w_target) + scheduler_cfg = DictConfig(fixt_schedulers.linear_w_target) + + module = AnomalyDetectionImagesLightningModule( + model=model, losses=losses, optimizer_cfg=optimizer_cfg, scheduler_cfg=scheduler_cfg + ) + + assert module is not None + + datamodule: ImageAnomaliesLightningDataModule = get_datamodule( + anomaly_detection_images_datamodule_cfg_w_target, + Frameworks.torch, + task="anomaly-detection-images" + ) + datamodule.setup(Stages.train) + + module.training_step(next(iter(datamodule.train_dataloader())), 0) + model.eval() + module.validation_step(next(iter(datamodule.val_dataloader())), 0) + module.test_step(next(iter(datamodule.test_dataloader())), 0) + datamodule.setup_infer() + module.predict_step(next(iter(datamodule.predict_dataloader())), 0) + + for _ in range(3): + try: + shutil.rmtree('./tmp') + break + except: + pass diff --git a/train.py b/train.py index 19ffd785..0933383e 100755 --- a/train.py +++ b/train.py @@ -56,4 +56,4 @@ def main(config) -> float: ) sys.argv.append("hydra.run.dir=./logs") sys.argv.append("hydra.job.chdir=True") - main() + main() \ No newline at end of file