From 443b39ea169091214b36f75e8dcee6658c96b72d Mon Sep 17 00:00:00 2001 From: Kalliopi Tsolaki Date: Tue, 24 Oct 2023 16:09:09 +0200 Subject: [PATCH 01/57] commiting integration of 3dgan scripts --- .gitignore | 1 + use-cases/3dgan/dataloader.py | 150 ++++++++ use-cases/3dgan/model.py | 567 +++++++++++++++++++++++++++++++ use-cases/3dgan/pipeline.yaml | 98 ++++++ use-cases/3dgan/requirements.txt | 3 + use-cases/3dgan/startscript | 34 ++ use-cases/3dgan/train.py | 55 +++ use-cases/3dgan/trainer.py | 45 +++ use-cases/3dgan/utils.py | 108 ++++++ 9 files changed, 1061 insertions(+) create mode 100644 use-cases/3dgan/dataloader.py create mode 100644 use-cases/3dgan/model.py create mode 100644 use-cases/3dgan/pipeline.yaml create mode 100644 use-cases/3dgan/requirements.txt create mode 100644 use-cases/3dgan/startscript create mode 100644 use-cases/3dgan/train.py create mode 100644 use-cases/3dgan/trainer.py create mode 100644 use-cases/3dgan/utils.py diff --git a/.gitignore b/.gitignore index 41349b5f..e32d500f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.pth TODO /data nohup* diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py new file mode 100644 index 00000000..f25ad7fd --- /dev/null +++ b/use-cases/3dgan/dataloader.py @@ -0,0 +1,150 @@ +from typing import Optional, Tuple, Dict +import lightning as pl + +from torch.utils.data import DataLoader, random_split +from torchvision import transforms + +from itwinai.components import DataGetter + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +import glob +import h5py +import numpy as np +from collections import defaultdict +#import pytorch_lightning as pl +import sys +import pickle + + +class Lightning3DGANDownloader(DataGetter): + def __init__( + self, + data_path: str, + name: Optional[str] = None, + **kwargs) -> None: + super().__init__(name, **kwargs) + self.data_path = data_path + + def load(self): + ... + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[None, Optional[Dict]]: + self.load() + return None, config + + +class MyDataset(Dataset): + def __init__(self, datapath): + self.datapath = datapath + self.data = self.fetch_data(self.datapath) + + def __len__(self): + return len(self.data["X"]) + + def __getitem__(self, idx): + return {"X": self.data["X"][idx], "Y": self.data["Y"][idx], "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]} + + def fetch_data(self, datapath): + + print("Searching in :", datapath) + Files = sorted(glob.glob(datapath)) + print("Found {} files. ".format(len(Files))) + + concatenated_datasets = [] + for datafile in Files: + f=h5py.File(datafile,'r') + dataset = self.GetDataAngleParallel(f) + concatenated_datasets.append(dataset) + result = {key: [] for key in concatenated_datasets[0].keys()} # Initialize result dictionary + for d in concatenated_datasets: + for key in result.keys(): + result[key].extend(d[key]) + return result + + def GetDataAngleParallel( + self, + dataset, + xscale=1, + xpower=0.85, + yscale=100, + angscale=1, + angtype="theta", + thresh=1e-4, + daxis=-1,): + """Preprocess function for the dataset + + Args: + dataset (str): Dataset file path + xscale (int, optional): Value to scale the ECAL values. Defaults to 1. + xpower (int, optional): Value to scale the ECAL values, exponentially. Defaults to 1. + yscale (int, optional): Value to scale the energy values. Defaults to 100. + angscale (int, optional): Value to scale the angle values. Defaults to 1. + angtype (str, optional): Which type of angle to use. Defaults to "theta". + thresh (_type_, optional): Maximum value for ECAL values. Defaults to 1e-4. + daxis (int, optional): Axis to expand values. Defaults to -1. + + Returns: + Dict: Dictionary containning the preprocessed dataset + """ + X = np.array(dataset.get("ECAL")) * xscale + Y = np.array(dataset.get("energy")) / yscale + X[X < thresh] = 0 + X = X.astype(np.float32) + Y = Y.astype(np.float32) + ecal = np.sum(X, axis=(1, 2, 3)) + indexes = np.where(ecal > 10.0) + X = X[indexes] + Y = Y[indexes] + if angtype in dataset: + ang = np.array(dataset.get(angtype))[indexes] + # else: + # ang = gan.measPython(X) + X = np.expand_dims(X, axis=daxis) + ecal = ecal[indexes] + ecal = np.expand_dims(ecal, axis=daxis) + if xpower != 1.0: + X = np.power(X, xpower) + + Y = np.array([[el] for el in Y]) + ang = np.array([[el] for el in ang]) + ecal = np.array([[el] for el in ecal]) + + final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal} + + return final_dataset + + +class MyDataModule(pl.LightningDataModule): + def __init__(self, batch_size: int, datapath): + super().__init__() + self.batch_size = batch_size + self.datapath = datapath + + def setup(self, stage: str = None): + # make assignments here (val/train/test split) + # called on every process in DDP + + if stage == 'fit' or stage is None: + self.dataset = MyDataset(self.datapath) + dataset_length = len(self.dataset) + split_point = int(dataset_length * 0.9) + self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.dataset, [split_point, dataset_length - split_point]) + + #if stage == 'test' or stage is None: + #self.test_dataset = MyDataset(self.data_dir, train=False) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, drop_last=True) + + #def test_dataloader(self): + #return DataLoader(self.test_dataset, batch_size=self.batch_size) + diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py new file mode 100644 index 00000000..ac447949 --- /dev/null +++ b/use-cases/3dgan/model.py @@ -0,0 +1,567 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +import lightning as pl +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +import glob +import h5py +import numpy as np +from collections import defaultdict +#import pytorch_lightning as pl +import sys +import pickle + + +class Generator(nn.Module): + def __init__(self, latent_dim): #img_shape + super().__init__() + #self.img_shape = img_shape + self.latent_dim = latent_dim + + self.l1 = nn.Linear(self.latent_dim, 5184) + self.up1 = nn.Upsample(scale_factor=(6, 6, 6), mode='trilinear', align_corners=False) + self.conv1 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(6, 6, 8), padding=0) + nn.init.kaiming_uniform_(self.conv1.weight) + self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) #num_features is the number of channels (see doc) + self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + self.conv2 = nn.Conv3d(in_channels=8, out_channels=6, kernel_size=(4, 4, 6), padding=0) + nn.init.kaiming_uniform_(self.conv2.weight) + self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6) + + self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + self.conv3 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(4, 4, 6), padding=0) + nn.init.kaiming_uniform_(self.conv3.weight) + self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) + self.conv4 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(4, 4, 6), padding=0) + nn.init.kaiming_uniform_(self.conv4.weight) + self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6) + self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) + self.conv5 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(3, 3, 5), padding=0) + nn.init.kaiming_uniform_(self.conv5.weight) + self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6) + + self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) + self.conv6 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(3, 3, 3), padding=0) + nn.init.kaiming_uniform_(self.conv6.weight) + self.conv7 = nn.Conv3d(in_channels=6, out_channels=1, kernel_size=(2, 2, 2), padding=0) + nn.init.xavier_normal_(self.conv7.weight) + + + def forward(self, z): + img = self.l1(z) + img = img.view(-1, 8, 9, 9, 8) + img = self.up1(img) + img = self.conv1(img) + img = F.relu(img) + img = self.bn1(img) + img = self.pad1(img) + img = self.conv2(img) + img = F.relu(img) + img = self.bn2(img) + + img = self.pad2(img) + img = self.conv3(img) + img = F.relu(img) + img = self.bn3(img) + img = self.pad3(img) + img = self.conv4(img) + img = F.relu(img) + img = self.bn4(img) + img = self.pad4(img) + img = self.conv5(img) + img = F.relu(img) + img = self.bn5(img) + + img = self.pad5(img) + img = self.conv6(img) + img = F.relu(img) + img = self.conv7(img) + img = F.relu(img) + + return img + + +class Discriminator(nn.Module): + def __init__(self, power): + super().__init__() + + self.power = power + + self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(5, 6, 6), padding=(2, 3, 3)) + self.drop1 = nn.Dropout(0.2) + + self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) + self.conv2 = nn.Conv3d(in_channels=16, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop2 = nn.Dropout(0.2) + + self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) + self.conv3 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop3 = nn.Dropout(0.2) + + self.conv4 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6) + self.drop4 = nn.Dropout(0.2) + + self.avgpool = nn.AvgPool3d((2, 2, 2)) + self.flatten = nn.Flatten() + + self.fakeout = nn.Linear(19152, 1) # The input features for the Linear layer need to be calculated based on the output shape from the previous layers. + self.auxout = nn.Linear(19152, 1) # The same as above for this layer. + + # calculate sum of intensities + def ecal_sum(self, image, daxis): + sum = torch.sum(image, dim=daxis) + return sum + + # angle calculation + def ecal_angle(self, image, daxis1): + image = torch.squeeze(image, dim=daxis1) # squeeze along channel axis + + # get shapes + x_shape = image.shape[1] + y_shape = image.shape[2] + z_shape = image.shape[3] + sumtot = torch.sum(image, dim=(1, 2, 3)) # sum of events + + # get 1. where event sum is 0 and 0 elsewhere + amask = torch.where(sumtot == 0.0, torch.ones_like(sumtot), torch.zeros_like(sumtot)) + masked_events = torch.sum(amask) # counting zero sum events + + # ref denotes barycenter as that is our reference point + x_ref = torch.sum(torch.sum(image, dim=(2, 3)) + * (torch.arange(x_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) # sum for x position * x index + y_ref = torch.sum( + torch.sum(image, dim=(1, 3)) + * (torch.arange(y_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) + z_ref = torch.sum( + torch.sum(image, dim=(1, 2)) + * (torch.arange(z_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) + + x_ref = torch.where(sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot) # return max position if sumtot=0 and divide by sumtot otherwise + y_ref = torch.where(sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot) + z_ref = torch.where(sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot) + + # reshape + x_ref = x_ref.unsqueeze(1) + y_ref = y_ref.unsqueeze(1) + z_ref = z_ref.unsqueeze(1) + + sumz = torch.sum(image, dim=(1, 2)) # sum for x,y planes going along z + + # Get 0 where sum along z is 0 and 1 elsewhere + zmask = torch.where(sumz == 0.0, torch.zeros_like(sumz), torch.ones_like(sumz)) + + x = torch.arange(x_shape, device=image.device).unsqueeze(0) # x indexes + x = (x.unsqueeze(2).float()) + 0.5 + y = torch.arange(y_shape, device=image.device).unsqueeze(0) # y indexes + y = (y.unsqueeze(2).float()) + 0.5 + + # barycenter for each z position + x_mid = torch.sum(torch.sum(image, dim=2) * x, dim=1) + y_mid = torch.sum(torch.sum(image, dim=1) * y, dim=1) + + x_mid = torch.where(sumz == 0.0, torch.zeros_like(sumz), x_mid / sumz) # if sum != 0 then divide by sum + y_mid = torch.where(sumz == 0.0, torch.zeros_like(sumz), y_mid / sumz) # if sum != 0 then divide by sum + + # Angle Calculations + z = (torch.arange(z_shape, device=image.device, dtype=torch.float32) + 0.5) * torch.ones_like(z_ref) # Make an array of z indexes for all events + + zproj = torch.sqrt( + torch.max((x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, torch.tensor([torch.finfo(torch.float32).eps]).to(x_mid.device))) # projection from z axis with stability check + #torch.finfo(torch.float32).eps)) + m = torch.where(zproj == 0.0, torch.zeros_like(zproj), (y_mid - y_ref) / zproj) # to avoid divide by zero for zproj =0 + m = torch.where(z < z_ref, -1 * m, m) # sign inversion + ang = (math.pi / 2.0) - torch.atan(m) # angle correction + zmask = torch.where(zproj == 0.0, torch.zeros_like(zproj), zmask) + ang = ang * zmask # place zero where zsum is zero + ang = ang * z # weighted by position + sumz_tot = z * zmask # removing indexes with 0 energies or angles + + # zunmasked = K.sum(zmask, axis=1) # used for simple mean + # ang = K.sum(ang, axis=1)/zunmasked # Mean does not include positions where zsum=0 + + ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1) # sum ( measured * weights)/sum(weights) + ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang)) # Place 100 for measured angle where no energy is deposited in events + ang = ang.unsqueeze(1) + return ang + + def forward(self, x): + z = self.conv1(x) + z = F.leaky_relu(z) + z = self.drop1(z) + z = self.pad1(z) + z = self.conv2(z) + z = F.leaky_relu(z) + z = self.bn1(z) + z = self.drop2(z) + z = self.pad2(z) + z = self.conv3(z) + z = F.leaky_relu(z) + z = self.bn2(z) + z = self.drop3(z) + z = self.conv4(z) + z = F.leaky_relu(z) + z = self.bn3(z) + z = self.drop4(z) + z = self.avgpool(z) + z = self.flatten(z) + + fake = torch.sigmoid(self.fakeout(z)) #generation output that says fake/real + aux = self.auxout(z) #auxiliary output + inv_image = x.pow(1.0 / self.power) + ang = self.ecal_angle(inv_image, 1) # angle calculation + ecal = self.ecal_sum(inv_image, (2, 3, 4)) # sum of energies + + return fake, aux, ang, ecal + + +class ThreeDGAN(pl.LightningModule): + def __init__(self, latent_size=256, batch_size=64, loss_weights=[3, 0.1, 25, 0.1], power=0.85, lr=0.001): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization=False + + self.latent_size = latent_size + self.batch_size = batch_size + self.loss_weights = loss_weights + self.lr = lr + self.power = power + + self.generator = Generator(self.latent_size) + self.discriminator = Discriminator(self.power) + + self.epoch_gen_loss = [] + self.epoch_disc_loss = [] + self.disc_epoch_test_loss = [] + self.gen_epoch_test_loss = [] + self.index = 0 + self.train_history = defaultdict(list) + self.test_history = defaultdict(list) + self.pklfile = "/afs/cern.ch/work/k/ktsolaki/private/projects/GAN_scripts/3DGAN/Accelerated3DGAN/src/Accelerated3DGAN/results/3dgan_history_test.pkl" + + + def BitFlip(self, x, prob=0.05): + """ + Flips a single bit according to a certain probability. + + Args: + x (list): list of bits to be flipped + prob (float): probability of flipping one bit + + Returns: + list: List of flipped bits + + """ + x = np.array(x) + selection = np.random.uniform(0, 1, x.shape) < prob + x[selection] = 1 * np.logical_not(x[selection]) + return x + + def mean_absolute_percentage_error(self, y_true, y_pred): + return torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-7))) * 100 + + def compute_global_loss(self, labels, predictions, loss_weights=[3, 0.1, 25, 0.1]): + # Can be initialized outside + binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none') + #there is no equivalent in pytorch for tf.keras.losses.MeanAbsolutePercentageError --> using the custom "mean_absolute_percentage_error" above! + mean_absolute_percentage_error_object1 = self.mean_absolute_percentage_error(predictions[1], labels[1]) + mean_absolute_percentage_error_object2 = self.mean_absolute_percentage_error(predictions[3], labels[3]) + mae_object = nn.L1Loss(reduction='none') + + binary_example_loss = binary_crossentropy_object(predictions[0], labels[0]) * loss_weights[0] + + #mean_example_loss_1 = mean_absolute_percentage_error_object(predictions[1], labels[1]) * loss_weights[1] + mean_example_loss_1 = mean_absolute_percentage_error_object1 * loss_weights[1] + + mae_example_loss = mae_object(predictions[2], labels[2]) * loss_weights[2] + + #mean_example_loss_2 = mean_absolute_percentage_error_object(predictions[3], labels[3]) * loss_weights[3] + mean_example_loss_2 = mean_absolute_percentage_error_object2 * loss_weights[3] + + binary_loss = binary_example_loss.mean() + mean_loss_1 = mean_example_loss_1.mean() + mae_loss = mae_example_loss.mean() + mean_loss_2 = mean_example_loss_2.mean() + + return [binary_loss, mean_loss_1, mae_loss, mean_loss_2] + + def forward(self, z): + return self.generator(z) + + def training_step(self, batch, batch_idx): + image_batch, energy_batch, ang_batch, ecal_batch = batch['X'], batch['Y'], batch['ang'], batch['ecal'] + + image_batch = image_batch.permute(0, 4, 1, 2, 3) + + image_batch = image_batch.to(self.device) + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + ecal_batch = ecal_batch.to(self.device) + + optimizer_discriminator, optimizer_generator = self.optimizers() + + noise = torch.randn((self.batch_size, self.latent_size - 2)).to(self.device) + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), + dim=1,) + generated_images = self.generator(generator_ip) + + # Train discriminator first on real batch + fake_batch = self.BitFlip(np.ones(self.batch_size).astype(np.float32)) + fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + predictions = self.discriminator(image_batch) + print("calculating real_batch_loss...") + real_batch_loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("real_batch_loss", sum(real_batch_loss), prog_bar=True, on_step=True, on_epoch=True) + print("real batch disc train") + #the following 3 lines correspond in tf version to: + #gradients = tape.gradient(real_batch_loss, discriminator.trainable_variables) + #optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + optimizer_discriminator.zero_grad() + self.manual_backward(sum(real_batch_loss)) + #sum(real_batch_loss).backward() + #real_batch_loss.backward() + optimizer_discriminator.step() + + # Train discriminator on the fake batch + fake_batch = self.BitFlip(np.zeros(self.batch_size).astype(np.float32)) + fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + predictions = self.discriminator(generated_images) + + fake_batch_loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("fake_batch_loss", sum(fake_batch_loss), prog_bar=True, on_step=True, on_epoch=True) + print("fake batch disc train") + #the following 3 lines correspond to + #gradients = tape.gradient(fake_batch_loss, discriminator.trainable_variables) + #optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + optimizer_discriminator.zero_grad() + self.manual_backward(sum(fake_batch_loss)) + #sum(fake_batch_loss).backward() + optimizer_discriminator.step() + + trick = np.ones(self.batch_size).astype(np.float32) + fake_batch = torch.tensor([[el] for el in trick]).to(self.device) + labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch] + + gen_losses_train = [] + # Train generator twice using combined model + for _ in range(2): + noise = torch.randn((self.batch_size, self.latent_size - 2)).to(self.device) + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), + dim=1,) + + generated_images = self.generator(generator_ip) + predictions = self.discriminator(generated_images) + + loss = self.compute_global_loss( + labels, predictions, self.loss_weights) + self.log("gen_loss", sum(loss), prog_bar=True, on_step=True, on_epoch=True) + print("gen train") + optimizer_generator.zero_grad() + self.manual_backward(sum(loss)) + #sum(loss).backward() + optimizer_generator.step() + + for el in loss: + gen_losses_train.append(el) + + avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train) + self.log("generator_loss", avg_generator_loss.item(), prog_bar=True, on_step=True, on_epoch=True) + #avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)] + #self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, on_step=True, on_epoch=True) + + gen_losses = [] + # I'm not returning anything as in pl you do not return anything when you back-propagate manually + #return_loss = real_batch_loss + real_batch_loss = [real_batch_loss[0], real_batch_loss[1], real_batch_loss[2], real_batch_loss[3]] + fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1], fake_batch_loss[2], fake_batch_loss[3]] + gen_batch_loss = [gen_losses_train[0], gen_losses_train[1], gen_losses_train[2], gen_losses_train[3]] + gen_losses.append(gen_batch_loss) + gen_batch_loss = [gen_losses_train[4], gen_losses_train[5], gen_losses_train[6], gen_losses_train[7]] + gen_losses.append(gen_batch_loss) + + real_batch_loss = [el.cpu().detach().numpy() for el in real_batch_loss] + real_batch_loss_total_loss = np.sum(real_batch_loss) + new_real_batch_loss = [real_batch_loss_total_loss] + for i_weights in range(len(real_batch_loss)): + new_real_batch_loss.append(real_batch_loss[i_weights] / self.loss_weights[i_weights]) + real_batch_loss = new_real_batch_loss + + fake_batch_loss = [el.cpu().detach().numpy() for el in fake_batch_loss] + fake_batch_loss_total_loss = np.sum(fake_batch_loss) + new_fake_batch_loss = [fake_batch_loss_total_loss] + for i_weights in range(len(fake_batch_loss)): + new_fake_batch_loss.append(fake_batch_loss[i_weights] / self.loss_weights[i_weights]) + fake_batch_loss = new_fake_batch_loss + + # if ecal sum has 100% loss(generating empty events) then end the training + if fake_batch_loss[3] == 100.0 and self.index > 10: + print("Empty image with Ecal loss equal to 100.0 for {} batch".format(self.index)) + torch.save(self.generator.state_dict(), "generator_weights.pth") + torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") + print("real_batch_loss", real_batch_loss) + print("fake_batch_loss", fake_batch_loss) + sys.exit() + + # append mean of discriminator loss for real and fake events + self.epoch_disc_loss.append([(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)]) + + gen_losses[0] = [el.cpu().detach().numpy() for el in gen_losses[0]] + gen_losses_total_loss = np.sum(gen_losses[0]) + new_gen_losses = [gen_losses_total_loss] + for i_weights in range(len(gen_losses[0])): + new_gen_losses.append(gen_losses[0][i_weights] / self.loss_weights[i_weights]) + gen_losses[0] = new_gen_losses + + gen_losses[1] = [el.cpu().detach().numpy() for el in gen_losses[1]] + gen_losses_total_loss = np.sum(gen_losses[1]) + new_gen_losses = [gen_losses_total_loss] + for i_weights in range(len(gen_losses[1])): + new_gen_losses.append(gen_losses[1][i_weights] / self.loss_weights[i_weights]) + gen_losses[1] = new_gen_losses + + generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses)] + + self.epoch_gen_loss.append(generator_loss) + + #self.index += 1 #this might be moved after test cycle + + #logging of gen and disc loss done by Trainer + #self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, on_epoch=True) + #self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, on_epoch=True) + + def on_train_epoch_end(self): #outputs + discriminator_train_loss = np.mean(np.array(self.epoch_disc_loss), axis=0) + generator_train_loss = np.mean(np.array(self.epoch_gen_loss), axis=0) + + self.train_history["generator"].append(generator_train_loss) + self.train_history["discriminator"].append(discriminator_train_loss) + + print("-" * 65) + ROW_FMT = ("{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (train)", *self.train_history["generator"][-1])) + print(ROW_FMT.format("discriminator (train)", *self.train_history["discriminator"][-1])) + + torch.save(self.generator.state_dict(), "generator_weights.pth") + torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") + + with open(self.pklfile, "wb") as f: + pickle.dump({"train": self.train_history, "test": self.test_history}, f) + + #pickle.dump({"train": self.train_history}, open(self.pklfile, "wb")) + print("train-loss:" + str(self.train_history["generator"][-1][0])) + + def validation_step(self, batch, batch_idx): + image_batch, energy_batch, ang_batch, ecal_batch = batch['X'], batch['Y'], batch['ang'], batch['ecal'] + + image_batch = image_batch.permute(0, 4, 1, 2, 3) + + image_batch = image_batch.to(self.device) + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + ecal_batch = ecal_batch.to(self.device) + + # Generate Fake events with same energy and angle as data batch + noise = torch.randn((self.batch_size, self.latent_size - 2), dtype=torch.float32).to(self.device) + + generator_ip = torch.cat((energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) + generated_images = self.generator(generator_ip) + + # concatenate to fake and real batches + X = torch.cat((image_batch, generated_images), dim=0) + + #y = np.array([1] * self.batch_size + [0] * self.batch_size).astype(np.float32) + y = torch.tensor([1] * self.batch_size + [0] * self.batch_size, dtype=torch.float32).to(self.device) + y = y.view(-1, 1) + + ang = torch.cat((ang_batch, ang_batch), dim=0) + ecal = torch.cat((ecal_batch, ecal_batch), dim=0) + aux_y = torch.cat((energy_batch, energy_batch), dim=0) + + #y = [[el] for el in y] + labels = [y, aux_y, ang, ecal] + + # Calculate discriminator loss + disc_eval = self.discriminator(X) + disc_eval_loss = self.compute_global_loss(labels, disc_eval, self.loss_weights) + + # Calculate generator loss + trick = np.ones(self.batch_size).astype(np.float32) + fake_batch = torch.tensor([[el] for el in trick]).to(self.device) + #fake_batch = [[el] for el in trick] + labels = [fake_batch, energy_batch, ang_batch, ecal_batch] + + generated_images = self.generator(generator_ip) + gen_eval = self.discriminator(generated_images) + gen_eval_loss = self.compute_global_loss(labels, gen_eval, self.loss_weights) + + self.log('val_discriminator_loss', sum(disc_eval_loss), on_epoch=True, prog_bar=True) + self.log('val_generator_loss', sum(gen_eval_loss), on_epoch=True, prog_bar=True) + + disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1], disc_eval_loss[2], disc_eval_loss[3]] + gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1], gen_eval_loss[2], gen_eval_loss[3]] + + # Configure the loss so it is equal to the original values + disc_eval_loss = [el.cpu().detach().numpy() for el in disc_test_loss] + disc_eval_loss_total_loss = np.sum(disc_eval_loss) + new_disc_eval_loss = [disc_eval_loss_total_loss] + for i_weights in range(len(disc_eval_loss)): + new_disc_eval_loss.append(disc_eval_loss[i_weights] / self.loss_weights[i_weights]) + disc_eval_loss = new_disc_eval_loss + + gen_eval_loss = [el.cpu().detach().numpy() for el in gen_test_loss] + gen_eval_loss_total_loss = np.sum(gen_eval_loss) + new_gen_eval_loss = [gen_eval_loss_total_loss] + for i_weights in range(len(gen_eval_loss)): + new_gen_eval_loss.append(gen_eval_loss[i_weights] / self.loss_weights[i_weights]) + gen_eval_loss = new_gen_eval_loss + + self.index += 1 + # evaluate discriminator loss + self.disc_epoch_test_loss.append(disc_eval_loss) + # evaluate generator loss + self.gen_epoch_test_loss.append(gen_eval_loss) + + def on_validation_epoch_end(self): + discriminator_test_loss = np.mean(np.array(self.disc_epoch_test_loss), axis=0) + generator_test_loss = np.mean(np.array(self.gen_epoch_test_loss), axis=0) + + self.test_history["generator"].append(generator_test_loss) + self.test_history["discriminator"].append(discriminator_test_loss) + + print("-" * 65) + ROW_FMT = ("{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (test)", *self.test_history["generator"][-1])) + print(ROW_FMT.format("discriminator (test)", *self.test_history["discriminator"][-1])) + + # save loss dict to pkl file + with open(self.pklfile, "wb") as f: + pickle.dump({"train": self.train_history, "test": self.test_history}, f) + #pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) + #print("train-loss:" + str(self.train_history["generator"][-1][0])) + + def configure_optimizers(self): + lr = self.lr + + optimizer_discriminator = torch.optim.RMSprop(self.discriminator.parameters(), lr) + optimizer_generator = torch.optim.RMSprop(self.generator.parameters(), lr) + return [optimizer_discriminator, optimizer_generator], [] + diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml new file mode 100644 index 00000000..ff711405 --- /dev/null +++ b/use-cases/3dgan/pipeline.yaml @@ -0,0 +1,98 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: data/ + + - class_path: trainer.Lightning3DGANTrainer + init_args: + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: [0] + enable_checkpointing: null + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: null + logger: + class_path: lightning.pytorch.loggers.CSVLogger + init_args: + save_dir: "logs" + max_epochs: 1 + max_steps: 10 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 64 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + + # Lightning data module configuration + data: + class_path: dataloader.MyDataModule + init_args: + datapath: /afs/cern.ch/work/k/ktsolaki/private/projects/GAN_scripts/3DGAN/Accelerated3DGAN/src/Accelerated3DGAN/data/*.h5 + batch_size: 64 + + # Torch Optimizer configuration + # optimizer: + # class_path: torch.optim.AdamW + # init_args: + # lr: 0.001 + + # # Torch LR scheduler configuration + # lr_scheduler: + # class_path: torch.optim.lr_scheduler.ExponentialLR + # init_args: + # gamma: 0.1 \ No newline at end of file diff --git a/use-cases/3dgan/requirements.txt b/use-cases/3dgan/requirements.txt new file mode 100644 index 00000000..f3a08124 --- /dev/null +++ b/use-cases/3dgan/requirements.txt @@ -0,0 +1,3 @@ +h5py>=3.7.0 +google>=3.0.0 +protobuf>=4.24.4 \ No newline at end of file diff --git a/use-cases/3dgan/startscript b/use-cases/3dgan/startscript new file mode 100644 index 00000000..b9d2dd08 --- /dev/null +++ b/use-cases/3dgan/startscript @@ -0,0 +1,34 @@ +#!/bin/bash + +# general configuration of the job +#SBATCH --job-name=PrototypeTest +#SBATCH --account=intertwin +#SBATCH --mail-user= +#SBATCH --mail-type=ALL +#SBATCH --output=job.out +#SBATCH --error=job.err +#SBATCH --time=00:30:00 + +# configure node and process count on the CM +#SBATCH --partition=batch +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus-per-node=4 + +#SBATCH --exclusive + +# gres options have to be disabled for deepv +#SBATCH --gres=gpu:4 + +# load modules +ml --force purge +ml Stages/2023 StdEnv/2023 NVHPC/23.1 OpenMPI/4.1.4 cuDNN/8.6.0.163-CUDA-11.7 Python/3.10.4 HDF5 libaio/0.3.112 GCC/11.3.0 + +# shellcheck source=/dev/null +source ~/.bashrc + +# ON LOGIN NODE download datasets: +# $ micromamba run -p ../../.venv-pytorch python train.py -p pipeline.yaml --download-only + +srun micromamba run -p ../../.venv-pytorch python train.py -p pipeline.yaml \ No newline at end of file diff --git a/use-cases/3dgan/train.py b/use-cases/3dgan/train.py new file mode 100644 index 00000000..d04596be --- /dev/null +++ b/use-cases/3dgan/train.py @@ -0,0 +1,55 @@ +""" +Training pipeline. To run this script, use the following commands. + +On login node: + +>>> micromamba run -p ../../.venv-pytorch/ \ + python train.py -p pipeline.yaml -d + +On compute nodes: + +>>> micromamba run -p ../../.venv-pytorch/ \ + python train.py -p pipeline.yaml + +""" + +import argparse + +from itwinai.components import Executor +from itwinai.utils import parse_pipe_config +from jsonargparse import ArgumentParser + + +if __name__ == "__main__": + # Create CLI Parser + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", "--pipeline", type=str, required=True, + help='Configuration file to the pipeline to execute.' + ) + parser.add_argument( + '-d', '--download-only', + action=argparse.BooleanOptionalAction, + default=False, + help=('Whether to download only the dataset and exit execution ' + '(suggested on login nodes of HPC systems)') + ) + args = parser.parse_args() + + # Create parser for the pipeline (ordered) + pipe_parser = ArgumentParser() + pipe_parser.add_subclass_arguments(Executor, "executor") + + # Parse, Instantiate pipe + parsed = parse_pipe_config(args.pipeline, pipe_parser) + pipe = pipe_parser.instantiate_classes(parsed) + executor: Executor = getattr(pipe, 'executor') + + if args.download_only: + print('Downloading datasets and exiting...') + executor = executor[:1] + else: + print('Downloading datasets (if not already done) and running...') + executor = executor + executor.setup() + executor() diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py new file mode 100644 index 00000000..86938c27 --- /dev/null +++ b/use-cases/3dgan/trainer.py @@ -0,0 +1,45 @@ +import os +from typing import Union, Dict, Tuple, Optional, Any + +from itwinai.components import Trainer +from model import ThreeDGAN +from dataloader import MyDataModule +from lightning.pytorch.cli import LightningCLI +from utils import load_yaml + + +class Lightning3DGANTrainer(Trainer): + def __init__(self, config: Union[Dict, str]): + super().__init__() + if isinstance(config, str) and os.path.isfile(config): + # Load from YAML + config = load_yaml(config) + self.conf = config + + def train(self) -> Any: + cli = LightningCLI( + args=self.conf, + model_class=ThreeDGAN, + datamodule_class=MyDataModule, + run=False, + save_config_kwargs={ + "overwrite": True, + "config_filename": "pl-training.yml", + }, + subclass_mode_model=True, + subclass_mode_data=True, + ) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[Any, Optional[Dict]]: + result = self.train() + return result, config + + def save_state(self): + return super().save_state() + + def load_state(self): + return super().load_state() diff --git a/use-cases/3dgan/utils.py b/use-cases/3dgan/utils.py new file mode 100644 index 00000000..d04f9e63 --- /dev/null +++ b/use-cases/3dgan/utils.py @@ -0,0 +1,108 @@ +""" +Utilities for itwinai package. +""" +import os +import yaml + +from collections.abc import MutableMapping +from typing import Dict +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig + + +def load_yaml(path: str) -> Dict: + """Load YAML file as dict. + + Args: + path (str): path to YAML file. + + Raises: + exc: yaml.YAMLError for loading/parsing errors. + + Returns: + Dict: nested dict representation of parsed YAML file. + """ + with open(path, "r", encoding="utf-8") as yaml_file: + try: + loaded_config = yaml.safe_load(yaml_file) + except yaml.YAMLError as exc: + print(exc) + raise exc + return loaded_config + + +def load_yaml_with_deps_from_file(path: str) -> DictConfig: + """ + Load YAML file with OmegaConf and merge it with its dependencies + specified in the `conf-dependencies` field. + Assume that the dependencies live in the same folder of the + YAML file which is importing them. + + Args: + path (str): path to YAML file. + + Raises: + exc: yaml.YAMLError for loading/parsing errors. + + Returns: + DictConfig: nested representation of parsed YAML file. + """ + yaml_conf = load_yaml(path) + use_case_dir = os.path.dirname(path) + deps = [] + if yaml_conf.get("conf-dependencies"): + for dependency in yaml_conf["conf-dependencies"]: + deps.append(load_yaml(os.path.join(use_case_dir, dependency))) + + return OmegaConf.merge(yaml_conf, *deps) + + +def load_yaml_with_deps_from_dict(dict_conf, use_case_dir) -> DictConfig: + deps = [] + + if dict_conf.get("conf-dependencies"): + for dependency in dict_conf["conf-dependencies"]: + deps.append(load_yaml(os.path.join(use_case_dir, dependency))) + + return OmegaConf.merge(dict_conf, *deps) + + +def dynamically_import_class(name: str): + """ + Dynamically import class by module path. + Adapted from https://stackoverflow.com/a/547867 + + Args: + name (str): path to the class (e.g., mypackage.mymodule.MyClass) + + Returns: + __class__: class object. + """ + module, class_name = name.rsplit(".", 1) + mod = __import__(module, fromlist=[class_name]) + klass = getattr(mod, class_name) + return klass + + +def flatten_dict( + d: MutableMapping, parent_key: str = "", sep: str = "." +) -> MutableMapping: + """Flatten dictionary + + Args: + d (MutableMapping): nested dictionary to flatten + parent_key (str, optional): prefix for all keys. Defaults to ''. + sep (str, optional): separator for nested key concatenation. + Defaults to '.'. + + Returns: + MutableMapping: flattened dictionary with new keys. + """ + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) From 0f50aaf536d604b7145140a6304c20eacff02903 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 24 Oct 2023 17:31:13 +0200 Subject: [PATCH 02/57] ADD: Download dataset --- .gitignore | 1 + use-cases/3dgan/dataloader.py | 157 ++++++------- use-cases/3dgan/model.py | 383 +++++++++++++++++++------------ use-cases/3dgan/pipeline.yaml | 6 +- use-cases/3dgan/requirements.txt | 3 +- use-cases/3dgan/trainer.py | 4 + 6 files changed, 322 insertions(+), 232 deletions(-) diff --git a/.gitignore b/.gitignore index e32d500f..3d59ba3b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pth +exp_data/ TODO /data nohup* diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index f25ad7fd..f2af7cb2 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -1,35 +1,35 @@ from typing import Optional, Tuple, Dict -import lightning as pl - -from torch.utils.data import DataLoader, random_split -from torchvision import transforms - -from itwinai.components import DataGetter +import os import numpy as np import torch -import torch.nn as nn from torch.utils.data import DataLoader, Dataset +import lightning as pl import glob import h5py -import numpy as np -from collections import defaultdict -#import pytorch_lightning as pl -import sys -import pickle +import gdown + +from itwinai.components import DataGetter class Lightning3DGANDownloader(DataGetter): def __init__( self, + data_url: str, data_path: str, name: Optional[str] = None, **kwargs) -> None: super().__init__(name, **kwargs) self.data_path = data_path + self.data_url = data_url - def load(self): - ... + def load(self): + # Download data + if not os.path.exists(self.data_path): + gdown.download_folder( + url=self.data_url, quiet=False, + output=self.data_path + ) def execute( self, @@ -58,66 +58,67 @@ def fetch_data(self, datapath): concatenated_datasets = [] for datafile in Files: - f=h5py.File(datafile,'r') - dataset = self.GetDataAngleParallel(f) - concatenated_datasets.append(dataset) - result = {key: [] for key in concatenated_datasets[0].keys()} # Initialize result dictionary - for d in concatenated_datasets: - for key in result.keys(): - result[key].extend(d[key]) + f = h5py.File(datafile, 'r') + dataset = self.GetDataAngleParallel(f) + concatenated_datasets.append(dataset) + # Initialize result dictionary + result = {key: [] for key in concatenated_datasets[0].keys()} + for d in concatenated_datasets: + for key in result.keys(): + result[key].extend(d[key]) return result def GetDataAngleParallel( - self, - dataset, - xscale=1, - xpower=0.85, - yscale=100, - angscale=1, - angtype="theta", - thresh=1e-4, - daxis=-1,): - """Preprocess function for the dataset - - Args: - dataset (str): Dataset file path - xscale (int, optional): Value to scale the ECAL values. Defaults to 1. - xpower (int, optional): Value to scale the ECAL values, exponentially. Defaults to 1. - yscale (int, optional): Value to scale the energy values. Defaults to 100. - angscale (int, optional): Value to scale the angle values. Defaults to 1. - angtype (str, optional): Which type of angle to use. Defaults to "theta". - thresh (_type_, optional): Maximum value for ECAL values. Defaults to 1e-4. - daxis (int, optional): Axis to expand values. Defaults to -1. - - Returns: - Dict: Dictionary containning the preprocessed dataset - """ - X = np.array(dataset.get("ECAL")) * xscale - Y = np.array(dataset.get("energy")) / yscale - X[X < thresh] = 0 - X = X.astype(np.float32) - Y = Y.astype(np.float32) - ecal = np.sum(X, axis=(1, 2, 3)) - indexes = np.where(ecal > 10.0) - X = X[indexes] - Y = Y[indexes] - if angtype in dataset: - ang = np.array(dataset.get(angtype))[indexes] - # else: - # ang = gan.measPython(X) - X = np.expand_dims(X, axis=daxis) - ecal = ecal[indexes] - ecal = np.expand_dims(ecal, axis=daxis) - if xpower != 1.0: - X = np.power(X, xpower) - - Y = np.array([[el] for el in Y]) - ang = np.array([[el] for el in ang]) - ecal = np.array([[el] for el in ecal]) - - final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal} - - return final_dataset + self, + dataset, + xscale=1, + xpower=0.85, + yscale=100, + angscale=1, + angtype="theta", + thresh=1e-4, + daxis=-1,): + """Preprocess function for the dataset + + Args: + dataset (str): Dataset file path + xscale (int, optional): Value to scale the ECAL values. Defaults to 1. + xpower (int, optional): Value to scale the ECAL values, exponentially. Defaults to 1. + yscale (int, optional): Value to scale the energy values. Defaults to 100. + angscale (int, optional): Value to scale the angle values. Defaults to 1. + angtype (str, optional): Which type of angle to use. Defaults to "theta". + thresh (_type_, optional): Maximum value for ECAL values. Defaults to 1e-4. + daxis (int, optional): Axis to expand values. Defaults to -1. + + Returns: + Dict: Dictionary containning the preprocessed dataset + """ + X = np.array(dataset.get("ECAL")) * xscale + Y = np.array(dataset.get("energy")) / yscale + X[X < thresh] = 0 + X = X.astype(np.float32) + Y = Y.astype(np.float32) + ecal = np.sum(X, axis=(1, 2, 3)) + indexes = np.where(ecal > 10.0) + X = X[indexes] + Y = Y[indexes] + if angtype in dataset: + ang = np.array(dataset.get(angtype))[indexes] + # else: + # ang = gan.measPython(X) + X = np.expand_dims(X, axis=daxis) + ecal = ecal[indexes] + ecal = np.expand_dims(ecal, axis=daxis) + if xpower != 1.0: + X = np.power(X, xpower) + + Y = np.array([[el] for el in Y]) + ang = np.array([[el] for el in ang]) + ecal = np.array([[el] for el in ecal]) + + final_dataset = {"X": X, "Y": Y, "ang": ang, "ecal": ecal} + + return final_dataset class MyDataModule(pl.LightningDataModule): @@ -134,17 +135,17 @@ def setup(self, stage: str = None): self.dataset = MyDataset(self.datapath) dataset_length = len(self.dataset) split_point = int(dataset_length * 0.9) - self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.dataset, [split_point, dataset_length - split_point]) + self.train_dataset, self.val_dataset = torch.utils.data.random_split( + self.dataset, [split_point, dataset_length - split_point]) - #if stage == 'test' or stage is None: - #self.test_dataset = MyDataset(self.data_dir, train=False) + # if stage == 'test' or stage is None: + # self.test_dataset = MyDataset(self.data_dir, train=False) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, drop_last=True) - - #def test_dataloader(self): - #return DataLoader(self.test_dataset, batch_size=self.batch_size) + return DataLoader(self.val_dataset, batch_size=self.batch_size, drop_last=True) + # def test_dataloader(self): + # return DataLoader(self.test_dataset, batch_size=self.batch_size) diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index ac447949..2a86dfa1 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -1,58 +1,59 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F +import sys +import pickle +from collections import defaultdict import math -import lightning as pl -import numpy as np import torch import torch.nn as nn -from torch.utils.data import DataLoader, Dataset -import glob -import h5py +import torch.nn.functional as F +import lightning as pl import numpy as np -from collections import defaultdict -#import pytorch_lightning as pl -import sys -import pickle class Generator(nn.Module): - def __init__(self, latent_dim): #img_shape + def __init__(self, latent_dim): # img_shape super().__init__() - #self.img_shape = img_shape + # self.img_shape = img_shape self.latent_dim = latent_dim self.l1 = nn.Linear(self.latent_dim, 5184) - self.up1 = nn.Upsample(scale_factor=(6, 6, 6), mode='trilinear', align_corners=False) - self.conv1 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(6, 6, 8), padding=0) + self.up1 = nn.Upsample(scale_factor=( + 6, 6, 6), mode='trilinear', align_corners=False) + self.conv1 = nn.Conv3d(in_channels=8, out_channels=8, + kernel_size=(6, 6, 8), padding=0) nn.init.kaiming_uniform_(self.conv1.weight) - self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) #num_features is the number of channels (see doc) + # num_features is the number of channels (see doc) + self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv2 = nn.Conv3d(in_channels=8, out_channels=6, kernel_size=(4, 4, 6), padding=0) + self.conv2 = nn.Conv3d(in_channels=8, out_channels=6, + kernel_size=(4, 4, 6), padding=0) nn.init.kaiming_uniform_(self.conv2.weight) self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv3 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(4, 4, 6), padding=0) + self.conv3 = nn.Conv3d(in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0) nn.init.kaiming_uniform_(self.conv3.weight) self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv4 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(4, 4, 6), padding=0) + self.conv4 = nn.Conv3d(in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0) nn.init.kaiming_uniform_(self.conv4.weight) self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) - self.conv5 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(3, 3, 5), padding=0) + self.conv5 = nn.Conv3d(in_channels=6, out_channels=6, + kernel_size=(3, 3, 5), padding=0) nn.init.kaiming_uniform_(self.conv5.weight) self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) - self.conv6 = nn.Conv3d(in_channels=6, out_channels=6, kernel_size=(3, 3, 3), padding=0) + self.conv6 = nn.Conv3d(in_channels=6, out_channels=6, + kernel_size=(3, 3, 3), padding=0) nn.init.kaiming_uniform_(self.conv6.weight) - self.conv7 = nn.Conv3d(in_channels=6, out_channels=1, kernel_size=(2, 2, 2), padding=0) + self.conv7 = nn.Conv3d(in_channels=6, out_channels=1, + kernel_size=(2, 2, 2), padding=0) nn.init.xavier_normal_(self.conv7.weight) - def forward(self, z): img = self.l1(z) img = img.view(-1, 8, 9, 9, 8) @@ -85,7 +86,7 @@ def forward(self, z): img = F.relu(img) return img - + class Discriminator(nn.Module): def __init__(self, power): @@ -93,28 +94,33 @@ def __init__(self, power): self.power = power - self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(5, 6, 6), padding=(2, 3, 3)) + self.conv1 = nn.Conv3d( + in_channels=1, out_channels=16, kernel_size=(5, 6, 6), padding=(2, 3, 3)) self.drop1 = nn.Dropout(0.2) self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) - self.conv2 = nn.Conv3d(in_channels=16, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.conv2 = nn.Conv3d( + in_channels=16, out_channels=8, kernel_size=(5, 6, 6), padding=0) self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop2 = nn.Dropout(0.2) self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) - self.conv3 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.conv3 = nn.Conv3d(in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0) self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop3 = nn.Dropout(0.2) - self.conv4 = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=(5, 6, 6), padding=0) + self.conv4 = nn.Conv3d(in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0) self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop4 = nn.Dropout(0.2) self.avgpool = nn.AvgPool3d((2, 2, 2)) self.flatten = nn.Flatten() - self.fakeout = nn.Linear(19152, 1) # The input features for the Linear layer need to be calculated based on the output shape from the previous layers. - self.auxout = nn.Linear(19152, 1) # The same as above for this layer. + # The input features for the Linear layer need to be calculated based on the output shape from the previous layers. + self.fakeout = nn.Linear(19152, 1) + self.auxout = nn.Linear(19152, 1) # The same as above for this layer. # calculate sum of intensities def ecal_sum(self, image, daxis): @@ -132,25 +138,33 @@ def ecal_angle(self, image, daxis1): sumtot = torch.sum(image, dim=(1, 2, 3)) # sum of events # get 1. where event sum is 0 and 0 elsewhere - amask = torch.where(sumtot == 0.0, torch.ones_like(sumtot), torch.zeros_like(sumtot)) + amask = torch.where(sumtot == 0.0, torch.ones_like( + sumtot), torch.zeros_like(sumtot)) masked_events = torch.sum(amask) # counting zero sum events # ref denotes barycenter as that is our reference point - x_ref = torch.sum(torch.sum(image, dim=(2, 3)) - * (torch.arange(x_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), - dim=1,) # sum for x position * x index + x_ref = torch.sum(torch.sum(image, dim=(2, 3)) + * (torch.arange(x_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), + dim=1,) # sum for x position * x index y_ref = torch.sum( torch.sum(image, dim=(1, 3)) - * (torch.arange(y_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), + * (torch.arange(y_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), dim=1,) z_ref = torch.sum( torch.sum(image, dim=(1, 2)) - * (torch.arange(z_shape, device=image.device, dtype=torch.float32).unsqueeze(0) + 0.5), + * (torch.arange(z_shape, device=image.device, + dtype=torch.float32).unsqueeze(0) + 0.5), dim=1,) - x_ref = torch.where(sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot) # return max position if sumtot=0 and divide by sumtot otherwise - y_ref = torch.where(sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot) - z_ref = torch.where(sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot) + # return max position if sumtot=0 and divide by sumtot otherwise + x_ref = torch.where( + sumtot == 0.0, torch.ones_like(x_ref), x_ref / sumtot) + y_ref = torch.where( + sumtot == 0.0, torch.ones_like(y_ref), y_ref / sumtot) + z_ref = torch.where( + sumtot == 0.0, torch.ones_like(z_ref), z_ref / sumtot) # reshape x_ref = x_ref.unsqueeze(1) @@ -160,27 +174,35 @@ def ecal_angle(self, image, daxis1): sumz = torch.sum(image, dim=(1, 2)) # sum for x,y planes going along z # Get 0 where sum along z is 0 and 1 elsewhere - zmask = torch.where(sumz == 0.0, torch.zeros_like(sumz), torch.ones_like(sumz)) + zmask = torch.where(sumz == 0.0, torch.zeros_like( + sumz), torch.ones_like(sumz)) - x = torch.arange(x_shape, device=image.device).unsqueeze(0) # x indexes + x = torch.arange(x_shape, device=image.device).unsqueeze( + 0) # x indexes x = (x.unsqueeze(2).float()) + 0.5 - y = torch.arange(y_shape, device=image.device).unsqueeze(0) # y indexes + y = torch.arange(y_shape, device=image.device).unsqueeze( + 0) # y indexes y = (y.unsqueeze(2).float()) + 0.5 # barycenter for each z position x_mid = torch.sum(torch.sum(image, dim=2) * x, dim=1) y_mid = torch.sum(torch.sum(image, dim=1) * y, dim=1) - x_mid = torch.where(sumz == 0.0, torch.zeros_like(sumz), x_mid / sumz) # if sum != 0 then divide by sum - y_mid = torch.where(sumz == 0.0, torch.zeros_like(sumz), y_mid / sumz) # if sum != 0 then divide by sum + x_mid = torch.where(sumz == 0.0, torch.zeros_like( + sumz), x_mid / sumz) # if sum != 0 then divide by sum + y_mid = torch.where(sumz == 0.0, torch.zeros_like( + sumz), y_mid / sumz) # if sum != 0 then divide by sum # Angle Calculations - z = (torch.arange(z_shape, device=image.device, dtype=torch.float32) + 0.5) * torch.ones_like(z_ref) # Make an array of z indexes for all events + z = (torch.arange(z_shape, device=image.device, dtype=torch.float32) + 0.5) * \ + torch.ones_like(z_ref) # Make an array of z indexes for all events zproj = torch.sqrt( - torch.max((x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, torch.tensor([torch.finfo(torch.float32).eps]).to(x_mid.device))) # projection from z axis with stability check - #torch.finfo(torch.float32).eps)) - m = torch.where(zproj == 0.0, torch.zeros_like(zproj), (y_mid - y_ref) / zproj) # to avoid divide by zero for zproj =0 + torch.max((x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, torch.tensor([torch.finfo(torch.float32).eps]).to(x_mid.device))) # projection from z axis with stability check + # torch.finfo(torch.float32).eps)) + # to avoid divide by zero for zproj =0 + m = torch.where(zproj == 0.0, torch.zeros_like( + zproj), (y_mid - y_ref) / zproj) m = torch.where(z < z_ref, -1 * m, m) # sign inversion ang = (math.pi / 2.0) - torch.atan(m) # angle correction zmask = torch.where(zproj == 0.0, torch.zeros_like(zproj), zmask) @@ -191,8 +213,10 @@ def ecal_angle(self, image, daxis1): # zunmasked = K.sum(zmask, axis=1) # used for simple mean # ang = K.sum(ang, axis=1)/zunmasked # Mean does not include positions where zsum=0 - ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1) # sum ( measured * weights)/sum(weights) - ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang)) # Place 100 for measured angle where no energy is deposited in events + # sum ( measured * weights)/sum(weights) + ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1) + # Place 100 for measured angle where no energy is deposited in events + ang = torch.where(amask == 0.0, ang, 100.0 * torch.ones_like(ang)) ang = ang.unsqueeze(1) return ang @@ -217,20 +241,29 @@ def forward(self, x): z = self.avgpool(z) z = self.flatten(z) - fake = torch.sigmoid(self.fakeout(z)) #generation output that says fake/real - aux = self.auxout(z) #auxiliary output + # generation output that says fake/real + fake = torch.sigmoid(self.fakeout(z)) + aux = self.auxout(z) # auxiliary output inv_image = x.pow(1.0 / self.power) - ang = self.ecal_angle(inv_image, 1) # angle calculation - ecal = self.ecal_sum(inv_image, (2, 3, 4)) # sum of energies + ang = self.ecal_angle(inv_image, 1) # angle calculation + ecal = self.ecal_sum(inv_image, (2, 3, 4)) # sum of energies return fake, aux, ang, ecal - + class ThreeDGAN(pl.LightningModule): - def __init__(self, latent_size=256, batch_size=64, loss_weights=[3, 0.1, 25, 0.1], power=0.85, lr=0.001): + def __init__( + self, + latent_size=256, + batch_size=64, + loss_weights=[3, 0.1, 25, 0.1], + power=0.85, + lr=0.001, + checkpoint_path: str = '3Dgan.pth' + ): super().__init__() self.save_hyperparameters() - self.automatic_optimization=False + self.automatic_optimization = False self.latent_size = latent_size self.batch_size = batch_size @@ -248,8 +281,7 @@ def __init__(self, latent_size=256, batch_size=64, loss_weights=[3, 0.1, 25, 0.1 self.index = 0 self.train_history = defaultdict(list) self.test_history = defaultdict(list) - self.pklfile = "/afs/cern.ch/work/k/ktsolaki/private/projects/GAN_scripts/3DGAN/Accelerated3DGAN/src/Accelerated3DGAN/results/3dgan_history_test.pkl" - + self.pklfile = checkpoint_path def BitFlip(self, x, prob=0.05): """ @@ -274,20 +306,26 @@ def mean_absolute_percentage_error(self, y_true, y_pred): def compute_global_loss(self, labels, predictions, loss_weights=[3, 0.1, 25, 0.1]): # Can be initialized outside binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none') - #there is no equivalent in pytorch for tf.keras.losses.MeanAbsolutePercentageError --> using the custom "mean_absolute_percentage_error" above! - mean_absolute_percentage_error_object1 = self.mean_absolute_percentage_error(predictions[1], labels[1]) - mean_absolute_percentage_error_object2 = self.mean_absolute_percentage_error(predictions[3], labels[3]) + # there is no equivalent in pytorch for tf.keras.losses.MeanAbsolutePercentageError --> using the custom "mean_absolute_percentage_error" above! + mean_absolute_percentage_error_object1 = self.mean_absolute_percentage_error( + predictions[1], labels[1]) + mean_absolute_percentage_error_object2 = self.mean_absolute_percentage_error( + predictions[3], labels[3]) mae_object = nn.L1Loss(reduction='none') - binary_example_loss = binary_crossentropy_object(predictions[0], labels[0]) * loss_weights[0] + binary_example_loss = binary_crossentropy_object( + predictions[0], labels[0]) * loss_weights[0] - #mean_example_loss_1 = mean_absolute_percentage_error_object(predictions[1], labels[1]) * loss_weights[1] - mean_example_loss_1 = mean_absolute_percentage_error_object1 * loss_weights[1] + # mean_example_loss_1 = mean_absolute_percentage_error_object(predictions[1], labels[1]) * loss_weights[1] + mean_example_loss_1 = mean_absolute_percentage_error_object1 * \ + loss_weights[1] - mae_example_loss = mae_object(predictions[2], labels[2]) * loss_weights[2] + mae_example_loss = mae_object( + predictions[2], labels[2]) * loss_weights[2] - #mean_example_loss_2 = mean_absolute_percentage_error_object(predictions[3], labels[3]) * loss_weights[3] - mean_example_loss_2 = mean_absolute_percentage_error_object2 * loss_weights[3] + # mean_example_loss_2 = mean_absolute_percentage_error_object(predictions[3], labels[3]) * loss_weights[3] + mean_example_loss_2 = mean_absolute_percentage_error_object2 * \ + loss_weights[3] binary_loss = binary_example_loss.mean() mean_loss_1 = mean_example_loss_1.mean() @@ -300,7 +338,8 @@ def forward(self, z): return self.generator(z) def training_step(self, batch, batch_idx): - image_batch, energy_batch, ang_batch, ecal_batch = batch['X'], batch['Y'], batch['ang'], batch['ecal'] + image_batch, energy_batch, ang_batch, ecal_batch = batch[ + 'X'], batch['Y'], batch['ang'], batch['ecal'] image_batch = image_batch.permute(0, 4, 1, 2, 3) @@ -311,7 +350,8 @@ def training_step(self, batch, batch_idx): optimizer_discriminator, optimizer_generator = self.optimizers() - noise = torch.randn((self.batch_size, self.latent_size - 2)).to(self.device) + noise = torch.randn( + (self.batch_size, self.latent_size - 2)).to(self.device) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1,) @@ -326,15 +366,16 @@ def training_step(self, batch, batch_idx): print("calculating real_batch_loss...") real_batch_loss = self.compute_global_loss( labels, predictions, self.loss_weights) - self.log("real_batch_loss", sum(real_batch_loss), prog_bar=True, on_step=True, on_epoch=True) + self.log("real_batch_loss", sum(real_batch_loss), + prog_bar=True, on_step=True, on_epoch=True) print("real batch disc train") - #the following 3 lines correspond in tf version to: - #gradients = tape.gradient(real_batch_loss, discriminator.trainable_variables) - #optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + # the following 3 lines correspond in tf version to: + # gradients = tape.gradient(real_batch_loss, discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow optimizer_discriminator.zero_grad() self.manual_backward(sum(real_batch_loss)) - #sum(real_batch_loss).backward() - #real_batch_loss.backward() + # sum(real_batch_loss).backward() + # real_batch_loss.backward() optimizer_discriminator.step() # Train discriminator on the fake batch @@ -346,14 +387,15 @@ def training_step(self, batch, batch_idx): fake_batch_loss = self.compute_global_loss( labels, predictions, self.loss_weights) - self.log("fake_batch_loss", sum(fake_batch_loss), prog_bar=True, on_step=True, on_epoch=True) + self.log("fake_batch_loss", sum(fake_batch_loss), + prog_bar=True, on_step=True, on_epoch=True) print("fake batch disc train") - #the following 3 lines correspond to - #gradients = tape.gradient(fake_batch_loss, discriminator.trainable_variables) - #optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + # the following 3 lines correspond to + # gradients = tape.gradient(fake_batch_loss, discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow optimizer_discriminator.zero_grad() self.manual_backward(sum(fake_batch_loss)) - #sum(fake_batch_loss).backward() + # sum(fake_batch_loss).backward() optimizer_discriminator.step() trick = np.ones(self.batch_size).astype(np.float32) @@ -363,7 +405,8 @@ def training_step(self, batch, batch_idx): gen_losses_train = [] # Train generator twice using combined model for _ in range(2): - noise = torch.randn((self.batch_size, self.latent_size - 2)).to(self.device) + noise = torch.randn( + (self.batch_size, self.latent_size - 2)).to(self.device) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1,) @@ -373,104 +416,124 @@ def training_step(self, batch, batch_idx): loss = self.compute_global_loss( labels, predictions, self.loss_weights) - self.log("gen_loss", sum(loss), prog_bar=True, on_step=True, on_epoch=True) + self.log("gen_loss", sum(loss), prog_bar=True, + on_step=True, on_epoch=True) print("gen train") optimizer_generator.zero_grad() self.manual_backward(sum(loss)) - #sum(loss).backward() + # sum(loss).backward() optimizer_generator.step() for el in loss: gen_losses_train.append(el) avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train) - self.log("generator_loss", avg_generator_loss.item(), prog_bar=True, on_step=True, on_epoch=True) - #avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)] - #self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, on_step=True, on_epoch=True) + self.log("generator_loss", avg_generator_loss.item(), + prog_bar=True, on_step=True, on_epoch=True) + # avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)] + # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, on_step=True, on_epoch=True) gen_losses = [] # I'm not returning anything as in pl you do not return anything when you back-propagate manually - #return_loss = real_batch_loss - real_batch_loss = [real_batch_loss[0], real_batch_loss[1], real_batch_loss[2], real_batch_loss[3]] - fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1], fake_batch_loss[2], fake_batch_loss[3]] - gen_batch_loss = [gen_losses_train[0], gen_losses_train[1], gen_losses_train[2], gen_losses_train[3]] + # return_loss = real_batch_loss + real_batch_loss = [real_batch_loss[0], real_batch_loss[1], + real_batch_loss[2], real_batch_loss[3]] + fake_batch_loss = [fake_batch_loss[0], fake_batch_loss[1], + fake_batch_loss[2], fake_batch_loss[3]] + gen_batch_loss = [gen_losses_train[0], gen_losses_train[1], + gen_losses_train[2], gen_losses_train[3]] gen_losses.append(gen_batch_loss) - gen_batch_loss = [gen_losses_train[4], gen_losses_train[5], gen_losses_train[6], gen_losses_train[7]] + gen_batch_loss = [gen_losses_train[4], gen_losses_train[5], + gen_losses_train[6], gen_losses_train[7]] gen_losses.append(gen_batch_loss) real_batch_loss = [el.cpu().detach().numpy() for el in real_batch_loss] real_batch_loss_total_loss = np.sum(real_batch_loss) new_real_batch_loss = [real_batch_loss_total_loss] for i_weights in range(len(real_batch_loss)): - new_real_batch_loss.append(real_batch_loss[i_weights] / self.loss_weights[i_weights]) + new_real_batch_loss.append( + real_batch_loss[i_weights] / self.loss_weights[i_weights]) real_batch_loss = new_real_batch_loss fake_batch_loss = [el.cpu().detach().numpy() for el in fake_batch_loss] fake_batch_loss_total_loss = np.sum(fake_batch_loss) new_fake_batch_loss = [fake_batch_loss_total_loss] for i_weights in range(len(fake_batch_loss)): - new_fake_batch_loss.append(fake_batch_loss[i_weights] / self.loss_weights[i_weights]) + new_fake_batch_loss.append( + fake_batch_loss[i_weights] / self.loss_weights[i_weights]) fake_batch_loss = new_fake_batch_loss - + # if ecal sum has 100% loss(generating empty events) then end the training if fake_batch_loss[3] == 100.0 and self.index > 10: - print("Empty image with Ecal loss equal to 100.0 for {} batch".format(self.index)) - torch.save(self.generator.state_dict(), "generator_weights.pth") - torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") - print("real_batch_loss", real_batch_loss) - print("fake_batch_loss", fake_batch_loss) - sys.exit() + print( + "Empty image with Ecal loss equal to 100.0 for {} batch".format(self.index)) + torch.save(self.generator.state_dict(), "generator_weights.pth") + torch.save(self.discriminator.state_dict(), + "discriminator_weights.pth") + print("real_batch_loss", real_batch_loss) + print("fake_batch_loss", fake_batch_loss) + sys.exit() # append mean of discriminator loss for real and fake events - self.epoch_disc_loss.append([(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)]) - + self.epoch_disc_loss.append( + [(a + b) / 2 for a, b in zip(real_batch_loss, fake_batch_loss)]) + gen_losses[0] = [el.cpu().detach().numpy() for el in gen_losses[0]] gen_losses_total_loss = np.sum(gen_losses[0]) new_gen_losses = [gen_losses_total_loss] for i_weights in range(len(gen_losses[0])): - new_gen_losses.append(gen_losses[0][i_weights] / self.loss_weights[i_weights]) + new_gen_losses.append( + gen_losses[0][i_weights] / self.loss_weights[i_weights]) gen_losses[0] = new_gen_losses gen_losses[1] = [el.cpu().detach().numpy() for el in gen_losses[1]] gen_losses_total_loss = np.sum(gen_losses[1]) new_gen_losses = [gen_losses_total_loss] for i_weights in range(len(gen_losses[1])): - new_gen_losses.append(gen_losses[1][i_weights] / self.loss_weights[i_weights]) + new_gen_losses.append( + gen_losses[1][i_weights] / self.loss_weights[i_weights]) gen_losses[1] = new_gen_losses generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses)] self.epoch_gen_loss.append(generator_loss) - #self.index += 1 #this might be moved after test cycle - - #logging of gen and disc loss done by Trainer - #self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, on_epoch=True) - #self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, on_epoch=True) - - def on_train_epoch_end(self): #outputs - discriminator_train_loss = np.mean(np.array(self.epoch_disc_loss), axis=0) + # self.index += 1 #this might be moved after test cycle + + # logging of gen and disc loss done by Trainer + # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, on_epoch=True) + # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, on_epoch=True) + + def on_train_epoch_end(self): # outputs + discriminator_train_loss = np.mean( + np.array(self.epoch_disc_loss), axis=0) generator_train_loss = np.mean(np.array(self.epoch_gen_loss), axis=0) self.train_history["generator"].append(generator_train_loss) self.train_history["discriminator"].append(discriminator_train_loss) print("-" * 65) - ROW_FMT = ("{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") - print(ROW_FMT.format("generator (train)", *self.train_history["generator"][-1])) - print(ROW_FMT.format("discriminator (train)", *self.train_history["discriminator"][-1])) + ROW_FMT = ( + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (train)", + *self.train_history["generator"][-1])) + print(ROW_FMT.format("discriminator (train)", + *self.train_history["discriminator"][-1])) torch.save(self.generator.state_dict(), "generator_weights.pth") - torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") + torch.save(self.discriminator.state_dict(), + "discriminator_weights.pth") with open(self.pklfile, "wb") as f: - pickle.dump({"train": self.train_history, "test": self.test_history}, f) + pickle.dump({"train": self.train_history, + "test": self.test_history}, f) - #pickle.dump({"train": self.train_history}, open(self.pklfile, "wb")) + # pickle.dump({"train": self.train_history}, open(self.pklfile, "wb")) print("train-loss:" + str(self.train_history["generator"][-1][0])) - + def validation_step(self, batch, batch_idx): - image_batch, energy_batch, ang_batch, ecal_batch = batch['X'], batch['Y'], batch['ang'], batch['ecal'] + image_batch, energy_batch, ang_batch, ecal_batch = batch[ + 'X'], batch['Y'], batch['ang'], batch['ecal'] image_batch = image_batch.permute(0, 4, 1, 2, 3) @@ -479,59 +542,70 @@ def validation_step(self, batch, batch_idx): ang_batch = ang_batch.to(self.device) ecal_batch = ecal_batch.to(self.device) - # Generate Fake events with same energy and angle as data batch - noise = torch.randn((self.batch_size, self.latent_size - 2), dtype=torch.float32).to(self.device) - - generator_ip = torch.cat((energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) + # Generate Fake events with same energy and angle as data batch + noise = torch.randn((self.batch_size, self.latent_size - 2), + dtype=torch.float32).to(self.device) + + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) generated_images = self.generator(generator_ip) # concatenate to fake and real batches X = torch.cat((image_batch, generated_images), dim=0) - #y = np.array([1] * self.batch_size + [0] * self.batch_size).astype(np.float32) - y = torch.tensor([1] * self.batch_size + [0] * self.batch_size, dtype=torch.float32).to(self.device) + # y = np.array([1] * self.batch_size + [0] * self.batch_size).astype(np.float32) + y = torch.tensor([1] * self.batch_size + [0] * + self.batch_size, dtype=torch.float32).to(self.device) y = y.view(-1, 1) ang = torch.cat((ang_batch, ang_batch), dim=0) ecal = torch.cat((ecal_batch, ecal_batch), dim=0) aux_y = torch.cat((energy_batch, energy_batch), dim=0) - #y = [[el] for el in y] + # y = [[el] for el in y] labels = [y, aux_y, ang, ecal] # Calculate discriminator loss disc_eval = self.discriminator(X) - disc_eval_loss = self.compute_global_loss(labels, disc_eval, self.loss_weights) - + disc_eval_loss = self.compute_global_loss( + labels, disc_eval, self.loss_weights) + # Calculate generator loss trick = np.ones(self.batch_size).astype(np.float32) fake_batch = torch.tensor([[el] for el in trick]).to(self.device) - #fake_batch = [[el] for el in trick] + # fake_batch = [[el] for el in trick] labels = [fake_batch, energy_batch, ang_batch, ecal_batch] generated_images = self.generator(generator_ip) gen_eval = self.discriminator(generated_images) - gen_eval_loss = self.compute_global_loss(labels, gen_eval, self.loss_weights) + gen_eval_loss = self.compute_global_loss( + labels, gen_eval, self.loss_weights) - self.log('val_discriminator_loss', sum(disc_eval_loss), on_epoch=True, prog_bar=True) - self.log('val_generator_loss', sum(gen_eval_loss), on_epoch=True, prog_bar=True) + self.log('val_discriminator_loss', sum( + disc_eval_loss), on_epoch=True, prog_bar=True) + self.log('val_generator_loss', sum(gen_eval_loss), + on_epoch=True, prog_bar=True) - disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1], disc_eval_loss[2], disc_eval_loss[3]] - gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1], gen_eval_loss[2], gen_eval_loss[3]] + disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1], + disc_eval_loss[2], disc_eval_loss[3]] + gen_test_loss = [gen_eval_loss[0], gen_eval_loss[1], + gen_eval_loss[2], gen_eval_loss[3]] # Configure the loss so it is equal to the original values disc_eval_loss = [el.cpu().detach().numpy() for el in disc_test_loss] disc_eval_loss_total_loss = np.sum(disc_eval_loss) new_disc_eval_loss = [disc_eval_loss_total_loss] for i_weights in range(len(disc_eval_loss)): - new_disc_eval_loss.append(disc_eval_loss[i_weights] / self.loss_weights[i_weights]) + new_disc_eval_loss.append( + disc_eval_loss[i_weights] / self.loss_weights[i_weights]) disc_eval_loss = new_disc_eval_loss gen_eval_loss = [el.cpu().detach().numpy() for el in gen_test_loss] gen_eval_loss_total_loss = np.sum(gen_eval_loss) new_gen_eval_loss = [gen_eval_loss_total_loss] for i_weights in range(len(gen_eval_loss)): - new_gen_eval_loss.append(gen_eval_loss[i_weights] / self.loss_weights[i_weights]) + new_gen_eval_loss.append( + gen_eval_loss[i_weights] / self.loss_weights[i_weights]) gen_eval_loss = new_gen_eval_loss self.index += 1 @@ -541,27 +615,34 @@ def validation_step(self, batch, batch_idx): self.gen_epoch_test_loss.append(gen_eval_loss) def on_validation_epoch_end(self): - discriminator_test_loss = np.mean(np.array(self.disc_epoch_test_loss), axis=0) - generator_test_loss = np.mean(np.array(self.gen_epoch_test_loss), axis=0) + discriminator_test_loss = np.mean( + np.array(self.disc_epoch_test_loss), axis=0) + generator_test_loss = np.mean( + np.array(self.gen_epoch_test_loss), axis=0) self.test_history["generator"].append(generator_test_loss) self.test_history["discriminator"].append(discriminator_test_loss) print("-" * 65) - ROW_FMT = ("{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") - print(ROW_FMT.format("generator (test)", *self.test_history["generator"][-1])) - print(ROW_FMT.format("discriminator (test)", *self.test_history["discriminator"][-1])) + ROW_FMT = ( + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + print(ROW_FMT.format("generator (test)", + *self.test_history["generator"][-1])) + print(ROW_FMT.format("discriminator (test)", + *self.test_history["discriminator"][-1])) # save loss dict to pkl file with open(self.pklfile, "wb") as f: - pickle.dump({"train": self.train_history, "test": self.test_history}, f) - #pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) - #print("train-loss:" + str(self.train_history["generator"][-1][0])) - + pickle.dump({"train": self.train_history, + "test": self.test_history}, f) + # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) + # print("train-loss:" + str(self.train_history["generator"][-1][0])) + def configure_optimizers(self): lr = self.lr - optimizer_discriminator = torch.optim.RMSprop(self.discriminator.parameters(), lr) - optimizer_generator = torch.optim.RMSprop(self.generator.parameters(), lr) + optimizer_discriminator = torch.optim.RMSprop( + self.discriminator.parameters(), lr) + optimizer_generator = torch.optim.RMSprop( + self.generator.parameters(), lr) return [optimizer_discriminator, optimizer_generator], [] - diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index ff711405..9916e2f4 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -4,7 +4,8 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: data/ + data_path: exp_data/ + data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANTrainer init_args: @@ -77,12 +78,13 @@ executor: loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 + checkpoint_path: exp_data/3dgan.pth # Lightning data module configuration data: class_path: dataloader.MyDataModule init_args: - datapath: /afs/cern.ch/work/k/ktsolaki/private/projects/GAN_scripts/3DGAN/Accelerated3DGAN/src/Accelerated3DGAN/data/*.h5 + datapath: exp_data/*/*.h5 batch_size: 64 # Torch Optimizer configuration diff --git a/use-cases/3dgan/requirements.txt b/use-cases/3dgan/requirements.txt index f3a08124..c06fe435 100644 --- a/use-cases/3dgan/requirements.txt +++ b/use-cases/3dgan/requirements.txt @@ -1,3 +1,4 @@ h5py>=3.7.0 google>=3.0.0 -protobuf>=4.24.4 \ No newline at end of file +protobuf>=4.24.3 +gdown>=4.7.1 \ No newline at end of file diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 86938c27..2caabc59 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -1,4 +1,5 @@ import os +import sys from typing import Union, Dict, Tuple, Optional, Any from itwinai.components import Trainer @@ -17,6 +18,8 @@ def __init__(self, config: Union[Dict, str]): self.conf = config def train(self) -> Any: + old_argv = sys.argv + sys.argv = ['some_script_placeholder.py'] cli = LightningCLI( args=self.conf, model_class=ThreeDGAN, @@ -29,6 +32,7 @@ def train(self) -> Any: subclass_mode_model=True, subclass_mode_data=True, ) + sys.argv = old_argv cli.trainer.fit(cli.model, datamodule=cli.datamodule) def execute( From 41d2b660d57fe332d820463006813e391e584ba4 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 25 Oct 2023 15:14:30 +0200 Subject: [PATCH 03/57] FIX: DDP distributed training with manual optimization --- use-cases/3dgan/dataloader.py | 32 +++-- use-cases/3dgan/model.py | 237 +++++++++++++++++++++++----------- use-cases/3dgan/pipeline.yaml | 18 +-- 3 files changed, 187 insertions(+), 100 deletions(-) diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index f2af7cb2..76893e5d 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -48,7 +48,8 @@ def __len__(self): return len(self.data["X"]) def __getitem__(self, idx): - return {"X": self.data["X"][idx], "Y": self.data["Y"][idx], "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]} + return {"X": self.data["X"][idx], "Y": self.data["Y"][idx], + "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]} def fetch_data(self, datapath): @@ -82,12 +83,18 @@ def GetDataAngleParallel( Args: dataset (str): Dataset file path - xscale (int, optional): Value to scale the ECAL values. Defaults to 1. - xpower (int, optional): Value to scale the ECAL values, exponentially. Defaults to 1. - yscale (int, optional): Value to scale the energy values. Defaults to 100. - angscale (int, optional): Value to scale the angle values. Defaults to 1. - angtype (str, optional): Which type of angle to use. Defaults to "theta". - thresh (_type_, optional): Maximum value for ECAL values. Defaults to 1e-4. + xscale (int, optional): Value to scale the ECAL values. + Defaults to 1. + xpower (int, optional): Value to scale the ECAL values, + exponentially. Defaults to 1. + yscale (int, optional): Value to scale the energy values. + Defaults to 100. + angscale (int, optional): Value to scale the angle values. + Defaults to 1. + angtype (str, optional): Which type of angle to use. + Defaults to "theta". + thresh (_type_, optional): Maximum value for ECAL values. + Defaults to 1e-4. daxis (int, optional): Axis to expand values. Defaults to -1. Returns: @@ -135,17 +142,20 @@ def setup(self, stage: str = None): self.dataset = MyDataset(self.datapath) dataset_length = len(self.dataset) split_point = int(dataset_length * 0.9) - self.train_dataset, self.val_dataset = torch.utils.data.random_split( - self.dataset, [split_point, dataset_length - split_point]) + self.train_dataset, self.val_dataset = \ + torch.utils.data.random_split( + self.dataset, [split_point, dataset_length - split_point]) # if stage == 'test' or stage is None: # self.test_dataset = MyDataset(self.data_dir, train=False) def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.batch_size, drop_last=True) + return DataLoader(self.train_dataset, num_workers=4, + batch_size=self.batch_size, drop_last=True) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, drop_last=True) + return DataLoader(self.val_dataset, num_workers=4, + batch_size=self.batch_size, drop_last=True) # def test_dataloader(self): # return DataLoader(self.test_dataset, batch_size=self.batch_size) diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 2a86dfa1..009e267e 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -17,41 +17,62 @@ def __init__(self, latent_dim): # img_shape self.latent_dim = latent_dim self.l1 = nn.Linear(self.latent_dim, 5184) - self.up1 = nn.Upsample(scale_factor=( - 6, 6, 6), mode='trilinear', align_corners=False) - self.conv1 = nn.Conv3d(in_channels=8, out_channels=8, - kernel_size=(6, 6, 8), padding=0) + self.up1 = nn.Upsample( + scale_factor=(6, 6, 6), + mode='trilinear', + align_corners=False + ) + self.conv1 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(6, 6, 8), padding=0 + ) nn.init.kaiming_uniform_(self.conv1.weight) # num_features is the number of channels (see doc) self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.pad1 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv2 = nn.Conv3d(in_channels=8, out_channels=6, - kernel_size=(4, 4, 6), padding=0) + + self.conv2 = nn.Conv3d( + in_channels=8, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) nn.init.kaiming_uniform_(self.conv2.weight) self.bn2 = nn.BatchNorm3d(num_features=6, eps=1e-6) - self.pad2 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv3 = nn.Conv3d(in_channels=6, out_channels=6, - kernel_size=(4, 4, 6), padding=0) + + self.conv3 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) nn.init.kaiming_uniform_(self.conv3.weight) self.bn3 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad3 = nn.ConstantPad3d((1, 1, 2, 2, 2, 2), 0) - self.conv4 = nn.Conv3d(in_channels=6, out_channels=6, - kernel_size=(4, 4, 6), padding=0) + + self.conv4 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(4, 4, 6), padding=0 + ) nn.init.kaiming_uniform_(self.conv4.weight) self.bn4 = nn.BatchNorm3d(num_features=6, eps=1e-6) self.pad4 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) - self.conv5 = nn.Conv3d(in_channels=6, out_channels=6, - kernel_size=(3, 3, 5), padding=0) + + self.conv5 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(3, 3, 5), padding=0 + ) nn.init.kaiming_uniform_(self.conv5.weight) self.bn5 = nn.BatchNorm3d(num_features=6, eps=1e-6) - self.pad5 = nn.ConstantPad3d((0, 0, 1, 1, 1, 1), 0) - self.conv6 = nn.Conv3d(in_channels=6, out_channels=6, - kernel_size=(3, 3, 3), padding=0) + + self.conv6 = nn.Conv3d( + in_channels=6, out_channels=6, + kernel_size=(3, 3, 3), padding=0 + ) nn.init.kaiming_uniform_(self.conv6.weight) - self.conv7 = nn.Conv3d(in_channels=6, out_channels=1, - kernel_size=(2, 2, 2), padding=0) + + self.conv7 = nn.Conv3d( + in_channels=6, out_channels=1, + kernel_size=(2, 2, 2), padding=0 + ) nn.init.xavier_normal_(self.conv7.weight) def forward(self, z): @@ -62,26 +83,30 @@ def forward(self, z): img = F.relu(img) img = self.bn1(img) img = self.pad1(img) + img = self.conv2(img) img = F.relu(img) img = self.bn2(img) - img = self.pad2(img) + img = self.conv3(img) img = F.relu(img) img = self.bn3(img) img = self.pad3(img) + img = self.conv4(img) img = F.relu(img) img = self.bn4(img) img = self.pad4(img) + img = self.conv5(img) img = F.relu(img) img = self.bn5(img) - img = self.pad5(img) + img = self.conv6(img) img = F.relu(img) + img = self.conv7(img) img = F.relu(img) @@ -95,30 +120,39 @@ def __init__(self, power): self.power = power self.conv1 = nn.Conv3d( - in_channels=1, out_channels=16, kernel_size=(5, 6, 6), padding=(2, 3, 3)) + in_channels=1, out_channels=16, + kernel_size=(5, 6, 6), padding=(2, 3, 3) + ) self.drop1 = nn.Dropout(0.2) - self.pad1 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) + self.conv2 = nn.Conv3d( - in_channels=16, out_channels=8, kernel_size=(5, 6, 6), padding=0) + in_channels=16, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) self.bn1 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop2 = nn.Dropout(0.2) - self.pad2 = nn.ConstantPad3d((1, 1, 0, 0, 0, 0), 0) - self.conv3 = nn.Conv3d(in_channels=8, out_channels=8, - kernel_size=(5, 6, 6), padding=0) + + self.conv3 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) self.bn2 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop3 = nn.Dropout(0.2) - self.conv4 = nn.Conv3d(in_channels=8, out_channels=8, - kernel_size=(5, 6, 6), padding=0) + self.conv4 = nn.Conv3d( + in_channels=8, out_channels=8, + kernel_size=(5, 6, 6), padding=0 + ) self.bn3 = nn.BatchNorm3d(num_features=8, eps=1e-6) self.drop4 = nn.Dropout(0.2) self.avgpool = nn.AvgPool3d((2, 2, 2)) self.flatten = nn.Flatten() - # The input features for the Linear layer need to be calculated based on the output shape from the previous layers. + # The input features for the Linear layer need to be calculated based + # on the output shape from the previous layers. self.fakeout = nn.Linear(19152, 1) self.auxout = nn.Linear(19152, 1) # The same as above for this layer. @@ -140,7 +174,7 @@ def ecal_angle(self, image, daxis1): # get 1. where event sum is 0 and 0 elsewhere amask = torch.where(sumtot == 0.0, torch.ones_like( sumtot), torch.zeros_like(sumtot)) - masked_events = torch.sum(amask) # counting zero sum events + # masked_events = torch.sum(amask) # counting zero sum events # ref denotes barycenter as that is our reference point x_ref = torch.sum(torch.sum(image, dim=(2, 3)) @@ -194,11 +228,22 @@ def ecal_angle(self, image, daxis1): sumz), y_mid / sumz) # if sum != 0 then divide by sum # Angle Calculations - z = (torch.arange(z_shape, device=image.device, dtype=torch.float32) + 0.5) * \ - torch.ones_like(z_ref) # Make an array of z indexes for all events - + z = (torch.arange( + z_shape, + device=image.device, + dtype=torch.float32 + # Make an array of z indexes for all events + ) + 0.5) * torch.ones_like(z_ref) + + # projection from z axis with stability check zproj = torch.sqrt( - torch.max((x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, torch.tensor([torch.finfo(torch.float32).eps]).to(x_mid.device))) # projection from z axis with stability check + torch.max( + (x_mid - x_ref) ** 2.0 + (z - z_ref) ** 2.0, + torch.tensor( + [torch.finfo(torch.float32).eps] + ).to(x_mid.device) + ) + ) # torch.finfo(torch.float32).eps)) # to avoid divide by zero for zproj =0 m = torch.where(zproj == 0.0, torch.zeros_like( @@ -211,7 +256,8 @@ def ecal_angle(self, image, daxis1): sumz_tot = z * zmask # removing indexes with 0 energies or angles # zunmasked = K.sum(zmask, axis=1) # used for simple mean - # ang = K.sum(ang, axis=1)/zunmasked # Mean does not include positions where zsum=0 + # Mean does not include positions where zsum=0 + # ang = K.sum(ang, axis=1)/zunmasked # sum ( measured * weights)/sum(weights) ang = torch.sum(ang, dim=1) / torch.sum(sumz_tot, dim=1) @@ -225,15 +271,18 @@ def forward(self, x): z = F.leaky_relu(z) z = self.drop1(z) z = self.pad1(z) + z = self.conv2(z) z = F.leaky_relu(z) z = self.bn1(z) z = self.drop2(z) z = self.pad2(z) + z = self.conv3(z) z = F.leaky_relu(z) z = self.bn2(z) z = self.drop3(z) + z = self.conv4(z) z = F.leaky_relu(z) z = self.bn3(z) @@ -303,29 +352,38 @@ def BitFlip(self, x, prob=0.05): def mean_absolute_percentage_error(self, y_true, y_pred): return torch.mean(torch.abs((y_true - y_pred) / (y_true + 1e-7))) * 100 - def compute_global_loss(self, labels, predictions, loss_weights=[3, 0.1, 25, 0.1]): + def compute_global_loss( + self, + labels, + predictions, + loss_weights=(3, 0.1, 25, 0.1) + ): # Can be initialized outside binary_crossentropy_object = nn.BCEWithLogitsLoss(reduction='none') - # there is no equivalent in pytorch for tf.keras.losses.MeanAbsolutePercentageError --> using the custom "mean_absolute_percentage_error" above! - mean_absolute_percentage_error_object1 = self.mean_absolute_percentage_error( - predictions[1], labels[1]) - mean_absolute_percentage_error_object2 = self.mean_absolute_percentage_error( - predictions[3], labels[3]) + # there is no equivalent in pytorch for + # tf.keras.losses.MeanAbsolutePercentageError --> using the + # custom "mean_absolute_percentage_error" above! + mean_absolute_percentage_error_object1 = \ + self.mean_absolute_percentage_error(predictions[1], labels[1]) + mean_absolute_percentage_error_object2 = \ + self.mean_absolute_percentage_error(predictions[3], labels[3]) mae_object = nn.L1Loss(reduction='none') binary_example_loss = binary_crossentropy_object( predictions[0], labels[0]) * loss_weights[0] - # mean_example_loss_1 = mean_absolute_percentage_error_object(predictions[1], labels[1]) * loss_weights[1] - mean_example_loss_1 = mean_absolute_percentage_error_object1 * \ - loss_weights[1] + # mean_example_loss_1 = mean_absolute_percentage_error_object( + # predictions[1], labels[1]) * loss_weights[1] + mean_example_loss_1 = \ + mean_absolute_percentage_error_object1 * loss_weights[1] mae_example_loss = mae_object( predictions[2], labels[2]) * loss_weights[2] - # mean_example_loss_2 = mean_absolute_percentage_error_object(predictions[3], labels[3]) * loss_weights[3] - mean_example_loss_2 = mean_absolute_percentage_error_object2 * \ - loss_weights[3] + # mean_example_loss_2 = mean_absolute_percentage_error_object( + # predictions[3], labels[3]) * loss_weights[3] + mean_example_loss_2 = \ + mean_absolute_percentage_error_object2 * loss_weights[3] binary_loss = binary_example_loss.mean() mean_loss_1 = mean_example_loss_1.mean() @@ -338,8 +396,8 @@ def forward(self, z): return self.generator(z) def training_step(self, batch, batch_idx): - image_batch, energy_batch, ang_batch, ecal_batch = batch[ - 'X'], batch['Y'], batch['ang'], batch['ecal'] + image_batch, energy_batch, ang_batch, ecal_batch = \ + batch['X'], batch['Y'], batch['ang'], batch['ecal'] image_batch = image_batch.permute(0, 4, 1, 2, 3) @@ -351,10 +409,13 @@ def training_step(self, batch, batch_idx): optimizer_discriminator, optimizer_generator = self.optimizers() noise = torch.randn( - (self.batch_size, self.latent_size - 2)).to(self.device) + (self.batch_size, self.latent_size - 2), + device=self.device + ) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), - dim=1,) + dim=1 + ) generated_images = self.generator(generator_ip) # Train discriminator first on real batch @@ -367,11 +428,13 @@ def training_step(self, batch, batch_idx): real_batch_loss = self.compute_global_loss( labels, predictions, self.loss_weights) self.log("real_batch_loss", sum(real_batch_loss), - prog_bar=True, on_step=True, on_epoch=True) + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) print("real batch disc train") # the following 3 lines correspond in tf version to: - # gradients = tape.gradient(real_batch_loss, discriminator.trainable_variables) - # optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + # gradients = tape.gradient(real_batch_loss, + # discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, + # discriminator.trainable_variables)) in Tensorflow optimizer_discriminator.zero_grad() self.manual_backward(sum(real_batch_loss)) # sum(real_batch_loss).backward() @@ -388,16 +451,20 @@ def training_step(self, batch, batch_idx): fake_batch_loss = self.compute_global_loss( labels, predictions, self.loss_weights) self.log("fake_batch_loss", sum(fake_batch_loss), - prog_bar=True, on_step=True, on_epoch=True) + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) print("fake batch disc train") # the following 3 lines correspond to - # gradients = tape.gradient(fake_batch_loss, discriminator.trainable_variables) - # optimizer_discriminator.apply_gradients(zip(gradients, discriminator.trainable_variables)) in Tensorflow + # gradients = tape.gradient(fake_batch_loss, + # discriminator.trainable_variables) + # optimizer_discriminator.apply_gradients(zip(gradients, + # discriminator.trainable_variables)) in Tensorflow optimizer_discriminator.zero_grad() self.manual_backward(sum(fake_batch_loss)) # sum(fake_batch_loss).backward() optimizer_discriminator.step() + # avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2 + trick = np.ones(self.batch_size).astype(np.float32) fake_batch = torch.tensor([[el] for el in trick]).to(self.device) labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch] @@ -409,7 +476,8 @@ def training_step(self, batch, batch_idx): (self.batch_size, self.latent_size - 2)).to(self.device) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), - dim=1,) + dim=1 + ) generated_images = self.generator(generator_ip) predictions = self.discriminator(generated_images) @@ -417,7 +485,7 @@ def training_step(self, batch, batch_idx): loss = self.compute_global_loss( labels, predictions, self.loss_weights) self.log("gen_loss", sum(loss), prog_bar=True, - on_step=True, on_epoch=True) + on_step=True, on_epoch=True, sync_dist=True) print("gen train") optimizer_generator.zero_grad() self.manual_backward(sum(loss)) @@ -429,12 +497,14 @@ def training_step(self, batch, batch_idx): avg_generator_loss = sum(gen_losses_train) / len(gen_losses_train) self.log("generator_loss", avg_generator_loss.item(), - prog_bar=True, on_step=True, on_epoch=True) + prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) # avg_generator_loss = [(a + b) / 2 for a, b in zip(*gen_losses_train)] - # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, on_step=True, on_epoch=True) + # self.log("generator_loss", sum(avg_generator_loss), prog_bar=True, + # on_step=True, on_epoch=True, sync_dist=True) gen_losses = [] - # I'm not returning anything as in pl you do not return anything when you back-propagate manually + # I'm not returning anything as in pl you do not return anything when + # you back-propagate manually # return_loss = real_batch_loss real_batch_loss = [real_batch_loss[0], real_batch_loss[1], real_batch_loss[2], real_batch_loss[3]] @@ -463,10 +533,11 @@ def training_step(self, batch, batch_idx): fake_batch_loss[i_weights] / self.loss_weights[i_weights]) fake_batch_loss = new_fake_batch_loss - # if ecal sum has 100% loss(generating empty events) then end the training + # if ecal sum has 100% loss(generating empty events) then end + # the training if fake_batch_loss[3] == 100.0 and self.index > 10: - print( - "Empty image with Ecal loss equal to 100.0 for {} batch".format(self.index)) + print("Empty image with Ecal loss equal to 100.0 " + f"for {self.index} batch") torch.save(self.generator.state_dict(), "generator_weights.pth") torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") @@ -498,11 +569,22 @@ def training_step(self, batch, batch_idx): self.epoch_gen_loss.append(generator_loss) + # # MB: verify weight synchronization among workers + # # Ref: https://github.com/Lightning-AI/lightning/issues/9237 + # disc_w = self.discriminator.conv1.weight.reshape(-1)[0:5] + # gen_w = self.generator.conv1.weight.reshape(-1)[0:5] + # print(f"DISC w: {disc_w}") + # print(f"GEN w: {gen_w}") + # self.index += 1 #this might be moved after test cycle # logging of gen and disc loss done by Trainer - # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, on_epoch=True) - # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, on_epoch=True) + # self.log('epoch_gen_loss', self.epoch_gen_loss, on_step=True, + # on_epoch=True, sync_dist=True) + # self.log('epoch_disc_loss', self.epoch_disc_loss, on_step=True, o + # n_epoch=True, sync_dist=True) + + # return avg_disc_loss + avg_generator_loss def on_train_epoch_end(self): # outputs discriminator_train_loss = np.mean( @@ -514,7 +596,8 @@ def on_train_epoch_end(self): # outputs print("-" * 65) ROW_FMT = ( - "{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | " + "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}") print(ROW_FMT.format("generator (train)", *self.train_history["generator"][-1])) print(ROW_FMT.format("discriminator (train)", @@ -553,7 +636,8 @@ def validation_step(self, batch, batch_idx): # concatenate to fake and real batches X = torch.cat((image_batch, generated_images), dim=0) - # y = np.array([1] * self.batch_size + [0] * self.batch_size).astype(np.float32) + # y = np.array([1] * self.batch_size \ + # + [0] * self.batch_size).astype(np.float32) y = torch.tensor([1] * self.batch_size + [0] * self.batch_size, dtype=torch.float32).to(self.device) y = y.view(-1, 1) @@ -582,9 +666,9 @@ def validation_step(self, batch, batch_idx): labels, gen_eval, self.loss_weights) self.log('val_discriminator_loss', sum( - disc_eval_loss), on_epoch=True, prog_bar=True) + disc_eval_loss), on_epoch=True, prog_bar=True, sync_dist=True) self.log('val_generator_loss', sum(gen_eval_loss), - on_epoch=True, prog_bar=True) + on_epoch=True, prog_bar=True, sync_dist=True) disc_test_loss = [disc_eval_loss[0], disc_eval_loss[1], disc_eval_loss[2], disc_eval_loss[3]] @@ -625,7 +709,8 @@ def on_validation_epoch_end(self): print("-" * 65) ROW_FMT = ( - "{0:<20s} | {1:<4.2f} | {2:<10.2f} | {3:<10.2f}| {4:<10.2f} | {5:<10.2f}") + "{0:<20s} | {1:<4.2f} | {2:<10.2f} | " + "{3:<10.2f}| {4:<10.2f} | {5:<10.2f}") print(ROW_FMT.format("generator (test)", *self.test_history["generator"][-1])) print(ROW_FMT.format("discriminator (test)", @@ -642,7 +727,11 @@ def configure_optimizers(self): lr = self.lr optimizer_discriminator = torch.optim.RMSprop( - self.discriminator.parameters(), lr) + self.discriminator.parameters(), + lr + ) optimizer_generator = torch.optim.RMSprop( - self.generator.parameters(), lr) + self.generator.parameters(), + lr + ) return [optimizer_discriminator, optimizer_generator], [] diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 9916e2f4..abc8905f 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -37,7 +37,7 @@ executor: default_root_dir: null detect_anomaly: false deterministic: null - devices: [0] + devices: auto #[0] enable_checkpointing: null enable_model_summary: null enable_progress_bar: null @@ -49,7 +49,7 @@ executor: limit_test_batches: null limit_train_batches: null limit_val_batches: null - log_every_n_steps: null + log_every_n_steps: 10 logger: class_path: lightning.pytorch.loggers.CSVLogger init_args: @@ -64,7 +64,7 @@ executor: plugins: null profiler: null reload_dataloaders_every_n_epochs: 0 - strategy: auto + strategy: ddp_find_unused_parameters_true #auto sync_batchnorm: false use_distributed_sampler: true val_check_interval: null @@ -86,15 +86,3 @@ executor: init_args: datapath: exp_data/*/*.h5 batch_size: 64 - - # Torch Optimizer configuration - # optimizer: - # class_path: torch.optim.AdamW - # init_args: - # lr: 0.001 - - # # Torch LR scheduler configuration - # lr_scheduler: - # class_path: torch.optim.lr_scheduler.ExponentialLR - # init_args: - # gamma: 0.1 \ No newline at end of file From ddfa59d39d74d000ee6ae8ce25aac89a7d1e52d6 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 25 Oct 2023 16:57:56 +0200 Subject: [PATCH 04/57] ADD: log with MLFlow --- .gitignore | 1 + use-cases/3dgan/README.md | 9 +++++++++ use-cases/3dgan/pipeline.yaml | 15 ++++++++++----- use-cases/3dgan/startscript | 2 +- 4 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 use-cases/3dgan/README.md diff --git a/.gitignore b/.gitignore index 3d59ba3b..659e8395 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pth +*_logs exp_data/ TODO /data diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md new file mode 100644 index 00000000..acfe4208 --- /dev/null +++ b/use-cases/3dgan/README.md @@ -0,0 +1,9 @@ +# 3DGAN use case + +To visualize the logs with MLFLow run the following in the terminal: + +```bash +micromamba run -p ../../.venv-pytorch mlflow ui --backend-store-uri ml_logs/mlflow_logs +``` + +And select the "3DGAN" experiment. diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index abc8905f..32d6d5b1 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -38,7 +38,7 @@ executor: detect_anomaly: false deterministic: null devices: auto #[0] - enable_checkpointing: null + enable_checkpointing: true enable_model_summary: null enable_progress_bar: null fast_dev_run: false @@ -49,13 +49,18 @@ executor: limit_test_batches: null limit_train_batches: null limit_val_batches: null - log_every_n_steps: 10 + log_every_n_steps: 2 logger: - class_path: lightning.pytorch.loggers.CSVLogger + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger init_args: - save_dir: "logs" + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all max_epochs: 1 - max_steps: 10 + max_steps: 20 max_time: null min_epochs: null min_steps: null diff --git a/use-cases/3dgan/startscript b/use-cases/3dgan/startscript index b9d2dd08..579ce3b3 100644 --- a/use-cases/3dgan/startscript +++ b/use-cases/3dgan/startscript @@ -11,7 +11,7 @@ # configure node and process count on the CM #SBATCH --partition=batch -#SBATCH --nodes=1 +#SBATCH --nodes=2 #SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=4 #SBATCH --gpus-per-node=4 From e89a433af960358f3760b9174a0e3d8f1d03422b Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Wed, 25 Oct 2023 14:02:17 +0200 Subject: [PATCH 05/57] Sqaaas code (#88) * Create sqaaas.yml * Update sqaaas.yml * Update sqaaas.yml * Point to the current repo * Remove unnecessary checkout step * Rename step --------- Co-authored-by: orviz --- .github/workflows/sqaaas.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .github/workflows/sqaaas.yml diff --git a/.github/workflows/sqaaas.yml b/.github/workflows/sqaaas.yml new file mode 100644 index 00000000..3a1fd0b5 --- /dev/null +++ b/.github/workflows/sqaaas.yml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright contributors to the Software Quality Assurance as a Service (SQAaaS) project. +# +# SPDX-License-Identifier: GPL-3.0-only +--- +name: SQAaaS + +on: [push] + +jobs: + sqaaas_job: + runs-on: ubuntu-latest + name: Job that triggers SQAaaS platform + steps: + - name: SQAaaS assessment step + uses: eosc-synergy/sqaaas-assessment-action@v1 + with: + repo: 'https://github.com/interTwin-eu/itwinai' + branch: 'sqaaas-code' From adc6c9121cca9a060981f5c70c1b756a7b6a40e6 Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:54:15 +0200 Subject: [PATCH 06/57] Sqaaas code (#89) * Create sqaaas.yml * Update sqaaas.yml * Update sqaaas.yml * Point to the current repo * Remove unnecessary checkout step * Rename step * ADD: adaptive branch discovery for SQAaaS action * Update sqaaas.yml --------- Co-authored-by: orviz --- .github/workflows/sqaaas.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/sqaaas.yml b/.github/workflows/sqaaas.yml index 3a1fd0b5..6e7ca329 100644 --- a/.github/workflows/sqaaas.yml +++ b/.github/workflows/sqaaas.yml @@ -11,8 +11,16 @@ jobs: runs-on: ubuntu-latest name: Job that triggers SQAaaS platform steps: + - name: Extract branch name + shell: bash + run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT + id: extract_branch + - name: Print current branch name (debug) + shell: bash + run: echo running on branch ${{ steps.extract_branch.outputs.branch }} - name: SQAaaS assessment step uses: eosc-synergy/sqaaas-assessment-action@v1 with: repo: 'https://github.com/interTwin-eu/itwinai' - branch: 'sqaaas-code' + branch: ${{ steps.extract_branch.outputs.branch }} + From 291b4f361671f3038d6a35b9067af53c0598f1bb Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Mon, 6 Nov 2023 13:57:52 +0100 Subject: [PATCH 07/57] ADD: draft predictor and saver --- src/itwinai/components.py | 18 +++-- src/itwinai/torch/predictor.py | 124 +++++++++++++++++++++++++++++ use-cases/mnist/torch/inference.py | 0 use-cases/mnist/torch/saver.py | 59 ++++++++++++++ 4 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 src/itwinai/torch/predictor.py create mode 100644 use-cases/mnist/torch/inference.py create mode 100644 use-cases/mnist/torch/saver.py diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 2e7e2e79..0df7b4f6 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -136,6 +136,7 @@ def _printout(self, msg: str): class Trainer(Executable): + """Trains a machine learning model.""" @abstractmethod def train(self, *args, **kwargs): pass @@ -149,6 +150,13 @@ def load_state(self): pass +class Predictor(Executable): + """Applies a pre-trained machine learning model to unseen data.""" + @abstractmethod + def predict(self, *args, **kwargs): + pass + + class DataGetter(Executable): @abstractmethod def load(self, *args, **kwargs): @@ -167,18 +175,12 @@ def preproc(self, *args, **kwargs): # pass -class Evaluator(Executable): +class Saver(Executable): @abstractmethod - def evaluate(self, *args, **kwargs): + def save(self, *args, **kwargs): pass -# class Saver(Executable): -# @abstractmethod -# def save(self, *args, **kwargs): -# pass - - class Executor(Executable): """Sets-up and executes a sequence of Executable steps.""" diff --git a/src/itwinai/torch/predictor.py b/src/itwinai/torch/predictor.py new file mode 100644 index 00000000..33f7910a --- /dev/null +++ b/src/itwinai/torch/predictor.py @@ -0,0 +1,124 @@ +from typing import Optional, Tuple, Dict, Any, List + +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from ..utils import dynamically_import_class +from .utils import clear_key +from ..components import Predictor +from .types import TorchDistributedStrategy as StrategyT +from .types import Metric + + +class TorchPredictor(Predictor): + """Applies a pre-trained torch model to unseen data. + + Args: + model (nn.Module): neural network instance. + test_dataloader_class (str, optional): test dataloader class path. + Defaults to 'torch.utils.data.DataLoader'. + test_dataloader_kwargs (Optional[Dict], optional): constructor + arguments of the test dataloader, except for the dataset + instance. Defaults to None. + strategy (Optional[TorchDistributedStrategy], optional): distributed + strategy. Defaults to StrategyT.NONE.value. + backend (TorchDistributedBackend, optional): computing backend. + Defaults to BackendT.NCCL.value. + shuffle_dataset (bool, optional): whether shuffle dataset before + sampling batches from dataloader. Defaults to False. + use_cuda (bool, optional): whether to use GPU. Defaults to True. + benchrun (bool, optional): sets up a debug run. Defaults to False. + testrun (bool, optional): deterministic training seeding everything. + Defaults to False. + seed (Optional[int], optional): random seed. Defaults to None. + logger (Optional[List[Logger]], optional): logger. Defaults to None. + checkpoint_every (int, optional): how often (epochs) to checkpoint the + best model. Defaults to 10. + cluster (Optional[ClusterEnvironment], optional): cluster environment + object describing the context in which the trainer is executed. + Defaults to None. + train_metrics (Optional[Dict[str, Metric]], optional): + list of metrics computed in the training step on the predictions. + It's a dictionary with the form + ``{'metric_unique_name': CallableMetric}``. Defaults to None. + validation_metrics (Optional[Dict[str, Metric]], optional): same + as ``training_metrics``. If not given, it mirrors the training + metrics. Defaults to None. + + Raises: + RuntimeError: When trying to use DDP without CUDA support. + NotImplementedError: when trying to use a strategy different from the + ones provided by TorchDistributedStrategy. + """ + + model: nn.Module = None + test_dataset: Dataset + test_dataloader: DataLoader = None + _strategy: StrategyT = StrategyT.NONE.value + epoch_idx: int = 0 + train_glob_step: int = 0 + validation_glob_step: int = 0 + train_metrics: Dict[str, Metric] + validation_metrics: Dict[str, Metric] + + def __init__( + self, + model: nn.Module, + test_dataloader_class: str = 'torch.utils.data.DataLoader', + test_dataloader_kwargs: Optional[Dict] = None, + # strategy: str = StrategyT.NONE.value, + # seed: Optional[int] = None, + # logger: Optional[List[Logger]] = None, + # cluster: Optional[ClusterEnvironment] = None, + # test_metrics: Optional[Dict[str, Metric]] = None, + ) -> None: + super().__init__() + self.model = model + # self.seed = seed + # self.strategy = strategy + # self.cluster = cluster + + # Train and validation dataloaders + self.test_dataloader_class = dynamically_import_class( + test_dataloader_class + ) + test_dataloader_kwargs = ( + test_dataloader_kwargs + if test_dataloader_kwargs is not None else {} + ) + self.test_dataloader_kwargs = clear_key( + test_dataloader_kwargs, 'train_dataloader_kwargs', 'dataset' + ) + + # # Loggers + # self.logger = logger if logger is not None else ConsoleLogger() + + # # Metrics + # self.train_metrics = ( + # {} if train_metrics is None else train_metrics + # ) + # self.validation_metrics = ( + # self.train_metrics if validation_metrics is None + # else validation_metrics + # ) + + def execute( + self, + test_dataset: Dataset, + model: nn.Module = None, + config: Optional[Dict] = None + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + self.test_dataset = test_dataset + self.test_dataloader = self.test_dataloader_class( + test_dataset, **self.test_dataloader_kwargs + ) + # Update model passed for "interactive" use + if model is not None: + self.model = model + result = self.predict() + return ((result,), config) + + def predict(self) -> List[Any]: + """Returns a list of predictions.""" + + return [] diff --git a/use-cases/mnist/torch/inference.py b/use-cases/mnist/torch/inference.py new file mode 100644 index 00000000..e69de29b diff --git a/use-cases/mnist/torch/saver.py b/use-cases/mnist/torch/saver.py new file mode 100644 index 00000000..7150475f --- /dev/null +++ b/use-cases/mnist/torch/saver.py @@ -0,0 +1,59 @@ +from typing import Optional, List, Dict, Tuple +import os +import shutil +import csv + +from itwinai.components import Saver + + +class TorchMNISTLabelSaver(Saver): + """Serializes to disk the labels predicted for MNIST dataset.""" + + def __init__( + self, + save_dir: str = 'mnist_predictions', + class_labels: Optional[List] = None + ) -> None: + super().__init__() + self.save_dir = save_dir + self.class_labels = ( + class_labels if class_labels is not None + else [f'Digit {i}' for i in range(10)] + ) + + def execute( + self, + predicted_classes: Dict[str, int], + config: Optional[Dict] = None + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """Translate predictions from class idx to class label and save + them to disk. + + Args: + predicted_classes (Dict[str, int]): maps unique item ID to + the predicted class ID. + config (Optional[Dict], optional): inherited configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited + configuration and saver return value. + """ + if os.path.exists(self.save_dir): + shutil.rmtree(self.save_dir) + os.makedirs(self.save_dir) + + # Map class idx (int) to class label (str) + predicted_labels = { + itm_name: self.class_labels[cls_idx] + for itm_name, cls_idx in predicted_classes.items() + } + result = self.save(predicted_labels) + return ((result,), config) + + def save(self, predicted_labels: Dict[str, str]) -> None: + filepath = os.path.join(self.save_dir, 'predictions.csv') + with open(filepath, 'w') as csv_file: + writer = csv.writer(csv_file) + for key, value in predicted_labels.items(): + writer.writerow([key, value]) From 7da9ba40a543b0bc7c21a3317864daf31d38b4b7 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Mon, 6 Nov 2023 17:43:35 +0100 Subject: [PATCH 08/57] ADD: stub for inference pipeline --- src/itwinai/serialization.py | 14 +++ .../torch/{predictor.py => inference.py} | 63 +++++++++++- use-cases/mnist/torch/dataloader.py | 95 ++++++++++++++++++- use-cases/mnist/torch/inference-pipeline.yaml | 20 ++++ 4 files changed, 188 insertions(+), 4 deletions(-) create mode 100644 src/itwinai/serialization.py rename src/itwinai/torch/{predictor.py => inference.py} (70%) create mode 100644 use-cases/mnist/torch/inference-pipeline.yaml diff --git a/src/itwinai/serialization.py b/src/itwinai/serialization.py new file mode 100644 index 00000000..5f8fa769 --- /dev/null +++ b/src/itwinai/serialization.py @@ -0,0 +1,14 @@ +from typing import Any +import abc + + +class ModelLoader(abc.ABC): + """Loads a machine learning model from somewhere.""" + + def __init__(self, model_uri: str) -> None: + super().__init__() + self.model_uri = model_uri + + @abc.abstractmethod + def __call__(self) -> Any: + """Loads model from model URI.""" diff --git a/src/itwinai/torch/predictor.py b/src/itwinai/torch/inference.py similarity index 70% rename from src/itwinai/torch/predictor.py rename to src/itwinai/torch/inference.py index 33f7910a..d83c6d2e 100644 --- a/src/itwinai/torch/predictor.py +++ b/src/itwinai/torch/inference.py @@ -1,5 +1,7 @@ -from typing import Optional, Tuple, Dict, Any, List +from typing import Optional, Tuple, Dict, Any, List, Union +import os +import torch from torch import nn from torch.utils.data import DataLoader, Dataset @@ -8,6 +10,60 @@ from ..components import Predictor from .types import TorchDistributedStrategy as StrategyT from .types import Metric +from ..serialization import ModelLoader + + +class TorchModelLoader(ModelLoader): + """Loads a torch model from somewhere. + + Args: + model_uri (str): Can be a path on local filesystem + or an mlflow 'locator' in the form: + 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + """ + + def __call__(self) -> nn.Module: + """"Loads model from model URI. + + Raises: + ValueError: if the model URI is not recognized + or the model is not found. + + Returns: + nn.Module: torch neural network. + """ + if os.path.exists(self.model_uri): + # Model is on local filesystem. + model = torch.load(self.model_uri) + return model.eval() + + if self.model_uri.startswith('mlflow+'): + # Model is on an MLFLow server + # Form is 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + import mlflow + from mlflow import MlflowException + _, tracking_uri, run_id, artifact_path = self.model_uri.split('+') + mlflow.set_tracking_uri(tracking_uri) + + # Check that run exists + try: + mlflow.get_run(run_id) + except MlflowException: + raise ValueError(f"Run ID '{run_id}' was not found!") + + # Download model weights + ckpt_path = mlflow.artifacts.download_artifacts( + run_id=run_id, + artifact_path=artifact_path, + dst_path='tmp/', + tracking_uri=mlflow.get_tracking_uri() + ) + model = torch.load(ckpt_path) + return model.eval() + + raise ValueError( + 'Unrecognized model URI: model may not be there!' + ) class TorchPredictor(Predictor): @@ -63,7 +119,7 @@ class TorchPredictor(Predictor): def __init__( self, - model: nn.Module, + model: Union[nn.Module, ModelLoader], test_dataloader_class: str = 'torch.utils.data.DataLoader', test_dataloader_kwargs: Optional[Dict] = None, # strategy: str = StrategyT.NONE.value, @@ -73,7 +129,7 @@ def __init__( # test_metrics: Optional[Dict[str, Metric]] = None, ) -> None: super().__init__() - self.model = model + self.model = model() if isinstance(model, ModelLoader) else model # self.seed = seed # self.strategy = strategy # self.cluster = cluster @@ -120,5 +176,6 @@ def execute( def predict(self) -> List[Any]: """Returns a list of predictions.""" + # TODO: complete return [] diff --git a/use-cases/mnist/torch/dataloader.py b/use-cases/mnist/torch/dataloader.py index 56ef807d..12c0a733 100644 --- a/use-cases/mnist/torch/dataloader.py +++ b/use-cases/mnist/torch/dataloader.py @@ -1,7 +1,9 @@ """Dataloader for Torch-based MNIST use case.""" -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Callable, Any +import os +from PIL import Image from torch.utils.data import Dataset from torchvision import transforms, datasets @@ -58,3 +60,94 @@ def execute( # ) # return (train_dataloder, validation_dataloader) return (self.train_dataset, self.val_dataset), config + + +class InferenceMNIST(Dataset): + """Loads a set of MNIST images from a folder of JPG files.""" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + supported_format: str = '.jpg' + ) -> None: + self.root = root + self.transform = transform + self.supported_format = supported_format + self.data = dict() + self._load() + + def _load(self): + for img_file in os.listdir(self.root): + if not img_file.lower().endswith(self.supported_format): + continue + filename = os.path.basename(img_file) + img = Image.open(img_file) + self.data[filename] = img + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image_identifier, image) where image_identifier + is the unique identifier for the image (e.g., filename). + """ + img_id, img = list(self.data.items())[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + return img_id, img + + @staticmethod + def generate_jpg_sample( + root: str, + max_items: int = 100 + ): + """Generate a sample dataset of JPG images starting from + LeCun's test dataset. + + Args: + root (str): sample path on disk + max_items (int, optional): max number of images to + generate. Defaults to 100. + """ + test_data = datasets.MNIST(root='.tmp', train=False, download=True) + for idx, (img, _) in enumerate(test_data): + if idx >= max_items: + break + savepath = os.path.join(root, f'digit_{idx}.jpg') + img.save(savepath) + + +class MNISTPredictLoader(DataGetter): + def __init__( + self, + test_data_path: str + ) -> None: + super().__init__() + self.test_data_path = test_data_path + + def execute( + self, + config: Optional[Dict] = None + ) -> Tuple[Tuple[Dataset, Dataset], Optional[Dict]]: + data = self.preproc() + return data, config + + def preproc(self) -> Dataset: + return InferenceMNIST( + root=self.test_data_path, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml new file mode 100644 index 00000000..f7dbfd42 --- /dev/null +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -0,0 +1,20 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + # TODO: complete + - class_path: dataloader.MNISTPredictLoader + init_args: + test_data_path: ... + + - class_path: itwinai.torch.inference.TorchPredictor + init_args: + model: + class_path: itwinai.torch.inference.TorchModelLoader + init_args: + model_uri: ... + ... + + - class_path: saver.TorchMNISTLabelSaver + init_args: + save_dir: ... \ No newline at end of file From c73fb088f1baeaca460158d5b92843ba65039e4e Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Mon, 6 Nov 2023 17:47:25 +0100 Subject: [PATCH 09/57] ADD: small docs --- use-cases/mnist/torch/inference-pipeline.yaml | 4 +++- use-cases/mnist/torch/inference.py | 1 + use-cases/mnist/torch/saver.py | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml index f7dbfd42..8a9e38a5 100644 --- a/use-cases/mnist/torch/inference-pipeline.yaml +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -17,4 +17,6 @@ executor: - class_path: saver.TorchMNISTLabelSaver init_args: - save_dir: ... \ No newline at end of file + save_dir: ... + predictions_file: ... + class_labels: null \ No newline at end of file diff --git a/use-cases/mnist/torch/inference.py b/use-cases/mnist/torch/inference.py index e69de29b..13d0e3dd 100644 --- a/use-cases/mnist/torch/inference.py +++ b/use-cases/mnist/torch/inference.py @@ -0,0 +1 @@ +# Can be replaced by train.py? diff --git a/use-cases/mnist/torch/saver.py b/use-cases/mnist/torch/saver.py index 7150475f..fd54c0cf 100644 --- a/use-cases/mnist/torch/saver.py +++ b/use-cases/mnist/torch/saver.py @@ -1,3 +1,7 @@ +""" +This module is used during inference to save predicted labels to file. +""" + from typing import Optional, List, Dict, Tuple import os import shutil @@ -12,10 +16,12 @@ class TorchMNISTLabelSaver(Saver): def __init__( self, save_dir: str = 'mnist_predictions', + predictions_file: str = 'predictions.csv', class_labels: Optional[List] = None ) -> None: super().__init__() self.save_dir = save_dir + self.predictions_file = predictions_file self.class_labels = ( class_labels if class_labels is not None else [f'Digit {i}' for i in range(10)] @@ -52,7 +58,7 @@ def execute( return ((result,), config) def save(self, predicted_labels: Dict[str, str]) -> None: - filepath = os.path.join(self.save_dir, 'predictions.csv') + filepath = os.path.join(self.save_dir, self.predictions_file) with open(filepath, 'w') as csv_file: writer = csv.writer(csv_file) for key, value in predicted_labels.items(): From 1866a81b982faa5a12e97406a0076b51eeb3c0d9 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 11:39:34 +0100 Subject: [PATCH 10/57] UPDATE: inference pipeline components --- .gitignore | 3 + src/itwinai/components.py | 55 ++++++- src/itwinai/serialization.py | 4 +- src/itwinai/torch/README.md | 36 +++++ src/itwinai/torch/inference.py | 145 +++++++++++------- src/itwinai/types.py | 11 ++ use-cases/mnist/torch/dataloader.py | 14 +- use-cases/mnist/torch/inference-pipeline.yaml | 14 +- use-cases/mnist/torch/inference.py | 1 - 9 files changed, 208 insertions(+), 75 deletions(-) create mode 100644 src/itwinai/torch/README.md create mode 100644 src/itwinai/types.py delete mode 100644 use-cases/mnist/torch/inference.py diff --git a/.gitignore b/.gitignore index 659e8395..739f077d 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,9 @@ mllogs *.err .logs/ pl-training.yml +*-predictions/ +*-data/ +*.pth # Custom envs .venv* diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 0df7b4f6..c1e6e372 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -1,11 +1,14 @@ from __future__ import annotations -from .cluster import ClusterEnvironment -from typing import Iterable, Dict, Any, Optional, Tuple +from typing import Iterable, Dict, Any, Optional, Tuple, Union from abc import ABCMeta, abstractmethod import time # import logging # from logging import Logger as PythonLogger +from .cluster import ClusterEnvironment +from .types import ModelML, DatasetML +from .serialization import ModelLoader + class Executable(metaclass=ABCMeta): """Base Executable class. @@ -152,9 +155,53 @@ def load_state(self): class Predictor(Executable): """Applies a pre-trained machine learning model to unseen data.""" + + model: ModelML + + def __init__( + self, + model: Union[ModelML, ModelLoader], + name: Optional[str] = None, + **kwargs + ) -> None: + super().__init__(name, **kwargs) + self.model = model() if isinstance(model, ModelLoader) else model + + def execute( + self, + predict_dataset: DatasetML, + config: Optional[Dict] = None, + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """"Execute some operations. + + Args: + predict_dataset (DatasetML): dataset object for inference. + config (Dict, optional): key-value configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as + (results, config). + """ + return self.predict(predict_dataset), config + @abstractmethod - def predict(self, *args, **kwargs): - pass + def predict( + self, + predict_dataset: DatasetML, + model: Optional[ModelML] = None + ) -> Iterable[Any]: + """Applies a machine learning model on a dataset of samples. + + Args: + predict_dataset (DatasetML): dataset for inference. + model (Optional[ModelML], optional): overrides the internal model, + if given. Defaults to None. + + Returns: + Iterable[Any]: predictions with the same cardinality of the + input dataset. + """ class DataGetter(Executable): diff --git a/src/itwinai/serialization.py b/src/itwinai/serialization.py index 5f8fa769..a7b70cd3 100644 --- a/src/itwinai/serialization.py +++ b/src/itwinai/serialization.py @@ -1,4 +1,4 @@ -from typing import Any +from .types import ModelML import abc @@ -10,5 +10,5 @@ def __init__(self, model_uri: str) -> None: self.model_uri = model_uri @abc.abstractmethod - def __call__(self) -> Any: + def __call__(self) -> ModelML: """Loads model from model URI.""" diff --git a/src/itwinai/torch/README.md b/src/itwinai/torch/README.md new file mode 100644 index 00000000..ac7f978b --- /dev/null +++ b/src/itwinai/torch/README.md @@ -0,0 +1,36 @@ +# Pure torch example on MNIST dataset + +## Training + +```bash +python train.py -p pipeline.yaml [-d] +``` + +Use `-d` flag to run only the fist step in the pipeline. + +## Inference + +1. Create sample dataset + + ```python + from dataloader import InferenceMNIST + InferenceMNIST.generate_jpg_sample('mnist-sample-data/', 10) + ``` + +2. Generate a dummy pre-trained neural network + + ```python + import torch + from model import Net + dummy_nn = Net() + torch.save(dummy_nn, 'mnist-pre-trained.pth') + ``` + +3. Run inference command. This will generate a "mnist-predictions" +folder containing a CSV file with the predictions as rows. + + ```bash + python train.py -p inference-pipeline.yaml + ``` + +Note the same entry point as for training. diff --git a/src/itwinai/torch/inference.py b/src/itwinai/torch/inference.py index d83c6d2e..7e5e047d 100644 --- a/src/itwinai/torch/inference.py +++ b/src/itwinai/torch/inference.py @@ -1,5 +1,6 @@ -from typing import Optional, Tuple, Dict, Any, List, Union +from typing import Optional, Dict, Any, Union import os +import abc import torch from torch import nn @@ -9,7 +10,7 @@ from .utils import clear_key from ..components import Predictor from .types import TorchDistributedStrategy as StrategyT -from .types import Metric +from .types import Metric, Batch from ..serialization import ModelLoader @@ -67,45 +68,7 @@ def __call__(self) -> nn.Module: class TorchPredictor(Predictor): - """Applies a pre-trained torch model to unseen data. - - Args: - model (nn.Module): neural network instance. - test_dataloader_class (str, optional): test dataloader class path. - Defaults to 'torch.utils.data.DataLoader'. - test_dataloader_kwargs (Optional[Dict], optional): constructor - arguments of the test dataloader, except for the dataset - instance. Defaults to None. - strategy (Optional[TorchDistributedStrategy], optional): distributed - strategy. Defaults to StrategyT.NONE.value. - backend (TorchDistributedBackend, optional): computing backend. - Defaults to BackendT.NCCL.value. - shuffle_dataset (bool, optional): whether shuffle dataset before - sampling batches from dataloader. Defaults to False. - use_cuda (bool, optional): whether to use GPU. Defaults to True. - benchrun (bool, optional): sets up a debug run. Defaults to False. - testrun (bool, optional): deterministic training seeding everything. - Defaults to False. - seed (Optional[int], optional): random seed. Defaults to None. - logger (Optional[List[Logger]], optional): logger. Defaults to None. - checkpoint_every (int, optional): how often (epochs) to checkpoint the - best model. Defaults to 10. - cluster (Optional[ClusterEnvironment], optional): cluster environment - object describing the context in which the trainer is executed. - Defaults to None. - train_metrics (Optional[Dict[str, Metric]], optional): - list of metrics computed in the training step on the predictions. - It's a dictionary with the form - ``{'metric_unique_name': CallableMetric}``. Defaults to None. - validation_metrics (Optional[Dict[str, Metric]], optional): same - as ``training_metrics``. If not given, it mirrors the training - metrics. Defaults to None. - - Raises: - RuntimeError: When trying to use DDP without CUDA support. - NotImplementedError: when trying to use a strategy different from the - ones provided by TorchDistributedStrategy. - """ + """Applies a pre-trained torch model to unseen data.""" model: nn.Module = None test_dataset: Dataset @@ -127,9 +90,10 @@ def __init__( # logger: Optional[List[Logger]] = None, # cluster: Optional[ClusterEnvironment] = None, # test_metrics: Optional[Dict[str, Metric]] = None, + name: str = None ) -> None: - super().__init__() - self.model = model() if isinstance(model, ModelLoader) else model + super().__init__(model=model, name=name) + self.model = self.model.eval() # self.seed = seed # self.strategy = strategy # self.cluster = cluster @@ -158,24 +122,91 @@ def __init__( # else validation_metrics # ) - def execute( + def predict( self, test_dataset: Dataset, model: nn.Module = None, - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - self.test_dataset = test_dataset - self.test_dataloader = self.test_dataloader_class( - test_dataset, **self.test_dataloader_kwargs - ) - # Update model passed for "interactive" use + ) -> Dict[str, Any]: + """Applies a torch model to a dataset for inference. + + Args: + test_dataset (Dataset[str, Any]): each item in this dataset is a + couple (item_unique_id, item) + model (nn.Module, optional): torch model. Overrides the existing + model, if given. Defaults to None. + + Returns: + Dict[str, Any]: maps each item ID to the corresponding predicted + value(s). + """ if model is not None: + # Overrides existing "internal" model self.model = model - result = self.predict() - return ((result,), config) - def predict(self) -> List[Any]: - """Returns a list of predictions.""" - # TODO: complete + test_dataloader = self.test_dataloader_class( + test_dataset, **self.test_dataloader_kwargs + ) + + all_predictions = dict() + for samples_ids, samples in test_dataloader: + with torch.no_grad(): + pred = self.model(samples) + pred = self.transform_predictions(pred) + for idx, pre in zip(samples_ids, pred): + # For each item in the batch + if pre.numel() == 1: + pre = pre.item() + else: + pre = pre.to_dense().tolist() + all_predictions[idx] = pre + return all_predictions + + @abc.abstractmethod + def transform_predictions(self, batch: Batch) -> Batch: + """ + Post-process the predictions of the torch model (e.g., apply + threshold in case of multilabel classifier). + """ + + +class MulticlassTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + multiclass classification. + """ + + def transform_predictions(self, batch: Batch) -> Batch: + batch = batch.argmax(-1) + return batch + + +class MultilabelTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + multilabel classification, applying a threshold on the + output of the neural network. + """ + + def __init__( + self, + model: Union[nn.Module, ModelLoader], + test_dataloader_class: str = 'torch.utils.data.DataLoader', + test_dataloader_kwargs: Optional[Dict] = None, + threshold: float = 0.5, + name: str = None + ) -> None: + super().__init__( + model, test_dataloader_class, test_dataloader_kwargs, name + ) + self.threshold = threshold + + +class RegressionTorchPredictor(TorchPredictor): + """ + Applies a pre-trained torch model to unseen data for + regression, leaving untouched the output of the neural + network. + """ - return [] + def transform_predictions(self, batch: Batch) -> Batch: + return batch diff --git a/src/itwinai/types.py b/src/itwinai/types.py new file mode 100644 index 00000000..9c302eb1 --- /dev/null +++ b/src/itwinai/types.py @@ -0,0 +1,11 @@ +""" +Framework-independent types. +""" + + +class DatasetML: + """A framework-independent machine learning dataset.""" + + +class ModelML: + """A framework-independent machine learning model.""" diff --git a/use-cases/mnist/torch/dataloader.py b/use-cases/mnist/torch/dataloader.py index 12c0a733..39e9b56b 100644 --- a/use-cases/mnist/torch/dataloader.py +++ b/use-cases/mnist/torch/dataloader.py @@ -2,6 +2,7 @@ from typing import Dict, Optional, Tuple, Callable, Any import os +import shutil from PIL import Image from torch.utils.data import Dataset @@ -82,7 +83,7 @@ def _load(self): if not img_file.lower().endswith(self.supported_format): continue filename = os.path.basename(img_file) - img = Image.open(img_file) + img = Image.open(os.path.join(self.root, img_file)) self.data[filename] = img def __len__(self) -> int: @@ -101,7 +102,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img.numpy(), mode="L") + # print(type(img)) + # img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) @@ -121,6 +123,10 @@ def generate_jpg_sample( max_items (int, optional): max number of images to generate. Defaults to 100. """ + if os.path.exists(root): + shutil.rmtree(root) + os.makedirs(root) + test_data = datasets.MNIST(root='.tmp', train=False, download=True) for idx, (img, _) in enumerate(test_data): if idx >= max_items: @@ -141,10 +147,10 @@ def execute( self, config: Optional[Dict] = None ) -> Tuple[Tuple[Dataset, Dataset], Optional[Dict]]: - data = self.preproc() + data = self.load() return data, config - def preproc(self) -> Dataset: + def load(self) -> Dataset: return InferenceMNIST( root=self.test_data_path, transform=transforms.Compose([ diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml index 8a9e38a5..2afa642f 100644 --- a/use-cases/mnist/torch/inference-pipeline.yaml +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -2,21 +2,21 @@ executor: class_path: itwinai.components.Executor init_args: steps: - # TODO: complete - class_path: dataloader.MNISTPredictLoader init_args: - test_data_path: ... + test_data_path: mnist-sample-data/ - - class_path: itwinai.torch.inference.TorchPredictor + - class_path: itwinai.torch.inference.MulticlassTorchPredictor init_args: model: class_path: itwinai.torch.inference.TorchModelLoader init_args: - model_uri: ... - ... + model_uri: mnist-pre-trained.pth + test_dataloader_kwargs: + batch_size: 3 - class_path: saver.TorchMNISTLabelSaver init_args: - save_dir: ... - predictions_file: ... + save_dir: mnist-predictions + predictions_file: predictions.csv class_labels: null \ No newline at end of file diff --git a/use-cases/mnist/torch/inference.py b/use-cases/mnist/torch/inference.py deleted file mode 100644 index 13d0e3dd..00000000 --- a/use-cases/mnist/torch/inference.py +++ /dev/null @@ -1 +0,0 @@ -# Can be replaced by train.py? From 22aed46f36aaff5bdeeda2dca7b6dd374e162591 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 13:23:44 +0100 Subject: [PATCH 11/57] UPDATE: reorg --- {src/itwinai => use-cases/mnist}/torch/README.md | 0 use-cases/mnist/torch/pipeline.yaml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename {src/itwinai => use-cases/mnist}/torch/README.md (100%) diff --git a/src/itwinai/torch/README.md b/use-cases/mnist/torch/README.md similarity index 100% rename from src/itwinai/torch/README.md rename to use-cases/mnist/torch/README.md diff --git a/use-cases/mnist/torch/pipeline.yaml b/use-cases/mnist/torch/pipeline.yaml index df29c3fe..9bb7fb98 100644 --- a/use-cases/mnist/torch/pipeline.yaml +++ b/use-cases/mnist/torch/pipeline.yaml @@ -25,7 +25,7 @@ executor: batch_size: 32 pin_memory: True shuffle: False - epochs: 2 + epochs: 30 train_metrics: accuracy: class_path: torchmetrics.classification.MulticlassAccuracy From 0242790d405386dde05468532dce747c6911527f Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 14:36:41 +0100 Subject: [PATCH 12/57] ADD: image generation for inference --- .gitignore | 1 + use-cases/mnist/torch/Dockerfile | 18 +++++++++ use-cases/mnist/torch/README.md | 38 +++++++++++++++++++ use-cases/mnist/torch/inference-pipeline.yaml | 4 +- 4 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 use-cases/mnist/torch/Dockerfile diff --git a/.gitignore b/.gitignore index 739f077d..0c42e0a6 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ pl-training.yml *-predictions/ *-data/ *.pth +*.tar.gz # Custom envs .venv* diff --git a/use-cases/mnist/torch/Dockerfile b/use-cases/mnist/torch/Dockerfile new file mode 100644 index 00000000..b4cf3654 --- /dev/null +++ b/use-cases/mnist/torch/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.9.12 + +WORKDIR /usr/src/app + +# Install pytorch (cpuonly) +# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu + +# Install itwinai and dependencies +COPY pyproject.toml ./ +COPY src ./ +RUN pip install --no-cache-dir . + +# Add torch MNIST use case +COPY use-cases/mnist/torch/* ./ + +# Run inference +CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/mnist/torch/README.md b/use-cases/mnist/torch/README.md index ac7f978b..ff755beb 100644 --- a/use-cases/mnist/torch/README.md +++ b/use-cases/mnist/torch/README.md @@ -34,3 +34,41 @@ folder containing a CSV file with the predictions as rows. ``` Note the same entry point as for training. + +### Docker image + +Build from project root with + +```bash +# Local +docker buildx build -t itwinai-mnist-torch-inference -f use-cases/mnist/torch/Dockerfile . + +# Ghcr.io +docker buildx build -t ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 -f use-cases/mnist/torch/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +``` + +From wherever a sample of MNIST jpg images is available +(folder called 'mnist-sample-data/'): + +```text +├── $PWD +│ ├── mnist-sample-data +| │ ├── digit_0.jpg +| │ ├── digit_1.jpg +| │ ├── digit_2.jpg +... +| │ ├── digit_N.jpg +``` + +```bash +docker run -it --rm --name running-inference -v "$PWD":/usr/data itwinai-mnist-torch-inference +``` + +This command will store the results in a folder called "mnist-predictions": + +```text +├── $PWD +│ ├── mnist-predictions +| │ ├── predictions.csv +``` diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml index 2afa642f..ba4f5e86 100644 --- a/use-cases/mnist/torch/inference-pipeline.yaml +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.MNISTPredictLoader init_args: - test_data_path: mnist-sample-data/ + test_data_path: /usr/data/mnist-sample-data - class_path: itwinai.torch.inference.MulticlassTorchPredictor init_args: @@ -17,6 +17,6 @@ executor: - class_path: saver.TorchMNISTLabelSaver init_args: - save_dir: mnist-predictions + save_dir: /usr/data/mnist-predictions predictions_file: predictions.csv class_labels: null \ No newline at end of file From 17915b11cce60f8ac58171a47595c8514fdab10f Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 14:43:58 +0100 Subject: [PATCH 13/57] update tag --- use-cases/mnist/torch/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/use-cases/mnist/torch/README.md b/use-cases/mnist/torch/README.md index ff755beb..c953671f 100644 --- a/use-cases/mnist/torch/README.md +++ b/use-cases/mnist/torch/README.md @@ -62,7 +62,7 @@ From wherever a sample of MNIST jpg images is available ``` ```bash -docker run -it --rm --name running-inference -v "$PWD":/usr/data itwinai-mnist-torch-inference +docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 ``` This command will store the results in a folder called "mnist-predictions": From c3ff733f7fdf3888b0baeafc54d2374d92f1e4f5 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 15:21:53 +0100 Subject: [PATCH 14/57] ADD: threshold --- src/itwinai/torch/inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/itwinai/torch/inference.py b/src/itwinai/torch/inference.py index 7e5e047d..4d7797c6 100644 --- a/src/itwinai/torch/inference.py +++ b/src/itwinai/torch/inference.py @@ -187,6 +187,8 @@ class MultilabelTorchPredictor(TorchPredictor): output of the neural network. """ + threshold: float + def __init__( self, model: Union[nn.Module, ModelLoader], @@ -200,6 +202,9 @@ def __init__( ) self.threshold = threshold + def transform_predictions(self, batch: Batch) -> Batch: + return (batch > self.threshold).float() + class RegressionTorchPredictor(TorchPredictor): """ From 0a0f56e422b0421124c546898da5769391f3b480 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 16:11:28 +0100 Subject: [PATCH 15/57] ADD: draft inference --- use-cases/3dgan/dataloader.py | 14 +++- use-cases/3dgan/inference-pipeline.yaml | 93 +++++++++++++++++++++++++ use-cases/3dgan/model.py | 27 +++++++ 3 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 use-cases/3dgan/inference-pipeline.yaml diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 76893e5d..71c85141 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple, Dict import os +from lightning.pytorch.utilities.types import EVAL_DATALOADERS import numpy as np import torch @@ -39,7 +40,7 @@ def execute( return None, config -class MyDataset(Dataset): +class ParticlesDataset(Dataset): def __init__(self, datapath): self.datapath = datapath self.data = self.fetch_data(self.datapath) @@ -139,13 +140,18 @@ def setup(self, stage: str = None): # called on every process in DDP if stage == 'fit' or stage is None: - self.dataset = MyDataset(self.datapath) + self.dataset = ParticlesDataset(self.datapath) dataset_length = len(self.dataset) split_point = int(dataset_length * 0.9) self.train_dataset, self.val_dataset = \ torch.utils.data.random_split( self.dataset, [split_point, dataset_length - split_point]) + if stage == 'predict': + # TODO: inference dataset should be different in that it + # does not contain images! + self.predict_dataset = ParticlesDataset(self.datapath) + # if stage == 'test' or stage is None: # self.test_dataset = MyDataset(self.data_dir, train=False) @@ -157,5 +163,9 @@ def val_dataloader(self): return DataLoader(self.val_dataset, num_workers=4, batch_size=self.batch_size, drop_last=True) + def predict_dataloader(self) -> EVAL_DATALOADERS: + return DataLoader(self.predict_dataset, num_workers=4, + batch_size=self.batch_size, drop_last=True) + # def test_dataloader(self): # return DataLoader(self.test_dataset, batch_size=self.batch_size) diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml new file mode 100644 index 00000000..32d6d5b1 --- /dev/null +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -0,0 +1,93 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: exp_data/ + data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX + + - class_path: trainer.Lightning3DGANTrainer + init_args: + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: auto #[0] + enable_checkpointing: true + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: 2 + logger: + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger + init_args: + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all + max_epochs: 1 + max_steps: 20 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: ddp_find_unused_parameters_true #auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 64 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + checkpoint_path: exp_data/3dgan.pth + + # Lightning data module configuration + data: + class_path: dataloader.MyDataModule + init_args: + datapath: exp_data/*/*.h5 + batch_size: 64 diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 009e267e..9840766b 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -2,6 +2,7 @@ import pickle from collections import defaultdict import math +from typing import Any import torch import torch.nn as nn @@ -412,6 +413,8 @@ def training_step(self, batch, batch_idx): (self.batch_size, self.latent_size - 2), device=self.device ) + print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}') + print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}') generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1 @@ -723,6 +726,30 @@ def on_validation_epoch_end(self): # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) # print("train-loss:" + str(self.train_history["generator"][-1][0])) + def predict_step( + self, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Any: + energy_batch, ang_batch = batch['Y'], batch['ang'] + + energy_batch = energy_batch.to(self.device) + ang_batch = ang_batch.to(self.device) + + # Generate Fake events with same energy and angle as data batch + noise = torch.randn( + (self.batch_size, self.latent_size - 2), + dtype=torch.float32, + device=self.device + ) + + generator_ip = torch.cat( + (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) + generated_images = self.generator(generator_ip) + print(f"Generated batch size {generated_images.shape}") + return generated_images + def configure_optimizers(self): lr = self.lr From 95661c18b46b24a71ecc4a7179b61373821b9c83 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Tue, 7 Nov 2023 18:18:24 +0100 Subject: [PATCH 16/57] ADD: draft inference wf --- use-cases/3dgan/README.md | 21 +++++ use-cases/3dgan/dataloader.py | 2 +- use-cases/3dgan/inference-pipeline.yaml | 12 ++- use-cases/3dgan/pipeline.yaml | 2 +- use-cases/3dgan/saver.py | 55 ++++++++++++ use-cases/3dgan/trainer.py | 109 +++++++++++++++++++++++- 6 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 use-cases/3dgan/saver.py diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index acfe4208..7f3801ad 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -7,3 +7,24 @@ micromamba run -p ../../.venv-pytorch mlflow ui --backend-store-uri ml_logs/mlfl ``` And select the "3DGAN" experiment. + +## Inference + +The following is preliminary and not 100% ML/scientifically sound. + +1. As inference dataset we can reuse training/validation dataset +2. As model, we can create a dummy version of it with: + + ```python + import torch + from model import ThreeDGAN + # Same params as in the training config file! + my_gan = ThreeDGAN() + torch.save(my_gan, '3dgan-inference.pth') + ``` + +3. Run inference with the following command: + + ```bash + TODO + ``` diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 71c85141..d77e363a 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -129,7 +129,7 @@ def GetDataAngleParallel( return final_dataset -class MyDataModule(pl.LightningDataModule): +class ParticlesDataModule(pl.LightningDataModule): def __init__(self, batch_size: int, datapath): super().__init__() self.batch_size = batch_size diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 32d6d5b1..52d6abf7 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -7,8 +7,12 @@ executor: data_path: exp_data/ data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - - class_path: trainer.Lightning3DGANTrainer + - class_path: trainer.Lightning3DGANPredictor init_args: + model: + class_path: trainer.LightningModelLoader + init_args: + model_uri: 3dgan-inference.pth # Pytorch lightning config for training config: seed_everything: 4231162351 @@ -87,7 +91,11 @@ executor: # Lightning data module configuration data: - class_path: dataloader.MyDataModule + class_path: dataloader.ParticlesDataModule init_args: datapath: exp_data/*/*.h5 batch_size: 64 + + - class_path: saver.ParticleImagesSaver + init_args: + save_dir: 3dgan-generated \ No newline at end of file diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 32d6d5b1..942efeb7 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -87,7 +87,7 @@ executor: # Lightning data module configuration data: - class_path: dataloader.MyDataModule + class_path: dataloader.ParticlesDataModule init_args: datapath: exp_data/*/*.h5 batch_size: 64 diff --git a/use-cases/3dgan/saver.py b/use-cases/3dgan/saver.py new file mode 100644 index 00000000..9de6fe7a --- /dev/null +++ b/use-cases/3dgan/saver.py @@ -0,0 +1,55 @@ +from typing import Dict, Tuple, Optional +import os +import shutil + +import torch +from torch import Tensor + +from itwinai.components import Saver + + +class ParticleImagesSaver(Saver): + """Saves generated particle trajectories to disk.""" + + def __init__( + self, + save_dir: str = '3dgan-generated', + ) -> None: + super().__init__() + self.save_dir = save_dir + + def execute( + self, + generated_images: Dict[str, Tensor], + config: Optional[Dict] = None + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """Saves generated images to disk. + + Args: + generated_images (Dict[str, Tensor]): maps unique item ID to + the generated image. + config (Optional[Dict], optional): inherited configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited + configuration and saver return value. + """ + result = self.save(generated_images) + return ((result,), config) + + def save(self, generated_images: Dict[str, Tensor]) -> None: + """Saves generated images to disk. + + Args: + generated_images (Dict[str, Tensor]): maps unique item ID to + the generated image. + """ + if os.path.exists(self.save_dir): + shutil.rmtree(self.save_dir) + os.makedirs(self.save_dir) + + # TODO: save as 3D plot image + for img_id, img in generated_images.items(): + img_path = os.path.join(self.save_dir, img_id + '.pth') + torch.save(img, img_path) diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 2caabc59..2a3d47ff 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -2,10 +2,17 @@ import sys from typing import Union, Dict, Tuple, Optional, Any -from itwinai.components import Trainer -from model import ThreeDGAN -from dataloader import MyDataModule +import torch +from torch import Tensor +import lightning as pl from lightning.pytorch.cli import LightningCLI + +from itwinai.components import Trainer, Predictor +from itwinai.serialization import ModelLoader +from itwinai.torch.inference import TorchModelLoader + +from model import ThreeDGAN +from dataloader import ParticlesDataModule from utils import load_yaml @@ -23,7 +30,7 @@ def train(self) -> Any: cli = LightningCLI( args=self.conf, model_class=ThreeDGAN, - datamodule_class=MyDataModule, + datamodule_class=ParticlesDataModule, run=False, save_config_kwargs={ "overwrite": True, @@ -47,3 +54,97 @@ def save_state(self): def load_state(self): return super().load_state() + + +class LightningModelLoader(TorchModelLoader): + """Loads a torch lightning model from somewhere. + + Args: + model_uri (str): Can be a path on local filesystem + or an mlflow 'locator' in the form: + 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH' + """ + + def __call__(self) -> pl.LightningModule: + """"Loads model from model URI. + + Raises: + ValueError: if the model URI is not recognized + or the model is not found. + + Returns: + pl.LightningModule: torch lightning module. + """ + # TODO: improve + # # Load best model + # loaded_model = cli.model.load_from_checkpoint( + # ckpt_path, + # lightning_conf['model']['init_args'] + # ) + return super().__call__() + + +class Lightning3DGANPredictor(Predictor): + + def __init__( + self, + model: Union[ModelLoader, pl.LightningModule], + config: Union[Dict, str], + name: Optional[str] = None + ): + super().__init__(model, name) + if isinstance(config, str) and os.path.isfile(config): + # Load from YAML + config = load_yaml(config) + self.conf = config + + def predict( + self, + datamodule: Optional[pl.LightningDataModule] = None, + model: Optional[pl.LightningModule] = None + ) -> Dict[str, Tensor]: + old_argv = sys.argv + sys.argv = ['some_script_placeholder.py'] + cli = LightningCLI( + args=self.conf, + model_class=ThreeDGAN, + datamodule_class=ParticlesDataModule, + run=False, + save_config_kwargs={ + "overwrite": True, + "config_filename": "pl-training.yml", + }, + subclass_mode_model=True, + subclass_mode_data=True, + ) + sys.argv = old_argv + + # Override config file with inline arguments, if given + if datamodule is None: + datamodule = cli.datamodule + if model is None: + model = cli.model + + predictions = cli.trainer.predict(model, datamodule=datamodule) + + predictions_dict = dict() + # TODO: postprocess predictions + for idx, generated_img in enumerate(torch.cat(predictions)): + predictions_dict[str(idx)] = generated_img + return predictions_dict + + def execute( + self, + config: Optional[Dict] = None, + ) -> Tuple[Optional[Tuple], Optional[Dict]]: + """"Execute some operations. + + Args: + config (Dict, optional): key-value configuration. + Defaults to None. + + Returns: + Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as + (results, config). + """ + return self.predict(), config From 94254cf0c8ff95f8795074f6d099d46a2521b4e9 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 13:27:31 +0100 Subject: [PATCH 17/57] ADD: working inference workflow --- use-cases/3dgan/inference-pipeline.yaml | 3 +++ use-cases/3dgan/model.py | 25 +++++++++++++++++++------ use-cases/3dgan/saver.py | 2 +- use-cases/3dgan/trainer.py | 20 +++++++++++++++++++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 52d6abf7..296b33c8 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -13,6 +13,9 @@ executor: class_path: trainer.LightningModelLoader init_args: model_uri: 3dgan-inference.pth + + max_samples: 10 + # Pytorch lightning config for training config: seed_everything: 4231162351 diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 9840766b..3ac37cf1 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -410,7 +410,9 @@ def training_step(self, batch, batch_idx): optimizer_discriminator, optimizer_generator = self.optimizers() noise = torch.randn( - (self.batch_size, self.latent_size - 2), + (energy_batch.shape[0], self.latent_size - 2), + # (self.batch_size, self.latent_size - 2), + dtype=torch.float32, device=self.device ) print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}') @@ -629,8 +631,12 @@ def validation_step(self, batch, batch_idx): ecal_batch = ecal_batch.to(self.device) # Generate Fake events with same energy and angle as data batch - noise = torch.randn((self.batch_size, self.latent_size - 2), - dtype=torch.float32).to(self.device) + noise = torch.randn( + (energy_batch.shape[0], self.latent_size - 2), + # (self.batch_size, self.latent_size - 2), + dtype=torch.float32, + device=self.device + ) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) @@ -739,15 +745,22 @@ def predict_step( # Generate Fake events with same energy and angle as data batch noise = torch.randn( - (self.batch_size, self.latent_size - 2), + (energy_batch.shape[0], self.latent_size - 2), dtype=torch.float32, device=self.device ) + # print(f"Reshape energy: {energy_batch.view(-1, 1).shape}") + # print(f"Reshape angle: {ang_batch.view(-1, 1).shape}") + # print(f"Noise: {noise.shape}") + generator_ip = torch.cat( - (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1) + [energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise], + dim=1 + ) + # print(f"Generator input: {generator_ip.shape}") generated_images = self.generator(generator_ip) - print(f"Generated batch size {generated_images.shape}") + # print(f"Generated batch size {generated_images.shape}") return generated_images def configure_optimizers(self): diff --git a/use-cases/3dgan/saver.py b/use-cases/3dgan/saver.py index 9de6fe7a..79ed7eea 100644 --- a/use-cases/3dgan/saver.py +++ b/use-cases/3dgan/saver.py @@ -13,7 +13,7 @@ class ParticleImagesSaver(Saver): def __init__( self, - save_dir: str = '3dgan-generated', + save_dir: str = '3dgan-generated' ) -> None: super().__init__() self.save_dir = save_dir diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 2a3d47ff..642a13ec 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -10,6 +10,7 @@ from itwinai.components import Trainer, Predictor from itwinai.serialization import ModelLoader from itwinai.torch.inference import TorchModelLoader +from itwinai.torch.types import Batch from model import ThreeDGAN from dataloader import ParticlesDataModule @@ -90,6 +91,7 @@ def __init__( self, model: Union[ModelLoader, pl.LightningModule], config: Union[Dict, str], + max_samples: Optional[int] = None, name: Optional[str] = None ): super().__init__(model, name) @@ -97,6 +99,7 @@ def __init__( # Load from YAML config = load_yaml(config) self.conf = config + self.max_samples = max_samples def predict( self, @@ -127,12 +130,27 @@ def predict( predictions = cli.trainer.predict(model, datamodule=datamodule) + predictions = [ + self.transform_predictions(pred) for pred in predictions + ] + predictions_dict = dict() - # TODO: postprocess predictions for idx, generated_img in enumerate(torch.cat(predictions)): + if (self.max_samples is not None + and idx >= self.max_samples): + break predictions_dict[str(idx)] = generated_img + + print(len(predictions_dict)) return predictions_dict + def transform_predictions(self, batch: Batch) -> Batch: + """ + Post-process the predictions of the torch model. + """ + # TODO: post-process predictions + return batch + def execute( self, config: Optional[Dict] = None, From 63a7aa07281ec8376f29e3343e4825d573ea5185 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 15:07:13 +0100 Subject: [PATCH 18/57] ADD: 3D scatter plots --- use-cases/3dgan/README.md | 80 +++++++++++++++++++++++-- use-cases/3dgan/inference-pipeline.yaml | 2 +- use-cases/3dgan/model.py | 4 +- use-cases/3dgan/requirements.txt | 4 +- use-cases/3dgan/saver.py | 71 +++++++++++++++++++++- use-cases/3dgan/trainer.py | 23 ++++--- 6 files changed, 166 insertions(+), 18 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 7f3801ad..1c40b63a 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -12,8 +12,22 @@ And select the "3DGAN" experiment. The following is preliminary and not 100% ML/scientifically sound. -1. As inference dataset we can reuse training/validation dataset -2. As model, we can create a dummy version of it with: +1. As inference dataset we can reuse training/validation dataset, +for instance the one downloaded from Google Drive folder. +The inference dataset is a set of H5 files stored inside `exp_data` +sub-folders. + + ```text + ├── exp_data + │ ├── data + | │ ├── file_0.h5 + | │ ├── file_1.h5 + ... + | │ ├── file_N.h5 + ``` + +2. As model, if a pre-trained checkpoint is not available, +we can create a dummy version of it with: ```python import torch @@ -23,8 +37,66 @@ The following is preliminary and not 100% ML/scientifically sound. torch.save(my_gan, '3dgan-inference.pth') ``` -3. Run inference with the following command: +3. Run inference command. This will generate a "3dgan-generated" +folder containing generated particle traces in form of torch tensors +(.pth files) and 3D scatter plots (.jpg images). ```bash - TODO + python train.py -p inference-pipeline.yaml ``` + +Note the same entry point as for training. + +The inference execution will produce a folder called +"3dgan-generated-data" containing +generated 3D particle trajectories (overwritten if already +there). Each generated 3D image is stored both as a +torch tensor (.pth) and 3D scatter plot (.jpg): + +```text +├── 3dgan-generated-data +│ ├── data +| │ ├── energy=1.296749234199524&angle=1.272539496421814.pth +| │ ├── energy=1.296749234199524&angle=1.272539496421814.jpg +... +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.pth +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.jpg +``` + +### Docker image + +Build from project root with + +```bash +# Local +docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . + +# Ghcr.io +docker buildx build -t ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 -f use-cases/mnist/torch/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +``` + +From wherever a sample of MNIST jpg images is available +(folder called 'mnist-sample-data/'): + +```text +├── $PWD +│ ├── mnist-sample-data +| │ ├── digit_0.jpg +| │ ├── digit_1.jpg +| │ ├── digit_2.jpg +... +| │ ├── digit_N.jpg +``` + +```bash +docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +``` + +This command will store the results in a folder called "mnist-predictions": + +```text +├── $PWD +│ ├── mnist-predictions +| │ ├── predictions.csv +``` diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 296b33c8..bf9b333a 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -101,4 +101,4 @@ executor: - class_path: saver.ParticleImagesSaver init_args: - save_dir: 3dgan-generated \ No newline at end of file + save_dir: 3dgan-generated-data \ No newline at end of file diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 3ac37cf1..6cc92c93 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -761,7 +761,9 @@ def predict_step( # print(f"Generator input: {generator_ip.shape}") generated_images = self.generator(generator_ip) # print(f"Generated batch size {generated_images.shape}") - return generated_images + return {'images': generated_images, + 'energies': energy_batch, + 'angles': ang_batch} def configure_optimizers(self): lr = self.lr diff --git a/use-cases/3dgan/requirements.txt b/use-cases/3dgan/requirements.txt index c06fe435..f1f3b0bf 100644 --- a/use-cases/3dgan/requirements.txt +++ b/use-cases/3dgan/requirements.txt @@ -1,4 +1,6 @@ h5py>=3.7.0 google>=3.0.0 protobuf>=4.24.3 -gdown>=4.7.1 \ No newline at end of file +gdown>=4.7.1 +# plotly>=5.18.0 +# kaleido>=0.2.1 \ No newline at end of file diff --git a/use-cases/3dgan/saver.py b/use-cases/3dgan/saver.py index 79ed7eea..7aa72429 100644 --- a/use-cases/3dgan/saver.py +++ b/use-cases/3dgan/saver.py @@ -4,6 +4,8 @@ import torch from torch import Tensor +import matplotlib.pyplot as plt +import numpy as np from itwinai.components import Saver @@ -49,7 +51,70 @@ def save(self, generated_images: Dict[str, Tensor]) -> None: shutil.rmtree(self.save_dir) os.makedirs(self.save_dir) - # TODO: save as 3D plot image + # Save as torch tensor and jpg image for img_id, img in generated_images.items(): - img_path = os.path.join(self.save_dir, img_id + '.pth') - torch.save(img, img_path) + img_path = os.path.join(self.save_dir, img_id) + torch.save(img, img_path + '.pth') + self._save_image(img, img_id, img_path + '.jpg') + + def _save_image( + self, + img: Tensor, + img_idx: str, + img_path: str, + center: bool = True + ) -> None: + """Converts a 3D tensor to a 3D scatter plot and saves it + to disk as jpg image. + """ + x_offset = img.shape[0] // 2 if center else 0 + y_offset = img.shape[1] // 2 if center else 0 + z_offset = img.shape[2] // 2 if center else 0 + + # Convert tensor dimension IDs to coordinates + x_coords = [] + y_coords = [] + z_coords = [] + values = [] + + for x in range(img.shape[0]): + for y in range(img.shape[1]): + for z in range(img.shape[2]): + if img[x, y, z] > 0.0: + x_coords.append(x - x_offset) + y_coords.append(y - y_offset) + z_coords.append(z - z_offset) + values.append(img[x, y, z]) + + # import plotly.graph_objects as go + # normalize_intensity_by = 1 + # trace = go.Scatter3d( + # x=x_coords, + # y=y_coords, + # z=z_coords, + # mode='markers', + # marker_symbol='square', + # marker_color=[ + # f"rgba(0,0,255,{i*100//normalize_intensity_by/10})" + # for i in values], + # ) + # fig = go.Figure() + # fig.add_trace(trace) + # fig.write_image(img_path) + + values = np.array(values) + # 0-1 scaling + values = (values - values.min()) / (values.max() - values.min()) + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.scatter(x_coords, y_coords, z_coords, alpha=values) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + + # Extract energy and angle from idx + en, ang = img_idx.split('&') + en = en[7:] + ang = ang[6:] + ax.set_title(f"Energy: {en} - Angle: {ang}") + fig.savefig(img_path) diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 642a13ec..b5407c79 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -130,26 +130,33 @@ def predict( predictions = cli.trainer.predict(model, datamodule=datamodule) - predictions = [ - self.transform_predictions(pred) for pred in predictions - ] + # Transpose predictions into images, energies and angles + images = torch.cat(list(map( + lambda pred: self.transform_predictions( + pred['images']), predictions + ))) + energies = torch.cat(list(map( + lambda pred: pred['energies'], predictions + ))) + angles = torch.cat(list(map( + lambda pred: pred['angles'], predictions + ))) predictions_dict = dict() - for idx, generated_img in enumerate(torch.cat(predictions)): + for idx, (img, en, ang) in enumerate(zip(images, energies, angles)): if (self.max_samples is not None and idx >= self.max_samples): break - predictions_dict[str(idx)] = generated_img + sample_key = f"energy={en.item()}&angle={ang.item()}" + predictions_dict[sample_key] = img - print(len(predictions_dict)) return predictions_dict def transform_predictions(self, batch: Batch) -> Batch: """ Post-process the predictions of the torch model. """ - # TODO: post-process predictions - return batch + return batch.squeeze(1) def execute( self, From 61c366630e4b5408bc1c7cb50574e02831ba0d2d Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 16:17:46 +0100 Subject: [PATCH 19/57] ADD: Dockerfile + refactor --- use-cases/3dgan/Dockerfile | 25 ++++++ use-cases/3dgan/README.md | 42 +++++----- use-cases/3dgan/dataloader.py | 100 ++++++++++++++++-------- use-cases/3dgan/inference-pipeline.yaml | 10 +-- use-cases/3dgan/trainer.py | 7 +- 5 files changed, 122 insertions(+), 62 deletions(-) create mode 100644 use-cases/3dgan/Dockerfile diff --git a/use-cases/3dgan/Dockerfile b/use-cases/3dgan/Dockerfile new file mode 100644 index 00000000..515caff0 --- /dev/null +++ b/use-cases/3dgan/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.9.12 + +WORKDIR /usr/src/app + +RUN pip install --upgrade pip + +# Install pytorch (cpuonly) +# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache-dir lightning + +# Add 3DGAN custom requirements +COPY use-cases/3dgan/requirements.txt ./ +RUN pip install --no-cache-dir -r requirements.txt + +# Install itwinai and dependencies +COPY pyproject.toml ./ +COPY src ./ +RUN pip install --no-cache-dir . + +# Add 3DGAN use case files +COPY use-cases/3dgan/* ./ + +# Run inference +CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 1c40b63a..2e607575 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -13,9 +13,10 @@ And select the "3DGAN" experiment. The following is preliminary and not 100% ML/scientifically sound. 1. As inference dataset we can reuse training/validation dataset, -for instance the one downloaded from Google Drive folder. +for instance the one downloaded from Google Drive folder: if the +dataset root folder is not present, the dataset will be downloaded. The inference dataset is a set of H5 files stored inside `exp_data` -sub-folders. +sub-folders: ```text ├── exp_data @@ -55,12 +56,11 @@ torch tensor (.pth) and 3D scatter plot (.jpg): ```text ├── 3dgan-generated-data -│ ├── data -| │ ├── energy=1.296749234199524&angle=1.272539496421814.pth -| │ ├── energy=1.296749234199524&angle=1.272539496421814.jpg +| ├── energy=1.296749234199524&angle=1.272539496421814.pth +| ├── energy=1.296749234199524&angle=1.272539496421814.jpg ... -| │ ├── energy=1.664689540863037&angle=1.4906378984451294.pth -| │ ├── energy=1.664689540863037&angle=1.4906378984451294.jpg +| ├── energy=1.664689540863037&angle=1.4906378984451294.pth +| ├── energy=1.664689540863037&angle=1.4906378984451294.jpg ``` ### Docker image @@ -72,31 +72,35 @@ Build from project root with docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . # Ghcr.io -docker buildx build -t ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 -f use-cases/mnist/torch/Dockerfile . -docker push ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 -f use-cases/3dgan/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 ``` From wherever a sample of MNIST jpg images is available (folder called 'mnist-sample-data/'): ```text -├── $PWD -│ ├── mnist-sample-data -| │ ├── digit_0.jpg -| │ ├── digit_1.jpg -| │ ├── digit_2.jpg +├── $PWD +| ├── exp_data +| │ ├── data +| | │ ├── file_0.h5 +| | │ ├── file_1.h5 ... -| │ ├── digit_N.jpg +| | │ ├── file_N.h5 ``` ```bash -docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-mnist-torch-inference:0.0.1 +docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 ``` -This command will store the results in a folder called "mnist-predictions": +This command will store the results in a folder called "3dgan-generated-data": ```text ├── $PWD -│ ├── mnist-predictions -| │ ├── predictions.csv +| ├── 3dgan-generated-data +| │ ├── energy=1.296749234199524&angle=1.272539496421814.pth +| │ ├── energy=1.296749234199524&angle=1.272539496421814.jpg +... +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.pth +| │ ├── energy=1.664689540863037&angle=1.4906378984451294.jpg ``` diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index d77e363a..65585779 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -41,9 +41,12 @@ def execute( class ParticlesDataset(Dataset): - def __init__(self, datapath): + def __init__(self, datapath: str, max_samples: Optional[int] = None): self.datapath = datapath - self.data = self.fetch_data(self.datapath) + self.max_samples = max_samples + self.data = dict() + + self.fetch_data() def __len__(self): return len(self.data["X"]) @@ -52,34 +55,53 @@ def __getitem__(self, idx): return {"X": self.data["X"][idx], "Y": self.data["Y"][idx], "ang": self.data["ang"][idx], "ecal": self.data["ecal"][idx]} - def fetch_data(self, datapath): - - print("Searching in :", datapath) - Files = sorted(glob.glob(datapath)) - print("Found {} files. ".format(len(Files))) - - concatenated_datasets = [] - for datafile in Files: + def fetch_data(self) -> None: + + print("Searching in :", self.datapath) + files = sorted(glob.glob(self.datapath)) + print("Found {} files. ".format(len(files))) + if len(files) == 0: + raise RuntimeError(f"No H5 files found at '{self.datapath}'!") + + # concatenated_datasets = [] + # for datafile in files: + # f = h5py.File(datafile, 'r') + # dataset = self.GetDataAngleParallel(f) + # concatenated_datasets.append(dataset) + # # Initialize result dictionary + # result = {key: [] for key in concatenated_datasets[0].keys()} + # for d in concatenated_datasets: + # for key in result.keys(): + # result[key].extend(d[key]) + # return result + + for datafile in files: f = h5py.File(datafile, 'r') dataset = self.GetDataAngleParallel(f) - concatenated_datasets.append(dataset) - # Initialize result dictionary - result = {key: [] for key in concatenated_datasets[0].keys()} - for d in concatenated_datasets: - for key in result.keys(): - result[key].extend(d[key]) - return result + for field, vals_list in dataset.items(): + if self.data.get(field) is not None: + self.data[field].extend(vals_list) + else: + self.data[field] = vals_list + + # Stop loading data, if self.max_samples reached + if (self.max_samples is not None + and len(self.data[field]) >= self.max_samples): + for field, vals_list in self.data.items(): + self.data[field] = self.data[field][:self.max_samples] + break def GetDataAngleParallel( - self, - dataset, - xscale=1, - xpower=0.85, - yscale=100, - angscale=1, - angtype="theta", - thresh=1e-4, - daxis=-1,): + self, + dataset, + xscale=1, + xpower=0.85, + yscale=100, + angscale=1, + angtype="theta", + thresh=1e-4, + daxis=-1 + ): """Preprocess function for the dataset Args: @@ -130,17 +152,28 @@ def GetDataAngleParallel( class ParticlesDataModule(pl.LightningDataModule): - def __init__(self, batch_size: int, datapath): + def __init__( + self, + datapath: str, + batch_size: int, + num_workers: int = 4, + max_samples: Optional[int] = None + ) -> None: super().__init__() self.batch_size = batch_size + self.num_workers = num_workers self.datapath = datapath + self.max_samples = max_samples def setup(self, stage: str = None): # make assignments here (val/train/test split) # called on every process in DDP if stage == 'fit' or stage is None: - self.dataset = ParticlesDataset(self.datapath) + self.dataset = ParticlesDataset( + self.datapath, + max_samples=self.max_samples + ) dataset_length = len(self.dataset) split_point = int(dataset_length * 0.9) self.train_dataset, self.val_dataset = \ @@ -150,21 +183,24 @@ def setup(self, stage: str = None): if stage == 'predict': # TODO: inference dataset should be different in that it # does not contain images! - self.predict_dataset = ParticlesDataset(self.datapath) + self.predict_dataset = ParticlesDataset( + self.datapath, + max_samples=self.max_samples + ) # if stage == 'test' or stage is None: # self.test_dataset = MyDataset(self.data_dir, train=False) def train_dataloader(self): - return DataLoader(self.train_dataset, num_workers=4, + return DataLoader(self.train_dataset, num_workers=self.num_workers, batch_size=self.batch_size, drop_last=True) def val_dataloader(self): - return DataLoader(self.val_dataset, num_workers=4, + return DataLoader(self.val_dataset, num_workers=self.num_workers, batch_size=self.batch_size, drop_last=True) def predict_dataloader(self) -> EVAL_DATALOADERS: - return DataLoader(self.predict_dataset, num_workers=4, + return DataLoader(self.predict_dataset, num_workers=self.num_workers, batch_size=self.batch_size, drop_last=True) # def test_dataloader(self): diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index bf9b333a..3939b206 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: exp_data/ + data_path: /usr/data/exp_data/ data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANPredictor @@ -14,8 +14,6 @@ executor: init_args: model_uri: 3dgan-inference.pth - max_samples: 10 - # Pytorch lightning config for training config: seed_everything: 4231162351 @@ -96,9 +94,11 @@ executor: data: class_path: dataloader.ParticlesDataModule init_args: - datapath: exp_data/*/*.h5 + datapath: /usr/data/exp_data/*/*.h5 batch_size: 64 + num_workers: 2 + max_samples: 10 - class_path: saver.ParticleImagesSaver init_args: - save_dir: 3dgan-generated-data \ No newline at end of file + save_dir: /usr/data/3dgan-generated-data \ No newline at end of file diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index b5407c79..faf7dc32 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -91,7 +91,6 @@ def __init__( self, model: Union[ModelLoader, pl.LightningModule], config: Union[Dict, str], - max_samples: Optional[int] = None, name: Optional[str] = None ): super().__init__(model, name) @@ -99,7 +98,6 @@ def __init__( # Load from YAML config = load_yaml(config) self.conf = config - self.max_samples = max_samples def predict( self, @@ -143,10 +141,7 @@ def predict( ))) predictions_dict = dict() - for idx, (img, en, ang) in enumerate(zip(images, energies, angles)): - if (self.max_samples is not None - and idx >= self.max_samples): - break + for img, en, ang in zip(images, energies, angles): sample_key = f"energy={en.item()}&angle={ang.item()}" predictions_dict[sample_key] = img From d690192b7b051e6febc9167dac806d9e74bc02ee Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 17:26:10 +0100 Subject: [PATCH 20/57] ADD: .dockerignore --- .dockerignore | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++ .gitignore | 1 - 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..ac374575 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,131 @@ +**/TODO +**/mamba* +pl-training.yml +.vscode + +# Project folders/files +use-cases +workflows +tests +CHANGELOG + +# Docs +docs + +# Data +**/MNIST +**/*-predictions/ +**/*-data/ +**/*.tar.gz +**/exp_data + +# Logs +**/logs +**/lightning_logs +**/mlruns +**/.logs +**/mllogs +**/nohup* +**/*.out +**/*.err +**/checkpoints/ +**/*_logs +**/tmp* +**/.tmp* + +# Markdown +**/*.md + +# Custom envs +**/.venv* + +# Git +.git +.gitignore +.github + +# CI +.codeclimate.yml +.travis.yml +.taskcluster.yml + +# Docker +docker-compose.yml +.docker +.dockerignore +Dockerfile + +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +**/eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +**/*.egg-info/ +**/.installed.cfg +**/*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Virtual environment +.env/ +.venv/ +venv/ + +# PyCharm +.idea + +# Python mode for VIM +.ropeproject +*/.ropeproject +*/*/.ropeproject +*/*/*/.ropeproject + +# Vim swap files +*.swp +*/*.swp +*/*/*.swp +*/*/*/*.swp \ No newline at end of file diff --git a/.gitignore b/.gitignore index 0c42e0a6..2f0ad142 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -*.pth *_logs exp_data/ TODO From a2a9875f806461c51380328332e3c3a388cbb668 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 17:32:13 +0100 Subject: [PATCH 21/57] Update .dockerignore --- .dockerignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index ac374575..697f33f2 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,7 +4,7 @@ pl-training.yml .vscode # Project folders/files -use-cases +# use-cases workflows tests CHANGELOG From 3bcc4103d02c29a9eda5aebc0130dc5af154833b Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Nov 2023 17:49:01 +0100 Subject: [PATCH 22/57] REMOVE: keras dependency --- src/itwinai/loggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/itwinai/loggers.py b/src/itwinai/loggers.py index 1116c503..d04becd7 100644 --- a/src/itwinai/loggers.py +++ b/src/itwinai/loggers.py @@ -9,7 +9,7 @@ import wandb import mlflow -import mlflow.keras +# import mlflow.keras BASE_EXP_NAME: str = 'unk_experiment' From be0c115aa686b1277c930f6447f4e71d52e3f12a Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 09:48:25 +0100 Subject: [PATCH 23/57] ADD: skip download option --- use-cases/3dgan/dataloader.py | 4 ++++ use-cases/3dgan/pipeline.yaml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 65585779..5ac9c4b7 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -25,6 +25,10 @@ def __init__( self.data_url = data_url def load(self): + if self.data_path is None: + print("Data path is None. Skipping dataset downloading") + return + # Download data if not os.path.exists(self.data_path): gdown.download_folder( diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 942efeb7..676424aa 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: exp_data/ + data_path: exp_data/ # Set to null to skip dataset download data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANTrainer From 77d939e2dba603ecf62e5a9f752e6a0a5441efc8 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:06:50 +0100 Subject: [PATCH 24/57] ADD: cern pipeline.yaml --- use-cases/3dgan/cern-pipeline.yaml | 93 ++++++++++++++++++++++++++++++ use-cases/3dgan/dataloader.py | 10 ++-- 2 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 use-cases/3dgan/cern-pipeline.yaml diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml new file mode 100644 index 00000000..0feaa570 --- /dev/null +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -0,0 +1,93 @@ +executor: + class_path: itwinai.components.Executor + init_args: + steps: + - class_path: dataloader.Lightning3DGANDownloader + init_args: + data_path: /eos/user/k/ktsolaki/data/3dgan_data # exp_data/ + data_url: null # https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX + + - class_path: trainer.Lightning3DGANTrainer + init_args: + # Pytorch lightning config for training + config: + seed_everything: 4231162351 + trainer: + accelerator: auto + accumulate_grad_batches: 1 + barebones: false + benchmark: null + # callbacks: + # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # # init_args: + # # monitor: val_loss + # # patience: 2 + # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + # init_args: + # logging_interval: step + # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # # init_args: + # # dirpath: checkpoints + # # filename: best-checkpoint + # # mode: min + # # monitor: val_loss + # # save_top_k: 1 + # # verbose: true + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: null + devices: auto #[0] + enable_checkpointing: true + enable_model_summary: null + enable_progress_bar: null + fast_dev_run: false + gradient_clip_algorithm: null + gradient_clip_val: null + inference_mode: true + limit_predict_batches: null + limit_test_batches: null + limit_train_batches: null + limit_val_batches: null + log_every_n_steps: 2 + logger: + # - class_path: lightning.pytorch.loggers.CSVLogger + # init_args: + # save_dir: ml_logs/csv_logs + class_path: lightning.pytorch.loggers.MLFlowLogger + init_args: + experiment_name: 3DGAN + save_dir: ml_logs/mlflow_logs + log_model: all + max_epochs: 100 + max_steps: -1 + max_time: null + min_epochs: null + min_steps: null + num_sanity_val_steps: null + overfit_batches: 0.0 + plugins: null + profiler: null + reload_dataloaders_every_n_epochs: 0 + strategy: ddp_find_unused_parameters_true #auto + sync_batchnorm: false + use_distributed_sampler: true + val_check_interval: null + + # Lightning Model configuration + model: + class_path: model.ThreeDGAN + init_args: + latent_size: 256 + batch_size: 64 + loss_weights: [3, 0.1, 25, 0.1] + power: 0.85 + lr: 0.001 + checkpoint_path: exp_data/3dgan.pth + + # Lightning data module configuration + data: + class_path: dataloader.ParticlesDataModule + init_args: + datapath: /eos/user/k/ktsolaki/data/3dgan_data/*.h5 # exp_data/*/*.h5 + batch_size: 64 diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 5ac9c4b7..087a5182 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -16,8 +16,8 @@ class Lightning3DGANDownloader(DataGetter): def __init__( self, - data_url: str, data_path: str, + data_url: Optional[str] = None, name: Optional[str] = None, **kwargs) -> None: super().__init__(name, **kwargs) @@ -25,12 +25,12 @@ def __init__( self.data_url = data_url def load(self): - if self.data_path is None: - print("Data path is None. Skipping dataset downloading") - return - # Download data if not os.path.exists(self.data_path): + if self.data_url is None: + print("WARNING! Data URL is None. " + "Skipping dataset downloading") + gdown.download_folder( url=self.data_url, quiet=False, output=self.data_path From b603b05e1667339f441d55785a98b37d40fbeb42 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:19:01 +0100 Subject: [PATCH 25/57] UPDATE: dataset loading function --- use-cases/3dgan/dataloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 087a5182..4611f229 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -84,7 +84,11 @@ def fetch_data(self) -> None: dataset = self.GetDataAngleParallel(f) for field, vals_list in dataset.items(): if self.data.get(field) is not None: - self.data[field].extend(vals_list) + # self.data[field].extend(vals_list) + self.data[field].resize( + len(self.data[field]) + len(vals_list) + ) + self.data[field][-len(vals_list):] = vals_list else: self.data[field] = vals_list @@ -92,7 +96,7 @@ def fetch_data(self) -> None: if (self.max_samples is not None and len(self.data[field]) >= self.max_samples): for field, vals_list in self.data.items(): - self.data[field] = self.data[field][:self.max_samples] + self.data[field] = vals_list[:self.max_samples] break def GetDataAngleParallel( From bee1317157b2ee3c858e1225b8858011c33d9f78 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:22:59 +0100 Subject: [PATCH 26/57] UPDATE: dataset loading function --- use-cases/3dgan/dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 4611f229..3b0e1e79 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -85,9 +85,9 @@ def fetch_data(self) -> None: for field, vals_list in dataset.items(): if self.data.get(field) is not None: # self.data[field].extend(vals_list) - self.data[field].resize( - len(self.data[field]) + len(vals_list) - ) + new_shape = list(self.data[field].shape) + new_shape[0] += len(vals_list) + self.data[field].resize(new_shape) self.data[field][-len(vals_list):] = vals_list else: self.data[field] = vals_list From b6c3ee28b7a94f22b1c901282e4b73b3b0274185 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:31:14 +0100 Subject: [PATCH 27/57] UPDATE conf --- use-cases/3dgan/cern-pipeline.yaml | 6 ++++-- use-cases/3dgan/dataloader.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml index 0feaa570..2356332b 100644 --- a/use-cases/3dgan/cern-pipeline.yaml +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -79,7 +79,7 @@ executor: class_path: model.ThreeDGAN init_args: latent_size: 256 - batch_size: 64 + batch_size: 128 loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 @@ -90,4 +90,6 @@ executor: class_path: dataloader.ParticlesDataModule init_args: datapath: /eos/user/k/ktsolaki/data/3dgan_data/*.h5 # exp_data/*/*.h5 - batch_size: 64 + batch_size: 128 + num_workers: 0 + max_samples: null diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index 3b0e1e79..fcd90343 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -82,21 +82,21 @@ def fetch_data(self) -> None: for datafile in files: f = h5py.File(datafile, 'r') dataset = self.GetDataAngleParallel(f) - for field, vals_list in dataset.items(): + for field, vals_array in dataset.items(): if self.data.get(field) is not None: - # self.data[field].extend(vals_list) + # Resize to include the new array new_shape = list(self.data[field].shape) - new_shape[0] += len(vals_list) + new_shape[0] += len(vals_array) self.data[field].resize(new_shape) - self.data[field][-len(vals_list):] = vals_list + self.data[field][-len(vals_array):] = vals_array else: - self.data[field] = vals_list + self.data[field] = vals_array # Stop loading data, if self.max_samples reached if (self.max_samples is not None and len(self.data[field]) >= self.max_samples): - for field, vals_list in self.data.items(): - self.data[field] = vals_list[:self.max_samples] + for field, vals_array in self.data.items(): + self.data[field] = vals_array[:self.max_samples] break def GetDataAngleParallel( From 466b15082ac62f08fedf609462d60c76182692e7 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:41:22 +0100 Subject: [PATCH 28/57] UPDATE refactor --- use-cases/3dgan/cern-pipeline.yaml | 2 +- use-cases/3dgan/model.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml index 2356332b..71c369d4 100644 --- a/use-cases/3dgan/cern-pipeline.yaml +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -83,7 +83,7 @@ executor: loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 - checkpoint_path: exp_data/3dgan.pth + checkpoint_path: checkpoints/3dgan.pth # Lightning data module configuration data: diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 6cc92c93..4fc5cc99 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -1,4 +1,5 @@ import sys +import os import pickle from collections import defaultdict import math @@ -332,6 +333,9 @@ def __init__( self.train_history = defaultdict(list) self.test_history = defaultdict(list) self.pklfile = checkpoint_path + checkpoint_dir = os.path.dirname(checkpoint_path) + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) def BitFlip(self, x, prob=0.05): """ @@ -415,8 +419,8 @@ def training_step(self, batch, batch_idx): dtype=torch.float32, device=self.device ) - print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}') - print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}') + # print(f'Energy elements: {energy_batch.numel} {energy_batch.shape}') + # print(f'Angle elements: {ang_batch.numel} {ang_batch.shape}') generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1 @@ -429,12 +433,12 @@ def training_step(self, batch, batch_idx): labels = [fake_batch, energy_batch, ang_batch, ecal_batch] predictions = self.discriminator(image_batch) - print("calculating real_batch_loss...") + # print("calculating real_batch_loss...") real_batch_loss = self.compute_global_loss( labels, predictions, self.loss_weights) self.log("real_batch_loss", sum(real_batch_loss), prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) - print("real batch disc train") + # print("real batch disc train") # the following 3 lines correspond in tf version to: # gradients = tape.gradient(real_batch_loss, # discriminator.trainable_variables) @@ -457,7 +461,7 @@ def training_step(self, batch, batch_idx): labels, predictions, self.loss_weights) self.log("fake_batch_loss", sum(fake_batch_loss), prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) - print("fake batch disc train") + # print("fake batch disc train") # the following 3 lines correspond to # gradients = tape.gradient(fake_batch_loss, # discriminator.trainable_variables) @@ -491,7 +495,7 @@ def training_step(self, batch, batch_idx): labels, predictions, self.loss_weights) self.log("gen_loss", sum(loss), prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) - print("gen train") + # print("gen train") optimizer_generator.zero_grad() self.manual_backward(sum(loss)) # sum(loss).backward() @@ -541,13 +545,13 @@ def training_step(self, batch, batch_idx): # if ecal sum has 100% loss(generating empty events) then end # the training if fake_batch_loss[3] == 100.0 and self.index > 10: - print("Empty image with Ecal loss equal to 100.0 " - f"for {self.index} batch") + # print("Empty image with Ecal loss equal to 100.0 " + # f"for {self.index} batch") torch.save(self.generator.state_dict(), "generator_weights.pth") torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") - print("real_batch_loss", real_batch_loss) - print("fake_batch_loss", fake_batch_loss) + # print("real_batch_loss", real_batch_loss) + # print("fake_batch_loss", fake_batch_loss) sys.exit() # append mean of discriminator loss for real and fake events From 3e1d6ab40a7809651d798b6f8d317c2f6f9f9125 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 13:41:42 +0100 Subject: [PATCH 29/57] UPDATE refactor --- use-cases/3dgan/cern-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml index 71c369d4..7d251ae5 100644 --- a/use-cases/3dgan/cern-pipeline.yaml +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -92,4 +92,4 @@ executor: datapath: /eos/user/k/ktsolaki/data/3dgan_data/*.h5 # exp_data/*/*.h5 batch_size: 128 num_workers: 0 - max_samples: null + max_samples: 3000 From ca60e191ad33a12594685c5eae1db7c5e675c179 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 14:13:01 +0100 Subject: [PATCH 30/57] UPDATE training docs --- use-cases/3dgan/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 2e607575..95a428c8 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -1,5 +1,21 @@ # 3DGAN use case +## Training + +At CERN, use the dedicated configuration file: + +```bash +cd use-cases/3dgan +python train.py -p cern-pipeline.yaml +``` + +Anywhere else, use the general purpose training configuration: + +```bash +cd use-cases/3dgan +python train.py -p pipeline.yaml +``` + To visualize the logs with MLFLow run the following in the terminal: ```bash From 307ed65dfee5352297dcdc6e06eabf23bf6c2703 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 14:49:50 +0100 Subject: [PATCH 31/57] Update readme --- use-cases/3dgan/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 95a428c8..efeed784 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -1,5 +1,22 @@ # 3DGAN use case +First of all, from the repository root, create a torch environment: + +```bash +make torch-gpu +``` + +Now, install custom requirements for 3DGAN: + +```bash +micromamba activate ./.venv-pytorch +cd use-cases/3dgan +pip install -r requirements.txt +``` + +**NOTE**: Python commands below assumed to be executed from within the +micromamba virtual environment. + ## Training At CERN, use the dedicated configuration file: From f47f40dfb8fdf8e1d4d0acdea903c8d8f9fa53bb Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 14:52:42 +0100 Subject: [PATCH 32/57] update README --- use-cases/3dgan/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index efeed784..73646323 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -24,6 +24,9 @@ At CERN, use the dedicated configuration file: ```bash cd use-cases/3dgan python train.py -p cern-pipeline.yaml + +# Or better: +micromamba run -p ../../.venv-pytorch/ torchrun python train.py -p cern-pipeline.yaml ``` Anywhere else, use the general purpose training configuration: @@ -31,6 +34,9 @@ Anywhere else, use the general purpose training configuration: ```bash cd use-cases/3dgan python train.py -p pipeline.yaml + +# Or better: +micromamba run -p ../../.venv-pytorch/ torchrun python train.py -p pipeline.yaml ``` To visualize the logs with MLFLow run the following in the terminal: From fc0697e55372e729334f3a485b2de9d69bc4e312 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 14:54:15 +0100 Subject: [PATCH 33/57] FIX typo --- use-cases/3dgan/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 73646323..2f37f4c7 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -26,7 +26,7 @@ cd use-cases/3dgan python train.py -p cern-pipeline.yaml # Or better: -micromamba run -p ../../.venv-pytorch/ torchrun python train.py -p cern-pipeline.yaml +micromamba run -p ../../.venv-pytorch/ torchrun train.py -p cern-pipeline.yaml ``` Anywhere else, use the general purpose training configuration: @@ -36,7 +36,7 @@ cd use-cases/3dgan python train.py -p pipeline.yaml # Or better: -micromamba run -p ../../.venv-pytorch/ torchrun python train.py -p pipeline.yaml +micromamba run -p ../../.venv-pytorch/ torchrun train.py -p pipeline.yaml ``` To visualize the logs with MLFLow run the following in the terminal: From a8c9d6d52dd7d866483b259a9c66cf8be2d55875 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 15:03:24 +0100 Subject: [PATCH 34/57] Update README --- use-cases/3dgan/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 2f37f4c7..d1bb7bc0 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -26,7 +26,7 @@ cd use-cases/3dgan python train.py -p cern-pipeline.yaml # Or better: -micromamba run -p ../../.venv-pytorch/ torchrun train.py -p cern-pipeline.yaml +micromamba run -p ../../.venv-pytorch/ torchrun --nproc_per_node gpu train.py -p cern-pipeline.yaml ``` Anywhere else, use the general purpose training configuration: @@ -36,7 +36,7 @@ cd use-cases/3dgan python train.py -p pipeline.yaml # Or better: -micromamba run -p ../../.venv-pytorch/ torchrun train.py -p pipeline.yaml +micromamba run -p ../../.venv-pytorch/ torchrun --nproc_per_node gpu train.py -p pipeline.yaml ``` To visualize the logs with MLFLow run the following in the terminal: From 3faa0621c3f9e9d61c345ff596c60adad84c281c Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 9 Nov 2023 15:31:13 +0100 Subject: [PATCH 35/57] Update mkdir --- use-cases/3dgan/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 4fc5cc99..9d416d1f 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -334,8 +334,7 @@ def __init__( self.test_history = defaultdict(list) self.pklfile = checkpoint_path checkpoint_dir = os.path.dirname(checkpoint_path) - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) def BitFlip(self, x, prob=0.05): """ From b50c61029fe3cb2b27d98f7dcd81fb78bdfb699e Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 15 Nov 2023 14:01:42 +0100 Subject: [PATCH 36/57] UPDATE data paths --- use-cases/3dgan/README.md | 2 +- use-cases/3dgan/inference-pipeline.yaml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index d1bb7bc0..d62af5d7 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -129,7 +129,7 @@ From wherever a sample of MNIST jpg images is available ``` ```bash -docker run -it --rm --name running-inference -v "$PWD":/usr/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 ``` This command will store the results in a folder called "3dgan-generated-data": diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 3939b206..8eac66d2 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: /usr/data/exp_data/ + data_path: /tmp/data/exp_data/ data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANPredictor @@ -62,7 +62,7 @@ executor: class_path: lightning.pytorch.loggers.MLFlowLogger init_args: experiment_name: 3DGAN - save_dir: ml_logs/mlflow_logs + save_dir: /tmp/ml_logs/mlflow_logs log_model: all max_epochs: 1 max_steps: 20 @@ -94,11 +94,11 @@ executor: data: class_path: dataloader.ParticlesDataModule init_args: - datapath: /usr/data/exp_data/*/*.h5 + datapath: /tmp/data/exp_data/*/*.h5 batch_size: 64 num_workers: 2 max_samples: 10 - class_path: saver.ParticleImagesSaver init_args: - save_dir: /usr/data/3dgan-generated-data \ No newline at end of file + save_dir: /tmp/data/3dgan-generated-data \ No newline at end of file From 2cedfe793c35176da1346b432f19481d6785da7d Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 16 Nov 2023 13:53:50 +0100 Subject: [PATCH 37/57] UPDATE Dockerfile --- use-cases/3dgan/Dockerfile | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/use-cases/3dgan/Dockerfile b/use-cases/3dgan/Dockerfile index 515caff0..0d075c00 100644 --- a/use-cases/3dgan/Dockerfile +++ b/use-cases/3dgan/Dockerfile @@ -21,5 +21,17 @@ RUN pip install --no-cache-dir . # Add 3DGAN use case files COPY use-cases/3dgan/* ./ +# # Create results folder +# RUN mkdir -p /tmp/data +# RUN chmod 0777 -R /tmp/data + +# Create results folder +# TODO: remove once the problem with file system permissions are solved +RUN mkdir -p /tmp/data +RUN mkdir -p /tmp/data/3dgan-generated-data +RUN mkdir -p /tmp/data/exp_data +RUN mkdir -p /tmp/data/exp_data/3dgan_data +RUN chmod 0777 -R /tmp/data + # Run inference CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file From 1efba3f0a56f5c197d532743861939abcd27b52a Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 16 Nov 2023 16:14:48 +0100 Subject: [PATCH 38/57] UPDATE Dockerfiles --- use-cases/3dgan/Dockerfile | 12 ----------- use-cases/3dgan/Dockerfile.vega | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 12 deletions(-) create mode 100644 use-cases/3dgan/Dockerfile.vega diff --git a/use-cases/3dgan/Dockerfile b/use-cases/3dgan/Dockerfile index 0d075c00..515caff0 100644 --- a/use-cases/3dgan/Dockerfile +++ b/use-cases/3dgan/Dockerfile @@ -21,17 +21,5 @@ RUN pip install --no-cache-dir . # Add 3DGAN use case files COPY use-cases/3dgan/* ./ -# # Create results folder -# RUN mkdir -p /tmp/data -# RUN chmod 0777 -R /tmp/data - -# Create results folder -# TODO: remove once the problem with file system permissions are solved -RUN mkdir -p /tmp/data -RUN mkdir -p /tmp/data/3dgan-generated-data -RUN mkdir -p /tmp/data/exp_data -RUN mkdir -p /tmp/data/exp_data/3dgan_data -RUN chmod 0777 -R /tmp/data - # Run inference CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/3dgan/Dockerfile.vega b/use-cases/3dgan/Dockerfile.vega new file mode 100644 index 00000000..fd933dd0 --- /dev/null +++ b/use-cases/3dgan/Dockerfile.vega @@ -0,0 +1,35 @@ +FROM python:3.9.12 + +WORKDIR /usr/src/app + +RUN pip install --upgrade pip + +# Install pytorch (cpuonly) +# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 +RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache-dir lightning + +# Add 3DGAN custom requirements +COPY use-cases/3dgan/requirements.txt /usr/src/app/ +RUN pip install --no-cache-dir -r requirements.txt + +# Install itwinai and dependencies +COPY pyproject.toml /usr/src/app/ +COPY src /usr/src/app/ +RUN pip install --no-cache-dir /usr/src/app + +# Add 3DGAN use case files +COPY use-cases/3dgan/* /usr/src/app/ + +# # Create results folder +# RUN mkdir -p /tmp/data +# RUN chmod 0777 -R /tmp/data + +# Create results folder +# TODO: remove once the problem with file system permissions are solved +RUN mkdir -p /tmp/data/3dgan-generated-data +RUN mkdir -p /tmp/data/exp_data/3dgan_data +RUN chmod 0777 -R /tmp/data + +# Run inference +CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file From 60ab87db1fba47900ba6e222a64eadb755e9524e Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 16 Nov 2023 18:36:08 +0100 Subject: [PATCH 39/57] UPDATE for Singularity execution --- use-cases/3dgan/README.md | 13 +++++++++++++ use-cases/3dgan/inference-pipeline.yaml | 10 +++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index d62af5d7..53aeb92b 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -143,3 +143,16 @@ This command will store the results in a folder called "3dgan-generated-data": | │ ├── energy=1.664689540863037&angle=1.4906378984451294.pth | │ ├── energy=1.664689540863037&angle=1.4906378984451294.jpg ``` + +### Singularity + +```bash +singularity pull docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +``` + +Run overriding the working directory (`--pwd /usr/src/app`) and providing a +writable filesystem (`-B "$PWD":/usr/data`): + +```bash +singularity exec -e --pwd /usr/src/app -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 +``` diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 8eac66d2..773dd399 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: /tmp/data/exp_data/ + data_path: /usr/data/exp_data/ data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANPredictor @@ -62,7 +62,7 @@ executor: class_path: lightning.pytorch.loggers.MLFlowLogger init_args: experiment_name: 3DGAN - save_dir: /tmp/ml_logs/mlflow_logs + save_dir: /usr/data/ml_logs/mlflow_logs log_model: all max_epochs: 1 max_steps: 20 @@ -88,17 +88,17 @@ executor: loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 - checkpoint_path: exp_data/3dgan.pth + checkpoint_path: /usr/data/exp_data/3dgan.pth # Lightning data module configuration data: class_path: dataloader.ParticlesDataModule init_args: - datapath: /tmp/data/exp_data/*/*.h5 + datapath: /usr/data/exp_data/*/*.h5 batch_size: 64 num_workers: 2 max_samples: 10 - class_path: saver.ParticleImagesSaver init_args: - save_dir: /tmp/data/3dgan-generated-data \ No newline at end of file + save_dir: /usr/data/3dgan-generated-data \ No newline at end of file From 881ae47febec000f8e9d2588435a429a93c1d9f3 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 16 Nov 2023 18:41:19 +0100 Subject: [PATCH 40/57] FIX version mismatch --- use-cases/3dgan/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index 53aeb92b..bdc1520d 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -147,7 +147,7 @@ This command will store the results in a folder called "3dgan-generated-data": ### Singularity ```bash -singularity pull docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +singularity pull docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 ``` Run overriding the working directory (`--pwd /usr/src/app`) and providing a From 9ab6ec1498eaeac36bffc04805eea7a9c9f37fe8 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 16 Nov 2023 18:43:48 +0100 Subject: [PATCH 41/57] UPDATE Singularity docs --- use-cases/3dgan/README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index bdc1520d..b50938be 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -146,12 +146,8 @@ This command will store the results in a folder called "3dgan-generated-data": ### Singularity -```bash -singularity pull docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 -``` - -Run overriding the working directory (`--pwd /usr/src/app`) and providing a -writable filesystem (`-B "$PWD":/usr/data`): +Run overriding the working directory (`--pwd /usr/src/app`, restores Docker's WORKDIR) +and providing a writable filesystem (`-B "$PWD":/usr/data`): ```bash singularity exec -e --pwd /usr/src/app -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 From 59fd74b95719e7f04d55e5c32169e06a3f2f3532 Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Thu, 23 Nov 2023 10:02:04 +0100 Subject: [PATCH 42/57] Named steps pipe (#100) * ADD: dict steps pipe * Relax dependency constraint --- pyproject.toml | 2 +- src/itwinai/components.py | 72 ++++++++++++++++++++++++++--- use-cases/mnist/torch/pipeline.yaml | 8 ++-- 3 files changed, 72 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a780b7c..14e7905a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "submitit>=1.4.6", "typing-extensions==4.5.0", "typing_extensions==4.5.0", - "urllib3>=2.0.5", + "urllib3>=1.26.18", ] # dynamic = ["version", "description"] diff --git a/src/itwinai/components.py b/src/itwinai/components.py index c1e6e372..de155236 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -2,12 +2,15 @@ from typing import Iterable, Dict, Any, Optional, Tuple, Union from abc import ABCMeta, abstractmethod import time +from jsonargparse import ArgumentParser + # import logging # from logging import Logger as PythonLogger from .cluster import ClusterEnvironment from .types import ModelML, DatasetML from .serialization import ModelLoader +from .utils import load_yaml class Executable(metaclass=ABCMeta): @@ -231,12 +234,12 @@ def save(self, *args, **kwargs): class Executor(Executable): """Sets-up and executes a sequence of Executable steps.""" - steps: Iterable[Executable] + steps: Union[Dict[str, Executable], Iterable[Executable]] constructor_args: Dict def __init__( self, - steps: Iterable[Executable], + steps: Union[Dict[str, Executable], Iterable[Executable]], name: Optional[str] = None, # logs_dir: Optional[str] = None, # debug: bool = False, @@ -247,9 +250,20 @@ def __init__( self.steps = steps self.constructor_args = kwargs - def __getitem__(self, subscript) -> Executor: + def __getitem__(self, subscript: Union[str, int, slice]) -> Executor: if isinstance(subscript, slice): - s = self.steps[subscript.start:subscript.stop: subscript.step] + # First, convert to list if is a dict + if isinstance(self.steps, dict): + steps = list(self.steps.items()) + else: + steps = self.steps + # Second, perform slicing + s = steps[subscript.start:subscript.stop: subscript.step] + # Third, reconstruct dict, if it is a dict + if isinstance(self.steps, dict): + s = dict(s) + # Fourth, return sliced sub-pipeline, preserving its + # initial structure sliced = self.__class__( steps=s, **self.constructor_args @@ -270,7 +284,12 @@ def setup(self, parent: Optional[Executor] = None) -> None: Defaults to None. """ super().setup(parent) - for step in self.steps: + if isinstance(self.steps, dict): + steps = list(self.steps.values()) + else: + steps = self.steps + + for step in steps: step.setup(self) step.is_setup = True @@ -303,7 +322,12 @@ def execute( Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as (results, config). """ - for step in self.steps: + if isinstance(self.steps, dict): + steps = list(self.steps.values()) + else: + steps = self.steps + + for step in steps: if not step.is_setup: raise RuntimeError( f"Step '{step.name}' was not setup!" @@ -318,3 +342,39 @@ def _pack_args(self, args) -> Tuple: if not isinstance(args, tuple): args = (args,) return args + + +def recursive_replace(config: Dict, target_field: str, new_value: Any) -> None: + def _recursive_replace_key(sub_dict: Dict): + if not isinstance(sub_dict, dict): + return + for k, v in sub_dict.items(): + if k == target_field: + sub_dict[k] = new_value + return + else: + _recursive_replace_key(v) + _recursive_replace_key(config) + + +def load_pipeline_step( + pipe: Union[str, Dict], + step_id: Union[str, int], + override_keys: Optional[Dict[str, Any]] = None +) -> Executable: + if isinstance(pipe, str): + # Load pipe from YAML file path + pipe = load_yaml(pipe) + step_dict_config = pipe['executor']['init_args']['steps'][step_id] + + # Override fields + if override_keys is not None: + for key, value in override_keys.items(): + recursive_replace(step_dict_config, key, value) + + # Wrap config under "step" field and parse it + step_dict_config = dict(step=step_dict_config) + step_parser = ArgumentParser() + step_parser.add_subclass_arguments(Executable, "step") + parsed_namespace = step_parser.parse_object(step_dict_config) + return step_parser.instantiate_classes(parsed_namespace)["step"] diff --git a/use-cases/mnist/torch/pipeline.yaml b/use-cases/mnist/torch/pipeline.yaml index 9bb7fb98..67848652 100644 --- a/use-cases/mnist/torch/pipeline.yaml +++ b/use-cases/mnist/torch/pipeline.yaml @@ -2,11 +2,13 @@ executor: class_path: itwinai.components.Executor init_args: steps: - - class_path: dataloader.MNISTDataModuleTorch + dataloading_step: + class_path: dataloader.MNISTDataModuleTorch init_args: save_path: .tmp/ - - class_path: itwinai.torch.trainer.TorchTrainerMG + training_step: + class_path: itwinai.torch.trainer.TorchTrainerMG init_args: model: class_path: model.Net @@ -25,7 +27,7 @@ executor: batch_size: 32 pin_memory: True shuffle: False - epochs: 30 + epochs: 2 train_metrics: accuracy: class_path: torchmetrics.classification.MulticlassAccuracy From 8f13d92011c59d5fbe6b934f5dc96254a1d6ae2d Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 10:04:19 +0100 Subject: [PATCH 43/57] UPDATE Singularity exec command --- use-cases/3dgan/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index b50938be..d91bb092 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -150,5 +150,6 @@ Run overriding the working directory (`--pwd /usr/src/app`, restores Docker's WO and providing a writable filesystem (`-B "$PWD":/usr/data`): ```bash -singularity exec -e --pwd /usr/src/app -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 +singularity exec -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 / +bash -c "cd /usr/src/app && python train.py -p pipeline.yaml" ``` From 8e19c62a261a0839bf89b43b773eb378461fc9ad Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 10:05:26 +0100 Subject: [PATCH 44/57] UPDATE: Image version --- use-cases/3dgan/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index d91bb092..bbcd1da8 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -111,8 +111,8 @@ Build from project root with docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . # Ghcr.io -docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 -f use-cases/3dgan/Dockerfile . -docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 -f use-cases/3dgan/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 ``` From wherever a sample of MNIST jpg images is available @@ -129,7 +129,7 @@ From wherever a sample of MNIST jpg images is available ``` ```bash -docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.1 +docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 ``` This command will store the results in a folder called "3dgan-generated-data": From d3a2630aa8f19a545dea5c53f1ca66c926600c1b Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 16:03:59 +0100 Subject: [PATCH 45/57] UPDATE: load components from pipeline --- src/itwinai/components.py | 36 +++++++++++++++++++++-------------- use-cases/3dgan/pipeline.yaml | 20 +++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index de155236..8232834c 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -344,23 +344,27 @@ def _pack_args(self, args) -> Tuple: return args -def recursive_replace(config: Dict, target_field: str, new_value: Any) -> None: - def _recursive_replace_key(sub_dict: Dict): - if not isinstance(sub_dict, dict): - return - for k, v in sub_dict.items(): - if k == target_field: - sub_dict[k] = new_value - return - else: - _recursive_replace_key(v) - _recursive_replace_key(config) +def add_replace_field( + config: Dict, + key_chain: str, + value: Any +) -> None: + sub_config = config + for idx, k in enumerate(key_chain.split('.')): + if idx >= len(key_chain.split('.')) - 1: + # Last key reached + break + if not isinstance(sub_config.get(k), dict): + sub_config[k] = dict() + sub_config = sub_config[k] + sub_config[k] = value def load_pipeline_step( pipe: Union[str, Dict], step_id: Union[str, int], - override_keys: Optional[Dict[str, Any]] = None + override_keys: Optional[Dict[str, Any]] = None, + verbose: bool = False ) -> Executable: if isinstance(pipe, str): # Load pipe from YAML file path @@ -369,8 +373,12 @@ def load_pipeline_step( # Override fields if override_keys is not None: - for key, value in override_keys.items(): - recursive_replace(step_dict_config, key, value) + for key_chain, value in override_keys.items(): + add_replace_field(step_dict_config, key_chain, value) + if verbose: + import json + print(f"NEW STEP CONFIG:") + print(json.dumps(step_dict_config, indent=4)) # Wrap config under "step" field and parse it step_dict_config = dict(step=step_dict_config) diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 676424aa..cd45674d 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -2,12 +2,14 @@ executor: class_path: itwinai.components.Executor init_args: steps: - - class_path: dataloader.Lightning3DGANDownloader + dataloading_step: + class_path: dataloader.Lightning3DGANDownloader init_args: data_path: exp_data/ # Set to null to skip dataset download data_url: https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - - class_path: trainer.Lightning3DGANTrainer + training_step: + class_path: trainer.Lightning3DGANTrainer init_args: # Pytorch lightning config for training config: @@ -49,7 +51,7 @@ executor: limit_test_batches: null limit_train_batches: null limit_val_batches: null - log_every_n_steps: 2 + log_every_n_steps: 1 logger: # - class_path: lightning.pytorch.loggers.CSVLogger # init_args: @@ -59,8 +61,8 @@ executor: experiment_name: 3DGAN save_dir: ml_logs/mlflow_logs log_model: all - max_epochs: 1 - max_steps: 20 + max_epochs: 5 + # max_steps: 2000 max_time: null min_epochs: null min_steps: null @@ -69,7 +71,7 @@ executor: plugins: null profiler: null reload_dataloaders_every_n_epochs: 0 - strategy: ddp_find_unused_parameters_true #auto + strategy: auto #ddp_find_unused_parameters_true #auto sync_batchnorm: false use_distributed_sampler: true val_check_interval: null @@ -79,7 +81,7 @@ executor: class_path: model.ThreeDGAN init_args: latent_size: 256 - batch_size: 64 + batch_size: 4 loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 @@ -90,4 +92,6 @@ executor: class_path: dataloader.ParticlesDataModule init_args: datapath: exp_data/*/*.h5 - batch_size: 64 + batch_size: 4 + num_workers: 0 + max_samples: 12 From 33de0b4cab2865be113b6d09eae455dadb10c6d7 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 16:20:24 +0100 Subject: [PATCH 46/57] ADD: docs --- src/itwinai/components.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 8232834c..b13bc548 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -349,6 +349,16 @@ def add_replace_field( key_chain: str, value: Any ) -> None: + """Replace or add (if not present) a field in a dictionary, following a + path of dot-separated keys. Inplace operation. + + Args: + config (Dict): dictionary to be modified. + key_chain (str): path of dot-separated keys to specify the location + if the new value (e.g., 'foo.bar.line' adds/overwrites the value + located at config['foo']['bar']['line']). + value (Any): the value to insert. + """ sub_config = config for idx, k in enumerate(key_chain.split('.')): if idx >= len(key_chain.split('.')) - 1: @@ -366,6 +376,26 @@ def load_pipeline_step( override_keys: Optional[Dict[str, Any]] = None, verbose: bool = False ) -> Executable: + """Instantiates a specific step from a pipeline configuration file, given + its ID (index if steps are a list, key if steps are a dictionary). It + allows to override the step configuration with user defined values. + + Args: + pipe (Union[str, Dict]): pipeline configuration. Either a path to a + YAML file (if string), or a configuration in memory (if dict object). + step_id (Union[str, int]): step identifier: list index if steps are + represented as a list, string key if steps are represented as a + dictionary. + override_keys (Optional[Dict[str, Any]], optional): if given, maps key + path to the value to add/override. A key path is a string of + dot-separated keys (e.g., 'foo.bar.line' adds/overwrites the value + located at pipe['foo']['bar']['line']). Defaults to None. + verbose (bool, optional): if given, prints to console the new + configuration, obtained after overriding. Defaults to False. + + Returns: + Executable: an instance of the selected step in the pipeline. + """ if isinstance(pipe, str): # Load pipe from YAML file path pipe = load_yaml(pipe) From f2ccfae5b9415d7122fdd34362f805b301465f15 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 16:58:27 +0100 Subject: [PATCH 47/57] Simplify 3DGAN model config --- use-cases/3dgan/model.py | 55 +++++++++++++++++------------------ use-cases/3dgan/pipeline.yaml | 4 +-- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/use-cases/3dgan/model.py b/use-cases/3dgan/model.py index 9d416d1f..9653c98e 100644 --- a/use-cases/3dgan/model.py +++ b/use-cases/3dgan/model.py @@ -1,6 +1,6 @@ import sys -import os -import pickle +# import os +# import pickle from collections import defaultdict import math from typing import Any @@ -306,18 +306,16 @@ class ThreeDGAN(pl.LightningModule): def __init__( self, latent_size=256, - batch_size=64, loss_weights=[3, 0.1, 25, 0.1], power=0.85, lr=0.001, - checkpoint_path: str = '3Dgan.pth' + # checkpoint_path: str = '3Dgan.pth' ): super().__init__() self.save_hyperparameters() self.automatic_optimization = False self.latent_size = latent_size - self.batch_size = batch_size self.loss_weights = loss_weights self.lr = lr self.power = power @@ -332,9 +330,9 @@ def __init__( self.index = 0 self.train_history = defaultdict(list) self.test_history = defaultdict(list) - self.pklfile = checkpoint_path - checkpoint_dir = os.path.dirname(checkpoint_path) - os.makedirs(checkpoint_dir, exist_ok=True) + # self.pklfile = checkpoint_path + # checkpoint_dir = os.path.dirname(checkpoint_path) + # os.makedirs(checkpoint_dir, exist_ok=True) def BitFlip(self, x, prob=0.05): """ @@ -411,10 +409,10 @@ def training_step(self, batch, batch_idx): ecal_batch = ecal_batch.to(self.device) optimizer_discriminator, optimizer_generator = self.optimizers() + batch_size = energy_batch.shape[0] noise = torch.randn( - (energy_batch.shape[0], self.latent_size - 2), - # (self.batch_size, self.latent_size - 2), + (batch_size, self.latent_size - 2), dtype=torch.float32, device=self.device ) @@ -427,7 +425,7 @@ def training_step(self, batch, batch_idx): generated_images = self.generator(generator_ip) # Train discriminator first on real batch - fake_batch = self.BitFlip(np.ones(self.batch_size).astype(np.float32)) + fake_batch = self.BitFlip(np.ones(batch_size).astype(np.float32)) fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) labels = [fake_batch, energy_batch, ang_batch, ecal_batch] @@ -450,7 +448,7 @@ def training_step(self, batch, batch_idx): optimizer_discriminator.step() # Train discriminator on the fake batch - fake_batch = self.BitFlip(np.zeros(self.batch_size).astype(np.float32)) + fake_batch = self.BitFlip(np.zeros(batch_size).astype(np.float32)) fake_batch = torch.tensor([[el] for el in fake_batch]).to(self.device) labels = [fake_batch, energy_batch, ang_batch, ecal_batch] @@ -473,7 +471,7 @@ def training_step(self, batch, batch_idx): # avg_disc_loss = (sum(real_batch_loss) + sum(fake_batch_loss)) / 2 - trick = np.ones(self.batch_size).astype(np.float32) + trick = np.ones(batch_size).astype(np.float32) fake_batch = torch.tensor([[el] for el in trick]).to(self.device) labels = [fake_batch, energy_batch.view(-1, 1), ang_batch, ecal_batch] @@ -481,7 +479,7 @@ def training_step(self, batch, batch_idx): # Train generator twice using combined model for _ in range(2): noise = torch.randn( - (self.batch_size, self.latent_size - 2)).to(self.device) + (batch_size, self.latent_size - 2)).to(self.device) generator_ip = torch.cat( (energy_batch.view(-1, 1), ang_batch.view(-1, 1), noise), dim=1 @@ -615,9 +613,9 @@ def on_train_epoch_end(self): # outputs torch.save(self.discriminator.state_dict(), "discriminator_weights.pth") - with open(self.pklfile, "wb") as f: - pickle.dump({"train": self.train_history, - "test": self.test_history}, f) + # with open(self.pklfile, "wb") as f: + # pickle.dump({"train": self.train_history, + # "test": self.test_history}, f) # pickle.dump({"train": self.train_history}, open(self.pklfile, "wb")) print("train-loss:" + str(self.train_history["generator"][-1][0])) @@ -633,10 +631,11 @@ def validation_step(self, batch, batch_idx): ang_batch = ang_batch.to(self.device) ecal_batch = ecal_batch.to(self.device) + batch_size = energy_batch.shape[0] + # Generate Fake events with same energy and angle as data batch noise = torch.randn( - (energy_batch.shape[0], self.latent_size - 2), - # (self.batch_size, self.latent_size - 2), + (batch_size, self.latent_size - 2), dtype=torch.float32, device=self.device ) @@ -648,10 +647,10 @@ def validation_step(self, batch, batch_idx): # concatenate to fake and real batches X = torch.cat((image_batch, generated_images), dim=0) - # y = np.array([1] * self.batch_size \ - # + [0] * self.batch_size).astype(np.float32) - y = torch.tensor([1] * self.batch_size + [0] * - self.batch_size, dtype=torch.float32).to(self.device) + # y = np.array([1] * batch_size \ + # + [0] * batch_size).astype(np.float32) + y = torch.tensor([1] * batch_size + [0] * + batch_size, dtype=torch.float32).to(self.device) y = y.view(-1, 1) ang = torch.cat((ang_batch, ang_batch), dim=0) @@ -667,7 +666,7 @@ def validation_step(self, batch, batch_idx): labels, disc_eval, self.loss_weights) # Calculate generator loss - trick = np.ones(self.batch_size).astype(np.float32) + trick = np.ones(batch_size).astype(np.float32) fake_batch = torch.tensor([[el] for el in trick]).to(self.device) # fake_batch = [[el] for el in trick] labels = [fake_batch, energy_batch, ang_batch, ecal_batch] @@ -728,10 +727,10 @@ def on_validation_epoch_end(self): print(ROW_FMT.format("discriminator (test)", *self.test_history["discriminator"][-1])) - # save loss dict to pkl file - with open(self.pklfile, "wb") as f: - pickle.dump({"train": self.train_history, - "test": self.test_history}, f) + # # save loss dict to pkl file + # with open(self.pklfile, "wb") as f: + # pickle.dump({"train": self.train_history, + # "test": self.test_history}, f) # pickle.dump({"test": self.test_history}, open(self.pklfile, "wb")) # print("train-loss:" + str(self.train_history["generator"][-1][0])) diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index cd45674d..1e98d173 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -71,7 +71,7 @@ executor: plugins: null profiler: null reload_dataloaders_every_n_epochs: 0 - strategy: auto #ddp_find_unused_parameters_true #auto + strategy: auto #ddp_find_unused_parameters_true sync_batchnorm: false use_distributed_sampler: true val_check_interval: null @@ -81,11 +81,9 @@ executor: class_path: model.ThreeDGAN init_args: latent_size: 256 - batch_size: 4 loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 - checkpoint_path: exp_data/3dgan.pth # Lightning data module configuration data: From 1af8ba758ba5ce5b75be3e851a4b3b8bad6f8f36 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Thu, 23 Nov 2023 18:13:51 +0100 Subject: [PATCH 48/57] ADD: mlflow autologging support for PL trainer --- src/itwinai/torch/mlflow.py | 77 +++++++++++++++++++++++++++++++++++ use-cases/3dgan/pipeline.yaml | 24 +++++------ use-cases/3dgan/trainer.py | 6 +++ 3 files changed, 95 insertions(+), 12 deletions(-) create mode 100644 src/itwinai/torch/mlflow.py diff --git a/src/itwinai/torch/mlflow.py b/src/itwinai/torch/mlflow.py new file mode 100644 index 00000000..18a014ff --- /dev/null +++ b/src/itwinai/torch/mlflow.py @@ -0,0 +1,77 @@ +from typing import Dict, Optional +import os + +import mlflow +import yaml + + +def _get_mlflow_logger_conf(pl_config: Dict) -> Optional[Dict]: + """Extract MLFLowLogger configuration from pytorch lightning + configuration file, if present. + + Args: + pl_config (Dict): lightning configuration loaded in memory. + + Returns: + Optional[Dict]: if present, MLFLowLogger constructor arguments + (under 'init_args' key). + """ + if isinstance(pl_config['trainer']['logger'], list): + # If multiple loggers are provided + for logger_conf in pl_config['trainer']['logger']: + if logger_conf['class_path'].endswith('MLFlowLogger'): + return logger_conf['init_args'] + elif pl_config['trainer']['logger']['class_path'].endswith('MLFlowLogger'): + return pl_config['trainer']['logger']['init_args'] + + +def _mlflow_log_pl_config(pl_config: Dict, local_yaml_path: str) -> None: + os.makedirs(os.path.dirname(local_yaml_path), exist_ok=True) + with open(local_yaml_path, 'w') as outfile: + yaml.dump(pl_config, outfile, default_flow_style=False) + mlflow.log_artifact(local_yaml_path) + + +def init_lightning_mlflow( + pl_config: Dict, + default_experiment_name: str = 'Default', + **autolog_kwargs +) -> None: + """Initialize mlflow for pytorch lightning, also setting up + auto-logging (mlflow.pytorch.autolog(...)). Creates a new mlflow + run and attaches it to the mlflow auto-logger. + + Args: + pl_config (Dict): pytorch lightning configuration loaded in memory. + default_experiment_name (str, optional): used as experiment name + if it is not given in the lightning conf. Defaults to 'Default'. + **autolog_kwargs (kwargs): args for mlflow.pytorch.autolog(...). + """ + mlflow_conf: Optional[Dict] = _get_mlflow_logger_conf(pl_config) + if not mlflow_conf: + return + + tracking_uri = mlflow_conf.get('tracking_uri') + if not tracking_uri: + save_path = mlflow_conf.get('save_dir') + tracking_uri = "file://" + os.path.abspath(save_path) + + experiment_name = mlflow_conf.get('experiment_name') + if not experiment_name: + experiment_name = default_experiment_name + + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(experiment_name) + mlflow.pytorch.autolog(**autolog_kwargs) + mlflow.start_run() + + mlflow_conf['experiment_name'] = experiment_name + mlflow_conf['run_id'] = mlflow.active_run().info.run_id + + _mlflow_log_pl_config(pl_config, '.tmp/pl_config.yml') + + +def teardown_lightning_mlflow() -> None: + """End active mlflow run, if any.""" + if mlflow.active_run() is not None: + mlflow.end_run() diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 1e98d173..2abf4627 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -20,21 +20,21 @@ executor: barebones: false benchmark: null # callbacks: - # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping - # # init_args: - # # monitor: val_loss - # # patience: 2 + # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + # init_args: + # monitor: real_batch_loss_epoch + # patience: 2 # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor # init_args: # logging_interval: step - # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint - # # init_args: - # # dirpath: checkpoints - # # filename: best-checkpoint - # # mode: min - # # monitor: val_loss - # # save_top_k: 1 - # # verbose: true + # - class_path: lightning.pytorch.callbacks.ModelCheckpoint + # init_args: + # dirpath: checkpoints + # filename: best-checkpoint + # mode: min + # monitor: real_batch_loss_epoch + # save_top_k: 1 + # verbose: true check_val_every_n_epoch: 1 default_root_dir: null detect_anomaly: false diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index faf7dc32..5bd2bcdb 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -11,6 +11,10 @@ from itwinai.serialization import ModelLoader from itwinai.torch.inference import TorchModelLoader from itwinai.torch.types import Batch +from itwinai.torch.mlflow import ( + init_lightning_mlflow, + teardown_lightning_mlflow +) from model import ThreeDGAN from dataloader import ParticlesDataModule @@ -26,6 +30,7 @@ def __init__(self, config: Union[Dict, str]): self.conf = config def train(self) -> Any: + init_lightning_mlflow(self.conf, registered_model_name='3dgan-lite') old_argv = sys.argv sys.argv = ['some_script_placeholder.py'] cli = LightningCLI( @@ -42,6 +47,7 @@ def train(self) -> Any: ) sys.argv = old_argv cli.trainer.fit(cli.model, datamodule=cli.datamodule) + teardown_lightning_mlflow() def execute( self, From acf7782469a0d9cde5603796c18d4d6ac8895ddf Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Fri, 24 Nov 2023 09:20:22 +0100 Subject: [PATCH 49/57] UPDATE container info --- use-cases/3dgan/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index bbcd1da8..dda0bf23 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -111,8 +111,8 @@ Build from project root with docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . # Ghcr.io -docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 -f use-cases/3dgan/Dockerfile . -docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 +docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 -f use-cases/3dgan/Dockerfile . +docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 ``` From wherever a sample of MNIST jpg images is available @@ -129,7 +129,7 @@ From wherever a sample of MNIST jpg images is available ``` ```bash -docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 +docker run -it --rm --name running-inference -v "$PWD":/tmp/data ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 ``` This command will store the results in a folder called "3dgan-generated-data": @@ -150,6 +150,6 @@ Run overriding the working directory (`--pwd /usr/src/app`, restores Docker's WO and providing a writable filesystem (`-B "$PWD":/usr/data`): ```bash -singularity exec -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.2 / -bash -c "cd /usr/src/app && python train.py -p pipeline.yaml" +singularity exec -B "$PWD":/usr/data docker://ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 / +bash -c "cd /usr/src/app && python train.py -p inference-pipeline.yaml" ``` From 656ab674d5ea1dfb9c09e7e649046cc3eba99631 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Fri, 1 Dec 2023 15:23:11 +0100 Subject: [PATCH 50/57] Refactor --- .../{Dockerfile => Dockerfile.inference} | 0 use-cases/3dgan/Dockerfile.vega | 35 ----------------- use-cases/3dgan/README.md | 4 +- use-cases/3dgan/cern-pipeline.yaml | 36 +++++++++--------- use-cases/3dgan/pipeline.yaml | 38 +++++++++---------- 5 files changed, 37 insertions(+), 76 deletions(-) rename use-cases/3dgan/{Dockerfile => Dockerfile.inference} (100%) delete mode 100644 use-cases/3dgan/Dockerfile.vega diff --git a/use-cases/3dgan/Dockerfile b/use-cases/3dgan/Dockerfile.inference similarity index 100% rename from use-cases/3dgan/Dockerfile rename to use-cases/3dgan/Dockerfile.inference diff --git a/use-cases/3dgan/Dockerfile.vega b/use-cases/3dgan/Dockerfile.vega deleted file mode 100644 index fd933dd0..00000000 --- a/use-cases/3dgan/Dockerfile.vega +++ /dev/null @@ -1,35 +0,0 @@ -FROM python:3.9.12 - -WORKDIR /usr/src/app - -RUN pip install --upgrade pip - -# Install pytorch (cpuonly) -# Ref:https://pytorch.org/get-started/previous-versions/#linux-and-windows-5 -RUN pip install --no-cache-dir torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu -RUN pip install --no-cache-dir lightning - -# Add 3DGAN custom requirements -COPY use-cases/3dgan/requirements.txt /usr/src/app/ -RUN pip install --no-cache-dir -r requirements.txt - -# Install itwinai and dependencies -COPY pyproject.toml /usr/src/app/ -COPY src /usr/src/app/ -RUN pip install --no-cache-dir /usr/src/app - -# Add 3DGAN use case files -COPY use-cases/3dgan/* /usr/src/app/ - -# # Create results folder -# RUN mkdir -p /tmp/data -# RUN chmod 0777 -R /tmp/data - -# Create results folder -# TODO: remove once the problem with file system permissions are solved -RUN mkdir -p /tmp/data/3dgan-generated-data -RUN mkdir -p /tmp/data/exp_data/3dgan_data -RUN chmod 0777 -R /tmp/data - -# Run inference -CMD [ "python", "train.py", "-p", "inference-pipeline.yaml"] \ No newline at end of file diff --git a/use-cases/3dgan/README.md b/use-cases/3dgan/README.md index dda0bf23..7d1d3d16 100644 --- a/use-cases/3dgan/README.md +++ b/use-cases/3dgan/README.md @@ -108,10 +108,10 @@ Build from project root with ```bash # Local -docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile . +docker buildx build -t itwinai-mnist-torch-inference -f use-cases/3dgan/Dockerfile.inference . # Ghcr.io -docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 -f use-cases/3dgan/Dockerfile . +docker buildx build -t ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 -f use-cases/3dgan/Dockerfile.inference . docker push ghcr.io/intertwin-eu/itwinai-3dgan-inference:0.0.3 ``` diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml index 7d251ae5..57245450 100644 --- a/use-cases/3dgan/cern-pipeline.yaml +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -4,7 +4,7 @@ executor: steps: - class_path: dataloader.Lightning3DGANDownloader init_args: - data_path: /eos/user/k/ktsolaki/data/3dgan_data # exp_data/ + data_path: /eos/user/k/ktsolaki/data/3dgan_data data_url: null # https://drive.google.com/drive/folders/1uPpz0tquokepptIfJenTzGpiENfo2xRX - class_path: trainer.Lightning3DGANTrainer @@ -17,22 +17,22 @@ executor: accumulate_grad_batches: 1 barebones: false benchmark: null - # callbacks: - # # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping - # # init_args: - # # monitor: val_loss - # # patience: 2 - # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor - # init_args: - # logging_interval: step - # # - class_path: lightning.pytorch.callbacks.ModelCheckpoint - # # init_args: - # # dirpath: checkpoints - # # filename: best-checkpoint - # # mode: min - # # monitor: val_loss - # # save_top_k: 1 - # # verbose: true + callbacks: + - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + init_args: + monitor: val_generator_loss + patience: 2 + - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: checkpoints + filename: best-checkpoint + mode: min + monitor: val_generator_loss + save_top_k: 1 + verbose: true check_val_every_n_epoch: 1 default_root_dir: null detect_anomaly: false @@ -92,4 +92,4 @@ executor: datapath: /eos/user/k/ktsolaki/data/3dgan_data/*.h5 # exp_data/*/*.h5 batch_size: 128 num_workers: 0 - max_samples: 3000 + max_samples: 10000 diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 2abf4627..82665304 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -19,22 +19,22 @@ executor: accumulate_grad_batches: 1 barebones: false benchmark: null - # callbacks: - # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping - # init_args: - # monitor: real_batch_loss_epoch - # patience: 2 - # - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor - # init_args: - # logging_interval: step - # - class_path: lightning.pytorch.callbacks.ModelCheckpoint - # init_args: - # dirpath: checkpoints - # filename: best-checkpoint - # mode: min - # monitor: real_batch_loss_epoch - # save_top_k: 1 - # verbose: true + callbacks: + - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + init_args: + monitor: val_generator_loss + patience: 2 + - class_path: lightning.pytorch.callbacks.lr_monitor.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: checkpoints + filename: best-checkpoint + mode: min + monitor: val_generator_loss + save_top_k: 1 + verbose: true check_val_every_n_epoch: 1 default_root_dir: null detect_anomaly: false @@ -53,16 +53,12 @@ executor: limit_val_batches: null log_every_n_steps: 1 logger: - # - class_path: lightning.pytorch.loggers.CSVLogger - # init_args: - # save_dir: ml_logs/csv_logs class_path: lightning.pytorch.loggers.MLFlowLogger init_args: experiment_name: 3DGAN save_dir: ml_logs/mlflow_logs log_model: all max_epochs: 5 - # max_steps: 2000 max_time: null min_epochs: null min_steps: null @@ -92,4 +88,4 @@ executor: datapath: exp_data/*/*.h5 batch_size: 4 num_workers: 0 - max_samples: 12 + max_samples: 48 From b176abf336eefafb8ab104f05a52c8b3cdc4cdb3 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Fri, 1 Dec 2023 15:26:13 +0100 Subject: [PATCH 51/57] UPDATE dependencies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 14e7905a..e1ca8fd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "typing-extensions==4.5.0", "typing_extensions==4.5.0", "urllib3>=1.26.18", + "lightning>=2.0.0", + "torchmetrics>=1.2.0", ] # dynamic = ["version", "description"] From 087c7ec5098a5d453d1b93cd7b4cce80519a10ba Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Fri, 1 Dec 2023 15:31:09 +0100 Subject: [PATCH 52/57] FIX linter problem --- src/itwinai/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index b13bc548..8b64796e 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -394,7 +394,7 @@ def load_pipeline_step( configuration, obtained after overriding. Defaults to False. Returns: - Executable: an instance of the selected step in the pipeline. + Executable: an instance of the selected step in the pipeline. """ if isinstance(pipe, str): # Load pipe from YAML file path From 8d9f51f1aa6e8c5795999682105995a41d43733b Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Wed, 13 Dec 2023 13:25:55 +0100 Subject: [PATCH 53/57] Simplified workflow configuration (#108) * Add SQAaaS dynamic badge for dev branch (#104) * Add SQAaaS dynamic badge * Upgrade to sqaaas-assessment-action@v2 * Add draft example * UPDATE credits field * ADD docs * REFACTOR components and pipeline code * UPDATE docstring * UPDATE mnist torch uc * ADD config file parser draft * ADD itwinaiCLI and ConfigParser * ADD docs * ADD pipeline parser and serializer plus tests * UPDATE docs * ADD adapter component and tests (incl parser) * ADD splitter component, improve pipeline, tests * UPDATE test * REMOVE todos * ADD component tests * ADD serializer tests * FIX linter * ADD basic workflow tutorial * ADD basic intermediate tutorial * ADD advanced tutorial * UPDATE advanced tutorial * UPDATE use cases * UPDATE save parameters * FIX linter * FIX cyclones use case workflow --------- Co-authored-by: orviz --- .github/workflows/sqaaas.yml | 18 +- README.md | 3 +- experimental/cli/example.yaml | 9 + experimental/cli/itwinai-conf.yaml | 14 + experimental/cli/itwinaicli.py | 29 + experimental/cli/mycode.py | 35 + experimental/cli/parser-bk.py | 46 ++ experimental/cli/parser.py | 29 + experimental/workflow/train.yaml | 53 ++ pyproject.toml | 3 +- src/itwinai/cli.py | 47 ++ src/itwinai/components.py | 641 ++++++++++-------- src/itwinai/{ => experimental}/executors.py | 14 +- src/itwinai/parser.py | 485 +++++++++++++ src/itwinai/pipeline.py | 101 +++ src/itwinai/serialization.py | 168 ++++- src/itwinai/tensorflow/trainer.py | 15 +- src/itwinai/tests/__init__.py | 11 + src/itwinai/tests/dummy_components.py | 97 +++ src/itwinai/torch/inference.py | 6 +- src/itwinai/torch/trainer.py | 9 +- src/itwinai/types.py | 8 +- src/itwinai/utils.py | 82 ++- tests/components/conftest.py | 72 ++ tests/components/test_components.py | 156 +++++ tests/components/test_pipe_parser.py | 216 ++++++ tests/components/test_pipeline.py | 83 +++ tests/test_components.py | 9 - tests/test_utils.py | 83 ++- tests/use-cases/conftest.py | 1 - tutorials/ml-workflows/basic_components.py | 91 +++ .../ml-workflows/tutorial_0_basic_workflow.py | 71 ++ .../tutorial_1_intermediate_workflow.py | 98 +++ .../tutorial_2_advanced_workflow.py | 86 +++ use-cases/3dgan/cern-pipeline.yaml | 4 +- use-cases/3dgan/dataloader.py | 27 +- use-cases/3dgan/inference-pipeline.yaml | 4 +- use-cases/3dgan/pipeline.yaml | 4 +- use-cases/3dgan/saver.py | 28 +- use-cases/3dgan/train.py | 25 +- use-cases/3dgan/trainer.py | 39 +- use-cases/3dgan/utils.py | 108 --- use-cases/cyclones/.gitignore | 2 + use-cases/cyclones/dataloader.py | 43 +- use-cases/cyclones/executor.py | 75 -- use-cases/cyclones/pipeline.yaml | 15 +- use-cases/cyclones/train.py | 97 ++- use-cases/cyclones/trainer.py | 38 +- use-cases/mnist/tensorflow/dataloader.py | 32 +- use-cases/mnist/tensorflow/pipeline.yaml | 6 +- use-cases/mnist/tensorflow/train.py | 25 +- use-cases/mnist/tensorflow/trainer.py | 25 +- use-cases/mnist/torch-lightning/dataloader.py | 25 +- use-cases/mnist/torch-lightning/pipeline.yaml | 4 +- use-cases/mnist/torch-lightning/train.py | 25 +- use-cases/mnist/torch-lightning/trainer.py | 15 +- use-cases/mnist/torch/dataloader.py | 67 +- use-cases/mnist/torch/inference-pipeline.yaml | 4 +- use-cases/mnist/torch/pipeline.yaml | 4 +- use-cases/mnist/torch/saver.py | 22 +- use-cases/mnist/torch/train.py | 25 +- use-cases/zebra2horse/train.py | 2 +- 62 files changed, 2784 insertions(+), 895 deletions(-) create mode 100644 experimental/cli/example.yaml create mode 100644 experimental/cli/itwinai-conf.yaml create mode 100644 experimental/cli/itwinaicli.py create mode 100644 experimental/cli/mycode.py create mode 100644 experimental/cli/parser-bk.py create mode 100644 experimental/cli/parser.py create mode 100644 experimental/workflow/train.yaml rename src/itwinai/{ => experimental}/executors.py (92%) create mode 100644 src/itwinai/parser.py create mode 100644 src/itwinai/pipeline.py create mode 100644 src/itwinai/tests/__init__.py create mode 100644 src/itwinai/tests/dummy_components.py create mode 100644 tests/components/conftest.py create mode 100644 tests/components/test_components.py create mode 100644 tests/components/test_pipe_parser.py create mode 100644 tests/components/test_pipeline.py delete mode 100644 tests/test_components.py create mode 100644 tutorials/ml-workflows/basic_components.py create mode 100644 tutorials/ml-workflows/tutorial_0_basic_workflow.py create mode 100644 tutorials/ml-workflows/tutorial_1_intermediate_workflow.py create mode 100644 tutorials/ml-workflows/tutorial_2_advanced_workflow.py delete mode 100644 use-cases/3dgan/utils.py create mode 100644 use-cases/cyclones/.gitignore delete mode 100644 use-cases/cyclones/executor.py diff --git a/.github/workflows/sqaaas.yml b/.github/workflows/sqaaas.yml index d3a61803..3e2bb7b2 100644 --- a/.github/workflows/sqaaas.yml +++ b/.github/workflows/sqaaas.yml @@ -4,10 +4,10 @@ --- name: SQAaaS -on: - push: +on: + push: branches: [main, dev] - pull_request: + pull_request: branches: [main, dev] jobs: @@ -15,15 +15,5 @@ jobs: runs-on: ubuntu-latest name: Job that triggers SQAaaS platform steps: - - name: Extract branch name - shell: bash - run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> "$GITHUB_OUTPUT" - id: extract_branch - - name: Print current branch name (debug) - shell: bash - run: echo running on branch ${{ steps.extract_branch.outputs.branch }} - name: SQAaaS assessment step - uses: eosc-synergy/sqaaas-assessment-action@v1 - with: - repo: 'https://github.com/interTwin-eu/itwinai' - branch: ${{ steps.extract_branch.outputs.branch }} + uses: eosc-synergy/sqaaas-assessment-action@v2 diff --git a/README.md b/README.md index 55c8484a..b0f1d66d 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![GitHub Super-Linter](https://github.com/interTwin-eu/T6.5-AI-and-ML/actions/workflows/lint.yml/badge.svg)](https://github.com/marketplace/actions/super-linter) [![GitHub Super-Linter](https://github.com/interTwin-eu/T6.5-AI-and-ML/actions/workflows/check-links.yml/badge.svg)](https://github.com/marketplace/actions/markdown-link-check) + [![SQAaaS source code](https://github.com/EOSC-synergy/itwinai.assess.sqaaas/raw/dev/.badge/status_shields.svg)](https://sqaaas.eosc-synergy.eu/#/full-assessment/report/https://raw.githubusercontent.com/eosc-synergy/itwinai.assess.sqaaas/dev/.report/assessment_output.json) See the latest version of our [docs](https://intertwin-eu.github.io/T6.5-AI-and-ML/) for a quick overview of this platform for advanced AI/ML workflows in digital twin applications. @@ -104,7 +105,7 @@ To run tests on itwinai package: # Activate env micromamba activate ./.venv-pytorch # or ./.venv-tf -pytest -v -m "not slurm" tests/ +pytest -v -m "not slurm" tests/ ``` However, some tests are intended to be executed only on an HPC system, diff --git a/experimental/cli/example.yaml b/experimental/cli/example.yaml new file mode 100644 index 00000000..ef6a342e --- /dev/null +++ b/experimental/cli/example.yaml @@ -0,0 +1,9 @@ +server: + class_path: mycode.ServerOptions + init_args: + host: localhost + port: 80 +client: + class_path: mycode.ClientOptions + init_args: + url: http://${server.init_args.host}:${server.init_args.port}/ \ No newline at end of file diff --git a/experimental/cli/itwinai-conf.yaml b/experimental/cli/itwinai-conf.yaml new file mode 100644 index 00000000..0cb662df --- /dev/null +++ b/experimental/cli/itwinai-conf.yaml @@ -0,0 +1,14 @@ +pipeline: + class_path: itwinai.pipeline.Pipeline + steps: [server, client] + +server: + class_path: mycode.ServerOptions + init_args: + host: localhost + port: 80 + +client: + class_path: mycode.ClientOptions + init_args: + url: http://${server.init_args.host}:${server.init_args.port}/ \ No newline at end of file diff --git a/experimental/cli/itwinaicli.py b/experimental/cli/itwinaicli.py new file mode 100644 index 00000000..6a22bfb1 --- /dev/null +++ b/experimental/cli/itwinaicli.py @@ -0,0 +1,29 @@ +""" +>>> python itwinaicli.py --config itwinai-conf.yaml --help +>>> python itwinaicli.py --config itwinai-conf.yaml --server.port 333 +""" + + +from itwinai.parser import ConfigParser2 +from itwinai.parser import ItwinaiCLI + +cli = ItwinaiCLI() +print(cli.pipeline) +print(cli.pipeline.steps) +print(cli.pipeline.steps['server'].port) + + +parser = ConfigParser2( + config='itwinai-conf.yaml', + override_keys={ + 'server.init_args.port': 777 + } +) +pipeline = parser.parse_pipeline() +print(pipeline) +print(pipeline.steps) +print(pipeline.steps['server'].port) + +server = parser.parse_step('server') +print(server) +print(server.port) diff --git a/experimental/cli/mycode.py b/experimental/cli/mycode.py new file mode 100644 index 00000000..5da07624 --- /dev/null +++ b/experimental/cli/mycode.py @@ -0,0 +1,35 @@ +# from dataclasses import dataclass +from itwinai.components import BaseComponent + + +class ServerOptions(BaseComponent): + host: str + port: int + + def __init__(self, host: str, port: int) -> None: + self.host = host + self.port = port + + def execute(): + ... + + +class ClientOptions(BaseComponent): + url: str + + def __init__(self, url: str) -> None: + self.url = url + + def execute(): + ... + + +class ServerOptions2(BaseComponent): + host: str + port: int + + def __init__(self, client: ClientOptions) -> None: + self.client = client + + def execute(): + ... diff --git a/experimental/cli/parser-bk.py b/experimental/cli/parser-bk.py new file mode 100644 index 00000000..8f87bf37 --- /dev/null +++ b/experimental/cli/parser-bk.py @@ -0,0 +1,46 @@ +""" +Provide functionalities to manage configuration files, including parsing, +execution, and dynamic override of fields. +""" + +from typing import Any +from jsonargparse import ArgumentParser, ActionConfigFile, Namespace + +from .components import BaseComponent + + +class ItwinaiCLI: + _parser: ArgumentParser + pipeline: BaseComponent + + def __init__( + self, + pipeline_nested_key: str = "pipeline", + args: Any = None, + parser_mode: str = "omegaconf" + ) -> None: + self.pipeline_nested_key = pipeline_nested_key + self.args = args + self.parser_mode = parser_mode + self._init_parser() + self._parse_args() + pipeline_inst = self._parser.instantiate_classes(self._config) + self.pipeline = pipeline_inst[self.pipeline_nested_key] + + def _init_parser(self): + self._parser = ArgumentParser(parser_mode=self.parser_mode) + self._parser.add_argument( + "-c", "--config", action=ActionConfigFile, + required=True, + help="Path to a configuration file in json or yaml format." + ) + self._parser.add_subclass_arguments( + baseclass=BaseComponent, + nested_key=self.pipeline_nested_key + ) + + def _parse_args(self): + if isinstance(self.args, (dict, Namespace)): + self._config = self._parser.parse_object(self.args) + else: + self._config = self._parser.parse_args(self.args) diff --git a/experimental/cli/parser.py b/experimental/cli/parser.py new file mode 100644 index 00000000..f400466f --- /dev/null +++ b/experimental/cli/parser.py @@ -0,0 +1,29 @@ +""" +Example of dynamic override of config files with (sub)class arguments, +and variable interpolation with omegaconf. + +Run with: +>>> python parser.py + +Or (after clearing the arguments in parse_args(...)): +>>> python parser.py --config example.yaml --server.port 212 +See the help page of each class: +>>> python parser.py --server.help mycode.ServerOptions +""" + +from jsonargparse import ArgumentParser, ActionConfigFile +from mycode import ServerOptions, ClientOptions + +if __name__ == "__main__": + parser = ArgumentParser(parser_mode="omegaconf") + parser.add_subclass_arguments(ServerOptions, "server") + parser.add_subclass_arguments(ClientOptions, "client") + parser.add_argument("--config", action=ActionConfigFile) + + # Example of dynamic CLI override + # cfg = parser.parse_args(["--config=example.yaml", "--server.port=212"]) + cfg = parser.parse_args() + cfg = parser.instantiate_classes(cfg) + print(cfg.client) + print(cfg.client.url) + print(cfg.server.port) diff --git a/experimental/workflow/train.yaml b/experimental/workflow/train.yaml new file mode 100644 index 00000000..c21d4141 --- /dev/null +++ b/experimental/workflow/train.yaml @@ -0,0 +1,53 @@ +# AI workflow metadata/header. +# They are optional and easily extensible in the future. +version: 0.0.1 +name: Experiment name +description: This is a textual description +credits: + - author1 + - author2 + +# Provide a unified place where this *template* can be configured. +# Variables which can be overridden at runtime as env vars, e.g.: +# - Execution environment details (e.g., path in container vs. in laptop, MLFlow tracking URI) +# - Tunable parameters (e.g., learning rate) +# - Intrinsically dynamic values (e.g., MLFLow run ID is a random value) +# These variables are interpolated with OmegaConf. +vars: + images_dataset_path: some/path/disk + mlflow_tracking_uri: http://localhost:5000 + training_lr: 0.001 + +# Runner-independent workflow steps. +# Each step is designed to be minimal, but easily extensible +# to accommodate future needs by adding new fields. +# The only required field is 'command'. New fields can be added +# to support future workflow executors. +steps: + preprocessing-step: + command: + class_path: itwinai.torch.Preprocessor + init_args: + save_path: ${vars.images_dataset_path} + after: null + env: null + + training-step: + command: + class_path: itwinai.torch.Trainer + init_args: + lr: ${vars.training_lr} + tracking_uri: ${vars.mlflow_tracking_uri} + after: preprocessing-step + env: null + + sth_step: + command: python inference.py -p pipeline.yaml + after: [preprocessing-step, training-step] + env: docker+ghcr.io/intertwin-eu/itwinai:training-0.0.1 + + sth_step2: + command: python train.py -p pipeline.yaml + after: null + env: conda+path/to/my/local/env + diff --git a/pyproject.toml b/pyproject.toml index e1ca8fd3..5e93f3ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "submitit>=1.4.6", "typing-extensions==4.5.0", "typing_extensions==4.5.0", + "rich>=13.5.3", + "typer>=0.9.0", "urllib3>=1.26.18", "lightning>=2.0.0", "torchmetrics>=1.2.0", @@ -45,7 +47,6 @@ dependencies = [ # TODO: add torch and tensorflow # torch = [] # tf = [] -cli = ["rich>=13.5.3", "typer>=0.9.0"] dev = [ "pytest>=7.4.2", "pytest-mock>=3.11.1", diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index bc1b852e..12954fbf 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -10,12 +10,59 @@ # NOTE: import libs in the command"s function, not here. # Otherwise this will slow the whole CLI. +from typing import Optional, List +from typing_extensions import Annotated +from pathlib import Path import typer app = typer.Typer() +@app.command() +def exec_pipeline( + config: Annotated[Path, typer.Option( + help="Path to the configuration file of the pipeline to execute." + )], + pipe_key: Annotated[str, typer.Option( + help=("Key in the configuration file identifying " + "the pipeline object to execute.") + )] = "pipeline", + overrides_list: Annotated[ + Optional[List[str]], typer.Option( + "--override", "-o", + help=( + "Nested key to dynamically override elements in the " + "configuration file with the " + "corresponding new value, joined by '='. It is also possible " + "to index elements in lists using their list index. " + "Example: [...] " + "-o pipeline.init_args.trainer.init_args.lr=0.001 " + "-o pipeline.my_list.2.batch_size=64 " + ) + ) + ] = None +): + """Execute a pipeline from configuration file. + Allows dynamic override of fields. + """ + # Add working directory to python path so that the interpreter is able + # to find the local python files imported from the pipeline file + import os + import sys + sys.path.append(os.getcwd()) + + # Parse and execute pipeline + from itwinai.parser import ConfigParser + overrides = { + k: v for k, v + in map(lambda x: (x.split('=')[0], x.split('=')[1]), overrides_list) + } + parser = ConfigParser(config=config, override_keys=overrides) + pipeline = parser.parse_pipeline(pipeline_nested_key=pipe_key) + pipeline.execute() + + @app.command() def mlflow_ui( path: str = typer.Option("ml-logs/", help="Path to logs storage."), diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 8b64796e..0c628e0c 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -1,61 +1,118 @@ +""" +This module provides the base classes to define modular and reproducible ML +workflows. The base component classes provide a template to follow for +extending existing components or creating new ones. + +There are two ways of creating workflows: simple and advanced workflows. + +Simple workflows can be obtained by creating a sequence of components +wrapped in a Pipeline object, which executes them in cascade, passing the +output of a component as the input of the following one. It is responsibility +of the user to prevent mismatches among outputs and inputs of component +sequences. This pipeline can be configured +both in terms of parameters and structure, with a configuration file +representing the whole pipeline. This configuration file can be executed +using itwinai CLI without the need of python files. + +Example: + +>>> from itwinai.components import DataGetter, Saver +>>> from itwinai.pipeline import Pipeline +>>> +>>> my_pipe = Pipeline({"getter": DataGetter(...), "data_saver": Saver(...)}) +>>> my_pipe.execute() +>>> my_pipe.to_yaml("training_pipe.yaml") +>>> +>>> # The pipeline can be parsed back to Python with: +>>> from itwinai.parser import PipeParser +>>> my_pipe = PipeParser("training_pipe.yaml") +>>> my_pipe.execute() +>>> +>>> # Run the pipeline from configuration file with dynamic override +>>> itwinai exec-pipeline --config training_pipe.yaml \ +>>> --override pipeline.init_args.steps.data_saver.some_param 42 + + +Advanced workflows foresee more complicated connections between the +components and it is very difficult to define a structure beforehand +without risking of over-constraining the user. Therefore, advanced +workflows are defined by explicitly connecting component outputs to +to the inputs of other components, without a wrapper Pipeline object. +In this case, the configuration files enable the user to persist the +parameters passed to the argument parser, enabling reuse through +configuration files, with the possibility of dynamic overrides of parameters. + +Example: + +>>> from jsonargparse import ArgumentParser, ActionConfigFile +>>> +>>> parser = ArgumentParser(description='PyTorch MNIST Example') +>>> parser.add_argument('--batch-size', type=int, default=64, +>>> help='input batch size for training (default: 64)') +>>> parser.add_argument('--epochs', type=int, default=10, +>>> help='number of epochs to train (default: 10)') +>>> parser.add_argument('--lr', type=float, default=0.01, +>>> help='learning rate (default: 0.01)') +>>> parser.add_argument( +>>> "-c", "--config", action=ActionConfigFile, +>>> required=True, +>>> help="Path to a configuration file in json or yaml format." +>>> ) +>>> args = parser.parse_args() +>>> +>>> from itwinai.components import ( +>>> DataGetter, Saver, DataSplitter, Trainer +>>> ) +>>> getter = DataGetter(...) +>>> splitter = DataSplitter(...) +>>> data_saver = Saver(...) +>>> model_saver = Saver(...) +>>> trainer = Trainer( +>>> batch_size=args.batch_size, lr=args.lr, epochs=args.epochs +>>> ) +>>> +>>> # Compose workflow +>>> my_dataset = getter.execute() +>>> train_set, valid_set, test_set = splitter.execute(my_dataset) +>>> data_saver.execute("train_dataset.pkl", test_set) +>>> _, _, _, trained_model = trainer(train_set, valid_set) +>>> model_saver.execute(trained_model) +>>> +>>> # Run the script using a previous configuration with dynamic override +>>> python my_train.py --config training_pipe.yaml --lr 0.002 +""" + + from __future__ import annotations -from typing import Iterable, Dict, Any, Optional, Tuple, Union -from abc import ABCMeta, abstractmethod +from typing import Any, Optional, Tuple, Union, Callable, Dict, List +from abc import ABC, abstractmethod import time -from jsonargparse import ArgumentParser - +import functools # import logging # from logging import Logger as PythonLogger -from .cluster import ClusterEnvironment -from .types import ModelML, DatasetML -from .serialization import ModelLoader -from .utils import load_yaml - +from .types import MLModel, MLDataset, MLArtifact +from .serialization import ModelLoader, Serializable -class Executable(metaclass=ABCMeta): - """Base Executable class. - Args: - name (Optional[str], optional): unique identifier for a step. - Defaults to None. - logs_path (Optional[str], optional): where to store the logs - produced by Python logging. Defaults to None. - """ - name: str = 'unnamed' - is_setup: bool = False - cluster: ClusterEnvironment = None - parent: Executor = None - # logs_dir: str = None - # log_file: str = None - # console: PythonLogger = None - def __init__( - self, - name: Optional[str] = None, - # logs_dir: Optional[str] = None, - # debug: bool = False, - **kwargs - ) -> None: - self.name = name if name is not None else self.__class__.__name__ - # self.logs_dir = logs_dir - # self.debug = debug +def monitor_exec(method: Callable) -> Callable: + """Decorator for execute method of a component class. + Computes execution time and gives some information about + the execution of the component. - def __call__( - self, - *args: Any, - config: Optional[Dict] = None, - **kwargs: Any - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - # WAIT! This method SHOULD NOT be overridden. This is just a wrapper. - # Override execute() instead! + Args: + func (Callable): class method. + """ + @functools.wraps(method) + def monitored_method(self: BaseComponent, *args, **kwargs) -> Any: msg = f"Starting execution of '{self.name}'..." self._printout(msg) start_t = time.time() try: # print(f'ARGS: {args}') # print(f'KWARGS: {kwargs}') - result = self.execute(*args, **kwargs, config=config) + result = method(self, *args, **kwargs) finally: self.cleanup() self.exec_t = time.time() - start_t @@ -63,25 +120,47 @@ def __call__( self._printout(msg) return result - @abstractmethod - def execute( - self, - *args, - config: Optional[Dict] = None, - **kwargs - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - """"Execute some operations. + return monitored_method + + +class BaseComponent(ABC, Serializable): + """Base component class. Each component provides a simple interface + to foster modularity in machine learning code. Each component class + implements the `execute` method, which received some input ML artifacts + (e.g., datasets), performs some operations and returns new artifacts. + The components are meant to be assembled in complex ML workflows, + represented as pipelines. Args: - args (Any, optional): generic input of the executable step. - config (Dict, optional): key-value configuration. + name (Optional[str], optional): unique identifier for a step. Defaults to None. - - Returns: - Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as - (results, config). """ - return args, config + _name: str = None + parameters: Dict[Any, Any] = None + + def __init__( + self, + name: Optional[str] = None, + # logs_dir: Optional[str] = None, + # debug: bool = False, + ) -> None: + self.save_parameters(name=name) + self.name = name + + @property + def name(self) -> str: + return ( + self._name if self._name is not None else self.__class__.__name__ + ) + + @name.setter + def name(self, name: str) -> None: + self._name = name + + @abstractmethod + @monitor_exec + def execute(self, *args, **kwargs) -> Any: + """"Execute some operations.""" # def setup_console(self): # """Setup Python logging""" @@ -104,48 +183,37 @@ def execute( # ) # self.console = logging.getLogger(self.name) - def setup(self, parent: Optional[Executor] = None) -> None: - """Inherit properties from parent Executor instance. - - Args: - parent (Optional[Executor], optional): parent executor. - Defaults to None. - """ - if parent is None: - # # Setup Python logging ("console") - # self.logs_dir = '.logs' - # os.makedirs(self.logs_dir, exist_ok=True) - # self.setup_console() - self.is_setup = True - return - if self.cluster is None: - self.cluster = parent.cluster - - # # Python logging ("console") - # if self.logs_dir is None: - # self.logs_dir = parent.logs_dir - # if self.log_file is None: - # self.log_file = parent.log_file - # if self.console is None: - # self.console = logging.getLogger(self.name) - - self.is_setup = True - def cleanup(self): - pass + """Cleanup resources allocated by this component.""" - def _printout(self, msg: str): + @staticmethod + def _printout(msg: str): msg = f"# {msg} #" print("#"*len(msg)) print(msg) print("#"*len(msg)) -class Trainer(Executable): +class Trainer(BaseComponent): """Trains a machine learning model.""" + @abstractmethod - def train(self, *args, **kwargs): - pass + @monitor_exec + def execute( + self, + train_dataset: MLDataset, + validation_dataset: MLDataset + ) -> Tuple[MLDataset, MLDataset, MLModel]: + """Trains a machine learning model. + + Args: + train_dataset (DatasetML): training dataset. + validation_dataset (DatasetML): validation dataset. + + Returns: + Tuple[DatasetML, DatasetML, ModelML]: training dataset, + validation dataset, trained model. + """ @abstractmethod def save_state(self): @@ -156,44 +224,27 @@ def load_state(self): pass -class Predictor(Executable): +class Predictor(BaseComponent): """Applies a pre-trained machine learning model to unseen data.""" - model: ModelML + model: MLModel def __init__( self, - model: Union[ModelML, ModelLoader], + model: Union[MLModel, ModelLoader], name: Optional[str] = None, - **kwargs ) -> None: - super().__init__(name, **kwargs) + super().__init__(name=name) + self.save_parameters(model=model, name=name) self.model = model() if isinstance(model, ModelLoader) else model - def execute( - self, - predict_dataset: DatasetML, - config: Optional[Dict] = None, - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - """"Execute some operations. - - Args: - predict_dataset (DatasetML): dataset object for inference. - config (Dict, optional): key-value configuration. - Defaults to None. - - Returns: - Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as - (results, config). - """ - return self.predict(predict_dataset), config - @abstractmethod - def predict( + @monitor_exec + def execute( self, - predict_dataset: DatasetML, - model: Optional[ModelML] = None - ) -> Iterable[Any]: + predict_dataset: MLDataset, + model: Optional[MLModel] = None + ) -> MLDataset: """Applies a machine learning model on a dataset of samples. Args: @@ -202,217 +253,205 @@ def predict( if given. Defaults to None. Returns: - Iterable[Any]: predictions with the same cardinality of the + DatasetML: predictions with the same cardinality of the input dataset. """ -class DataGetter(Executable): +class DataGetter(BaseComponent): + """Retrieves a dataset.""" + @abstractmethod - def load(self, *args, **kwargs): - pass + @monitor_exec + def execute(self) -> MLDataset: + """Retrieves a dataset. + + Returns: + MLDataset: retrieved dataset. + """ -class DataPreproc(Executable): +class DataPreproc(BaseComponent): + """Performs dataset pre-processing.""" + @abstractmethod - def preproc(self, *args, **kwargs): - pass + @monitor_exec + def execute(self, dataset: MLDataset) -> MLDataset: + """Pre-processes a dataset. + + Args: + dataset (MLDataset): dataset. + + Returns: + MLDataset: pre-processed dataset. + """ -# class StatGetter(Executable): -# @abstractmethod -# def stats(self, *args, **kwargs): -# pass +class Saver(BaseComponent): + """Saves artifact to disk.""" -class Saver(Executable): @abstractmethod - def save(self, *args, **kwargs): - pass + @monitor_exec + def execute(self, artifact: MLArtifact) -> MLArtifact: + """Saves an ML artifact to disk. + Args: + artifact (MLArtifact): artifact to save. -class Executor(Executable): - """Sets-up and executes a sequence of Executable steps.""" + Returns: + MLArtifact: the same input artifact, after saving it. + """ - steps: Union[Dict[str, Executable], Iterable[Executable]] - constructor_args: Dict - def __init__( - self, - steps: Union[Dict[str, Executable], Iterable[Executable]], - name: Optional[str] = None, - # logs_dir: Optional[str] = None, - # debug: bool = False, - **kwargs - ): - # super().__init__(name=name, logs_dir=logs_dir, debug=debug, **kwargs) - super().__init__(name=name, **kwargs) - self.steps = steps - self.constructor_args = kwargs - - def __getitem__(self, subscript: Union[str, int, slice]) -> Executor: - if isinstance(subscript, slice): - # First, convert to list if is a dict - if isinstance(self.steps, dict): - steps = list(self.steps.items()) - else: - steps = self.steps - # Second, perform slicing - s = steps[subscript.start:subscript.stop: subscript.step] - # Third, reconstruct dict, if it is a dict - if isinstance(self.steps, dict): - s = dict(s) - # Fourth, return sliced sub-pipeline, preserving its - # initial structure - sliced = self.__class__( - steps=s, - **self.constructor_args - ) - return sliced - else: - return self.steps[subscript] +class Adapter(BaseComponent): + """Connects to components in a sequential pipeline, allowing to + control with greater detail how intermediate results are propagated + among the components. + + Args: + policy (List[Any]): list of the same length of the output of this + component, describing how to map the input args to the output. + name (Optional[str], optional): name of the component. + Defaults to None. + + The adapter allows to define a policy with which inputs are re-arranged + before being propagated to the next component. + Some examples: [policy]: (input) -> (output) + - ["INPUT_ARG#2", "INPUT_ARG#1", "INPUT_ARG#0"]: (11,22,33) -> (33,22,11) + - ["INPUT_ARG#0", "INPUT_ARG#2", None]: (11, 22, 33) -> (11, 33, None) + - []: (11, 22, 33) -> () + - [42, "INPUT_ARG#2", "hello"] -> (11,22,33,44,55) -> (42, 33, "hello") + - [None, 33, 3.14]: () -> (None, 33, 3.14) + - [None, 33, 3.14]: ("double", 44, None, True) -> (None, 33, 3.14) + """ - def __len__(self) -> int: - return len(self.steps) + policy: List[Any] + INPUT_PREFIX: str = "INPUT_ARG#" - def setup(self, parent: Optional[Executor] = None) -> None: - """Inherit properties from parent Executor instance, then - propagates its properties to its own child steps. + def __init__(self, policy: List[Any], name: Optional[str] = None) -> None: + super().__init__(name=name) + self.save_parameters(policy=policy, name=name) + self.name = name + self.policy = policy + + @monitor_exec + def execute(self, *args) -> Tuple: + """Produces an output tuple by arranging input arguments according + to the policy specified in the constructor. Args: - parent (Optional[Executor], optional): parent executor. - Defaults to None. + args (Tuple): input arguments. + + Returns: + Tuple: input args arranged according to some policy. """ - super().setup(parent) - if isinstance(self.steps, dict): - steps = list(self.steps.values()) - else: - steps = self.steps - - for step in steps: - step.setup(self) - step.is_setup = True - - # def setup(self, config: Dict = None): - # """Pass a key-value based configuration down the pipeline, - # to propagate information computed at real-time. - - # Args: - # config (Dict, optional): key-value configuration. - # Defaults to None. - # """ - # for step in self.steps: - # config = step.setup(config) + result = [] + for itm in self.policy: + if isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX): + arg_idx = int(itm[len(self.INPUT_PREFIX):]) + if arg_idx >= len(args): + max_idx = max(map( + lambda itm: int(itm[len(self.INPUT_PREFIX):]), + filter( + lambda el: ( + isinstance(el, str) + and el.startswith(self.INPUT_PREFIX) + ), + self.policy + ))) + raise IndexError( + f"The args received as input by '{self.name}' " + "are not consistent with the given adapter policy " + "because input args are too few! " + f"Input args are {len(args)} but the policy foresees " + f"at least {max_idx+1} items." + ) + result.append(args[arg_idx]) + else: + result.append(itm) + return tuple(result) + + +class DataSplitter(BaseComponent): + """Splits a dataset into train, validation, and test splits.""" + _train_proportion: Union[int, float] + _validation_proportion: Union[int, float] + _test_proportion: Union[int, float] + + def __init__( + self, + train_proportion: Union[int, float], + validation_proportion: Union[int, float], + test_proportion: Union[int, float], + name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters( + train_proportion=train_proportion, + validation_proportion=validation_proportion, + test_proportion=test_proportion, + name=name + ) + self.train_proportion = train_proportion + self.validation_proportion = validation_proportion + self.test_proportion = test_proportion + + @property + def train_proportion(self) -> Union[int, float]: + """Training set proportion.""" + return self._train_proportion + + @train_proportion.setter + def train_proportion(self, prop: Union[int, float]) -> None: + if isinstance(prop, float) and not 0.0 <= prop <= 1.0: + raise ValueError( + "Train proportion should be in the interval [0.0, 1.0] " + f"if given as float. Received {prop}" + ) + self._train_proportion = prop + + @property + def validation_proportion(self) -> Union[int, float]: + """Validation set proportion.""" + return self._validation_proportion + + @validation_proportion.setter + def validation_proportion(self, prop: Union[int, float]) -> None: + if isinstance(prop, float) and not 0.0 <= prop <= 1.0: + raise ValueError( + "Validation proportion should be in the interval [0.0, 1.0] " + f"if given as float. Received {prop}" + ) + self._validation_proportion = prop + + @property + def test_proportion(self) -> Union[int, float]: + """Test set proportion.""" + return self._test_proportion + + @test_proportion.setter + def test_proportion(self, prop: Union[int, float]) -> None: + if isinstance(prop, float) and not 0.0 <= prop <= 1.0: + raise ValueError( + "Test proportion should be in the interval [0.0, 1.0] " + f"if given as float. Received {prop}" + ) + self._test_proportion = prop + + @abstractmethod + @monitor_exec def execute( self, - *args, - config: Optional[Dict] = None, - **kwargs - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - """"Execute some operations. + dataset: MLDataset + ) -> Tuple[MLDataset, MLDataset, MLDataset]: + """Splits a dataset into train, validation and test splits. Args: - args (Tuple, optional): generic input of the first executable step - in the pipeline. - config (Dict, optional): key-value configuration. - Defaults to None. + dataset (MLDataset): input dataset. Returns: - Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as - (results, config). + Tuple[MLDataset, MLDataset, MLDataset]: tuple of + train, validation and test splits. """ - if isinstance(self.steps, dict): - steps = list(self.steps.values()) - else: - steps = self.steps - - for step in steps: - if not step.is_setup: - raise RuntimeError( - f"Step '{step.name}' was not setup!" - ) - args = self._pack_args(args) - args, config = step(*args, **kwargs, config=config) - - return args, config - - def _pack_args(self, args) -> Tuple: - args = () if args is None else args - if not isinstance(args, tuple): - args = (args,) - return args - - -def add_replace_field( - config: Dict, - key_chain: str, - value: Any -) -> None: - """Replace or add (if not present) a field in a dictionary, following a - path of dot-separated keys. Inplace operation. - - Args: - config (Dict): dictionary to be modified. - key_chain (str): path of dot-separated keys to specify the location - if the new value (e.g., 'foo.bar.line' adds/overwrites the value - located at config['foo']['bar']['line']). - value (Any): the value to insert. - """ - sub_config = config - for idx, k in enumerate(key_chain.split('.')): - if idx >= len(key_chain.split('.')) - 1: - # Last key reached - break - if not isinstance(sub_config.get(k), dict): - sub_config[k] = dict() - sub_config = sub_config[k] - sub_config[k] = value - - -def load_pipeline_step( - pipe: Union[str, Dict], - step_id: Union[str, int], - override_keys: Optional[Dict[str, Any]] = None, - verbose: bool = False -) -> Executable: - """Instantiates a specific step from a pipeline configuration file, given - its ID (index if steps are a list, key if steps are a dictionary). It - allows to override the step configuration with user defined values. - - Args: - pipe (Union[str, Dict]): pipeline configuration. Either a path to a - YAML file (if string), or a configuration in memory (if dict object). - step_id (Union[str, int]): step identifier: list index if steps are - represented as a list, string key if steps are represented as a - dictionary. - override_keys (Optional[Dict[str, Any]], optional): if given, maps key - path to the value to add/override. A key path is a string of - dot-separated keys (e.g., 'foo.bar.line' adds/overwrites the value - located at pipe['foo']['bar']['line']). Defaults to None. - verbose (bool, optional): if given, prints to console the new - configuration, obtained after overriding. Defaults to False. - - Returns: - Executable: an instance of the selected step in the pipeline. - """ - if isinstance(pipe, str): - # Load pipe from YAML file path - pipe = load_yaml(pipe) - step_dict_config = pipe['executor']['init_args']['steps'][step_id] - - # Override fields - if override_keys is not None: - for key_chain, value in override_keys.items(): - add_replace_field(step_dict_config, key_chain, value) - if verbose: - import json - print(f"NEW STEP CONFIG:") - print(json.dumps(step_dict_config, indent=4)) - - # Wrap config under "step" field and parse it - step_dict_config = dict(step=step_dict_config) - step_parser = ArgumentParser() - step_parser.add_subclass_arguments(Executable, "step") - parsed_namespace = step_parser.parse_object(step_dict_config) - return step_parser.instantiate_classes(parsed_namespace)["step"] diff --git a/src/itwinai/executors.py b/src/itwinai/experimental/executors.py similarity index 92% rename from src/itwinai/executors.py rename to src/itwinai/experimental/executors.py index d94e1c0f..2c89f1c3 100644 --- a/src/itwinai/executors.py +++ b/src/itwinai/experimental/executors.py @@ -8,11 +8,11 @@ from ray import air, tune from jsonargparse import ArgumentParser -from .components import Executor, Executable -from .utils import parse_pipe_config +from ..components import Pipeline, BaseComponent +from ..utils import parse_pipe_config -class LocalExecutor(Executor): +class LocalExecutor(Pipeline): def __init__(self, pipeline, class_dict): # Create parser for the pipeline (ordered) pipe_parser = ArgumentParser() @@ -40,7 +40,7 @@ def setup(self, args): args = executable.setup(args) -class RayExecutor(Executor): +class RayExecutor(Pipeline): def __init__(self, pipeline, class_dict, param_space): self.class_dict = class_dict self.param_space = param_space @@ -91,10 +91,10 @@ def setup(self, args): pass -class ParallelExecutor(Executor): +class ParallelExecutor(Pipeline): """Execute a pipeline in parallel: multiprocessing and multi-node.""" - def __init__(self, steps: Iterable[Executable]): + def __init__(self, steps: Iterable[BaseComponent]): super().__init__(steps) def setup(self, config: Dict = None): @@ -112,7 +112,7 @@ class HPCExecutor(ParallelExecutor): network access. """ - def __init__(self, steps: Iterable[Executable]): + def __init__(self, steps: Iterable[BaseComponent]): super().__init__(steps) def setup(self, config: Dict = None): diff --git a/src/itwinai/parser.py b/src/itwinai/parser.py new file mode 100644 index 00000000..8e393652 --- /dev/null +++ b/src/itwinai/parser.py @@ -0,0 +1,485 @@ +""" +Provide functionalities to manage configuration files, including parsing, +execution, and dynamic override of fields. +""" + +import logging +import os +from typing import Dict, Any, List, Type, Union, Optional +from jsonargparse import ArgumentParser as JAPArgumentParser +from jsonargparse import ActionConfigFile +import json +from jsonargparse._formatters import DefaultHelpFormatter +from omegaconf import OmegaConf +from pathlib import Path + +from .components import BaseComponent +from .pipeline import Pipeline +from .utils import load_yaml + + +def add_replace_field( + config: Dict, + key_chain: str, + value: Any +) -> None: + """Replace or add (if not present) a field in a dictionary, following a + path of dot-separated keys. Adding is not supported for list items. + Inplace operation. + Args: + config (Dict): dictionary to be modified. + key_chain (str): path of nested (dot-separated) keys to specify the + location + of the new value (e.g., 'foo.bar.line' adds/overwrites the value + located at config['foo']['bar']['line']). + value (Any): the value to insert. + """ + sub_config = config + for idx, k in enumerate(key_chain.split('.')): + if idx >= len(key_chain.split('.')) - 1: + # Last key reached + break + + if isinstance(sub_config, (list, tuple)): + k = int(k) + next_elem = sub_config[k] + else: + next_elem = sub_config.get(k) + + if not isinstance(next_elem, (dict, list, tuple)): + sub_config[k] = dict() + + sub_config = sub_config[k] + if isinstance(sub_config, (list, tuple)): + k = int(k) + sub_config[k] = value + + +class ConfigParser: + """ + Parses a pipeline from a configuration file. + It also provides functionalities for dynamic override + of fields by means of nested key notation. + + Args: + config (Union[str, Dict]): path to YAML configuration file + or dict storing a configuration. + override_keys (Optional[Dict[str, Any]], optional): dict mapping + nested keys to the value to override. Defaults to None. + + Example: + + >>> # pipeline.yaml file + >>> pipeline: + >>> class_path: itwinai.pipeline.Pipeline + >>> init_args: + >>> steps: + >>> - class_path: dataloader.MNISTDataModuleTorch + >>> init_args: + >>> save_path: .tmp/ + >>> + >>> - class_path: itwinai.torch.trainer.TorchTrainerMG + >>> init_args: + >>> model: + >>> class_path: model.Net + >>> loss: + >>> class_path: torch.nn.NLLLoss + >>> init_args: + >>> reduction: mean + + >>> from itwinai.parser import ConfigParser + >>> + >>> parser = ConfigParser( + >>> config='pipeline.yaml', + >>> override_keys={ + >>> 'pipeline.init_args.steps.0.init_args.save_path': /save/path + >>> } + >>> ) + >>> pipeline = parser.parse_pipeline() + >>> print(pipeline) + >>> print(pipeline.steps) + >>> + >>> dataloader = parser.parse_step(0) + >>> print(dataloader) + >>> print(dataloader.save_path) + """ + + config: Dict + pipeline: Pipeline + + def __init__( + self, + config: Union[str, Dict], + override_keys: Optional[Dict[str, Any]] = None + ) -> None: + self.config = config + self.override_keys = override_keys + if isinstance(self.config, (str, Path)): + self.config = load_yaml(self.config) + self._dynamic_override_keys() + self._omegaconf_interpolate() + + def _dynamic_override_keys(self): + if self.override_keys is not None: + for key_chain, value in self.override_keys.items(): + add_replace_field(self.config, key_chain, value) + + def _omegaconf_interpolate(self) -> None: + """Performs variable interpolation with OmegaConf on internal + configuration file. + """ + conf = OmegaConf.create(self.config) + self.config = OmegaConf.to_container(conf, resolve=True) + + def parse_pipeline( + self, + pipeline_nested_key: str = "pipeline", + verbose: bool = False + ) -> Pipeline: + """Merges steps into pipeline and parses it. + + Args: + pipeline_nested_key (str, optional): nested key in the + configuration file identifying the pipeline object. + Defaults to "pipeline". + verbose (bool): if True, prints the assembled pipeline + to console formatted as JSON. + + Returns: + Pipeline: instantiated pipeline. + """ + pipe_parser = JAPArgumentParser() + pipe_parser.add_subclass_arguments(Pipeline, "pipeline") + + pipe_dict = self.config + for key in pipeline_nested_key.split('.'): + pipe_dict = pipe_dict[key] + # pipe_dict = self.config[pipeline_nested_key] + pipe_dict = {"pipeline": pipe_dict} + + if verbose: + print("Assembled pipeline:") + print(json.dumps(pipe_dict, indent=4)) + + # Parse pipeline dict once merged with steps + conf = pipe_parser.parse_object(pipe_dict) + pipe = pipe_parser.instantiate_classes(conf) + self.pipeline = pipe["pipeline"] + return self.pipeline + + def parse_step( + self, + step_idx: Union[str, int], + pipeline_nested_key: str = "pipeline", + verbose: bool = False + ) -> BaseComponent: + pipeline_dict = self.config + for key in pipeline_nested_key.split('.'): + pipeline_dict = pipeline_dict[key] + + step_dict_config = pipeline_dict['init_args']['steps'][step_idx] + + if verbose: + print(f"STEP '{step_idx}' CONFIG:") + print(json.dumps(step_dict_config, indent=4)) + + # Wrap config under "step" field and parse it + step_dict_config = {'step': step_dict_config} + step_parser = JAPArgumentParser() + step_parser.add_subclass_arguments(BaseComponent, "step") + parsed_namespace = step_parser.parse_object(step_dict_config) + return step_parser.instantiate_classes(parsed_namespace)["step"] + + +class ArgumentParser(JAPArgumentParser): + def __init__( + self, + *args, + env_prefix: Union[bool, str] = True, + formatter_class: Type[DefaultHelpFormatter] = DefaultHelpFormatter, + exit_on_error: bool = True, + logger: Union[bool, str, dict, logging.Logger] = False, + version: Optional[str] = None, + print_config: Optional[str] = "--print_config", + parser_mode: str = "yaml", + dump_header: Optional[List[str]] = None, + default_config_files: Optional[List[Union[str, os.PathLike]]] = None, + default_env: bool = False, + default_meta: bool = True, + **kwargs, + ) -> None: + """Initializer for ArgumentParser instance. + + All the arguments from the initializer of `argparse.ArgumentParser + `_ + are supported. Additionally it accepts: + + Args: + env_prefix: Prefix for environment variables. ``True`` to derive + from ``prog``. + formatter_class: Class for printing help messages. + logger: Configures the logger, see :class:`.LoggerProperty`. + version: Program version which will be printed by the --version + argument. + print_config: Add this as argument to print config, set None to + disable. + parser_mode: Mode for parsing config files: ``'yaml'``, + ``'jsonnet'`` or ones added via :func:`.set_loader`. + dump_header: Header to include as comment when dumping a config + object. + default_config_files: Default config file locations, e.g. + :code:`['~/.config/myapp/*.yaml']`. + default_env: Set the default value on whether to parse environment + variables. + default_meta: Set the default value on whether to include metadata + in config objects. + """ + super().__init__( + *args, env_prefix=env_prefix, formatter_class=formatter_class, + exit_on_error=exit_on_error, logger=logger, version=version, + print_config=print_config, parser_mode=parser_mode, + dump_header=dump_header, default_config_files=default_config_files, + default_env=default_env, + default_meta=default_meta, **kwargs) + self.add_argument( + "-c", "--config", action=ActionConfigFile, + help="Path to a configuration file in json or yaml format." + ) + + +# class ConfigParser2: +# """ +# Deprecated: this pipeline structure does not allow for +# nested pipelines. However, it is more readable and the linking +# from name to step data could be achieved with OmegaConf. This +# could be reused in the future: left as example. + +# Parses a configuration file, merging the steps into +# the pipeline and returning a pipeline object. +# It also provides functionalities for dynamic override +# of fields by means of nested key notation. + +# Example: + +# >>> # pipeline.yaml +# >>> pipeline: +# >>> class_path: itwinai.pipeline.Pipeline +# >>> steps: [server, client] +# >>> +# >>> server: +# >>> class_path: mycode.ServerOptions +# >>> init_args: +# >>> host: localhost +# >>> port: 80 +# >>> +# >>> client: +# >>> class_path: mycode.ClientOptions +# >>> init_args: +# >>> url: http://${server.init_args.host}:${server.init_args.port}/ + +# >>> from itwinai.parser import ConfigParser2 +# >>> +# >>> parser = ConfigParser2( +# >>> config='pipeline.yaml', +# >>> override_keys={ +# >>> 'server.init_args.port': 777 +# >>> } +# >>> ) +# >>> pipeline = parser.parse_pipeline() +# >>> print(pipeline) +# >>> print(pipeline.steps) +# >>> print(pipeline.steps['server'].port) +# >>> +# >>> server = parser.parse_step('server') +# >>> print(server) +# >>> print(server.port) +# """ + +# config: Dict +# pipeline: Pipeline + +# def __init__( +# self, +# config: Union[str, Dict], +# override_keys: Optional[Dict[str, Any]] = None +# ) -> None: +# self.config = config +# self.override_keys = override_keys +# if isinstance(self.config, str): +# self.config = load_yaml(self.config) +# self._dynamic_override_keys() +# self._omegaconf_interpolate() + +# def _dynamic_override_keys(self): +# if self.override_keys is not None: +# for key_chain, value in self.override_keys.items(): +# add_replace_field(self.config, key_chain, value) + +# def _omegaconf_interpolate(self) -> None: +# """Performs variable interpolation with OmegaConf on internal +# configuration file. +# """ +# conf = OmegaConf.create(self.config) +# self.config = OmegaConf.to_container(conf, resolve=True) + +# def parse_pipeline( +# self, +# pipeline_nested_key: str = "pipeline", +# verbose: bool = False +# ) -> Pipeline: +# """Merges steps into pipeline and parses it. + +# Args: +# pipeline_nested_key (str, optional): nested key in the +# configuration file identifying the pipeline object. +# Defaults to "pipeline". +# verbose (bool): if True, prints the assembled pipeline +# to console formatted as JSON. + +# Returns: +# Pipeline: instantiated pipeline. +# """ +# pipe_parser = JAPArgumentParser() +# pipe_parser.add_subclass_arguments(Pipeline, pipeline_nested_key) +# pipe_dict = self.config[pipeline_nested_key] + +# # Pop steps list from pipeline dictionary +# steps_list = pipe_dict['steps'] +# del pipe_dict['steps'] + +# # Link steps with respective dictionaries +# if not pipe_dict.get('init_args'): +# pipe_dict['init_args'] = {} +# steps_dict = pipe_dict['init_args']['steps'] = {} +# for step_name in steps_list: +# steps_dict[step_name] = self.config[step_name] +# pipe_dict = {pipeline_nested_key: pipe_dict} + +# if verbose: +# print("Assembled pipeline:") +# print(json.dumps(pipe_dict, indent=4)) + +# # Parse pipeline dict once merged with steps +# conf = pipe_parser.parse_object(pipe_dict) +# pipe = pipe_parser.instantiate_classes(conf) +# self.pipeline = pipe[pipeline_nested_key] +# return self.pipeline + +# def parse_step( +# self, +# step_name: str, +# verbose: bool = False +# ) -> BaseComponent: +# step_dict_config = self.config[step_name] + +# if verbose: +# print(f"STEP '{step_name}' CONFIG:") +# print(json.dumps(step_dict_config, indent=4)) + +# # Wrap config under "step" field and parse it +# step_dict_config = {'step': step_dict_config} +# step_parser = JAPArgumentParser() +# step_parser.add_subclass_arguments(BaseComponent, "step") +# parsed_namespace = step_parser.parse_object(step_dict_config) +# return step_parser.instantiate_classes(parsed_namespace)["step"] + + +# class ItwinaiCLI2: +# """ +# Deprecated: the dynamic override does not work with nested parameters +# and may be confusing. + +# CLI tool for executing a configuration file, with dynamic +# override of fields and variable interpolation with Omegaconf. + +# Example: + +# >>> # train.py +# >>> from itwinai.parser import ItwinaiCLI +# >>> cli = ItwinaiCLI() +# >>> cli.pipeline.execute() + +# >>> # pipeline.yaml +# >>> pipeline: +# >>> class_path: itwinai.pipeline.Pipeline +# >>> steps: [server, client] +# >>> +# >>> server: +# >>> class_path: mycode.ServerOptions +# >>> init_args: +# >>> host: localhost +# >>> port: 80 +# >>> +# >>> client: +# >>> class_path: mycode.ClientOptions +# >>> init_args: +# >>> url: http://${server.init_args.host}:${server.init_args.port}/ + +# From command line: + +# >>> python train.py --config itwinai-conf.yaml --help +# >>> python train.py --config itwinai-conf.yaml +# >>> python train.py --config itwinai-conf.yaml --server.port 8080 +# """ +# _parser: JAPArgumentParser +# _config: Dict +# pipeline: Pipeline + +# def __init__( +# self, +# pipeline_nested_key: str = "pipeline", +# parser_mode: str = "omegaconf" +# ) -> None: +# self.pipeline_nested_key = pipeline_nested_key +# self.parser_mode = parser_mode +# self._init_parser() +# self._parser.add_argument(f"--{self.pipeline_nested_key}", type=dict) +# self._add_steps_arguments() +# self._config = self._parser.parse_args() + +# # Merge steps into pipeline and parse it +# del self._config['config'] +# pipe_parser = ConfigParser2(config=self._config.as_dict()) +# self.pipeline = pipe_parser.parse_pipeline( +# pipeline_nested_key=self.pipeline_nested_key +# ) + +# def _init_parser(self): +# self._parser = JAPArgumentParser(parser_mode=self.parser_mode) +# self._parser.add_argument( +# "-c", "--config", action=ActionConfigFile, +# required=True, +# help="Path to a configuration file in json or yaml format." +# ) + +# def _add_steps_arguments(self): +# """Pre-parses the configuration file, dynamically adding all the +# component classes under 'steps' as arguments of the parser. +# """ +# if "--config" not in sys.argv: +# raise ValueError( +# "--config parameter has to be specified with a " +# "valid path to a configuration file." +# ) +# config_path = sys.argv.index("--config") + 1 +# config_path = sys.argv[config_path] +# config = load_yaml(config_path) + +# # Add steps to parser +# steps = filter( +# lambda itm: itm[0] != self.pipeline_nested_key, +# config.items() +# ) +# steps = { +# step_name: step_data['class_path'] +# for step_name, step_data in steps +# } + +# for st_nested_key, step_class_str in steps.items(): +# step_class = dynamically_import_class(step_class_str) +# self._add_step_arguments( +# step_class=step_class, nested_key=st_nested_key) + +# def _add_step_arguments(self, step_class, nested_key): +# self._parser.add_subclass_arguments( +# baseclass=step_class, nested_key=nested_key) diff --git a/src/itwinai/pipeline.py b/src/itwinai/pipeline.py new file mode 100644 index 00000000..1391bfef --- /dev/null +++ b/src/itwinai/pipeline.py @@ -0,0 +1,101 @@ +""" +This module provides the functionalities to execute workflows defined in +in form of pipelines. +""" +from __future__ import annotations +from typing import Iterable, Dict, Any, Tuple, Union, Optional + +from .components import BaseComponent, monitor_exec +from .utils import SignatureInspector + + +class Pipeline(BaseComponent): + """Executes a set of components arranged as a pipeline.""" + + steps: Union[Dict[str, BaseComponent], Iterable[BaseComponent]] + + def __init__( + self, + steps: Union[Dict[str, BaseComponent], Iterable[BaseComponent]], + name: Optional[str] = None + ): + super().__init__(name=name) + self.save_parameters(steps=steps, name=name) + self.steps = steps + + def __getitem__(self, subscript: Union[str, int, slice]) -> Pipeline: + if isinstance(subscript, slice): + # First, convert to list if is a dict + if isinstance(self.steps, dict): + steps = list(self.steps.items()) + else: + steps = self.steps + # Second, perform slicing + s = steps[subscript.start:subscript.stop: subscript.step] + # Third, reconstruct dict, if it is a dict + if isinstance(self.steps, dict): + s = dict(s) + # Fourth, return sliced sub-pipeline, preserving its + # initial structure + sliced = self.__class__( + steps=s, + name=self.name + ) + return sliced + else: + return self.steps[subscript] + + def __len__(self) -> int: + return len(self.steps) + + @monitor_exec + def execute(self, *args) -> Any: + """"Execute components sequentially.""" + if isinstance(self.steps, dict): + steps = list(self.steps.values()) + else: + steps = self.steps + + for step in steps: + step: BaseComponent + args = self._pack_args(args) + self.validate_args(args, step) + args = step.execute(*args) + + return args + + @staticmethod + def _pack_args(args) -> Tuple: + """Wraps args in a tuple, if needed.""" + args = () if args is None else args + if not isinstance(args, tuple): + args = (args,) + return args + + @staticmethod + def validate_args(input_args: Tuple, component: BaseComponent): + """Verify that the number of input args provided to some component + match with the number of the non-default args in the component. + + Args: + input_args (Tuple): input args to be fed to the component. + component (BaseComponent): component to be executed. + + Raises: + RuntimeError: in case of args mismatch. + """ + inspector = SignatureInspector(component.execute) + if inspector.min_params_num > len(input_args): + raise TypeError( + f"Component '{component.name}' received too few " + f"input arguments: {input_args}. Expected at least " + f"{inspector.min_params_num}, with names: " + f"{inspector.required_params}." + ) + if (inspector.max_params_num != inspector.INFTY + and len(input_args) > inspector.max_params_num): + raise TypeError( + f"Component '{component.name}' received too many " + f"input arguments: {input_args}. Expected at most " + f"{inspector.max_params_num}." + ) diff --git a/src/itwinai/serialization.py b/src/itwinai/serialization.py index a7b70cd3..9c1c8563 100644 --- a/src/itwinai/serialization.py +++ b/src/itwinai/serialization.py @@ -1,8 +1,170 @@ -from .types import ModelML +from typing import Dict, Any, Union import abc +import json +import yaml +from pathlib import Path +import inspect +from .types import MLModel +from .utils import SignatureInspector -class ModelLoader(abc.ABC): + +def is_jsonable(x): + try: + json.dumps(x) + return True + except Exception: + return False + + +def fullname(o): + klass = o.__class__ + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + '.' + klass.__qualname__ + + +class SerializationError(Exception): + ... + + +class Serializable: + parameters: Dict[Any, Any] = None + + def save_parameters(self, **kwargs) -> None: + """Simplified way to store constructor arguments in as class + attributes. Keeps track of the parameters to enable + YAML/JSON serialization. + """ + if self.parameters is None: + self.parameters = {} + self.parameters.update(kwargs) + + # for k, v in kwargs.items(): + # self.__setattr__(k, v) + + @staticmethod + def locals2params(locals: Dict, pop_self: bool = True) -> Dict: + """Remove ``self`` from the output of ``locals()``. + + Args: + locals (Dict): output of ``locals()`` called in the constructor + of a class. + pop_self (bool, optional): whether to remove ``self``. + Defaults to True. + + Returns: + Dict: cleaned ``locals()``. + """ + if pop_self: + locals.pop('self', None) + return locals + + def update_parameters(self, **kwargs) -> None: + """Updates stored parameters.""" + self.save_parameters(**kwargs) + + def to_dict(self) -> Dict: + """Returns a dict serialization of the current object.""" + self._validate_parameters() + class_path = self._get_class_path() + init_args = dict() + for par_name, par in self._saved_constructor_parameters().items(): + init_args[par_name] = self._recursive_serialization(par, par_name) + return dict(class_path=class_path, init_args=init_args) + + def _validate_parameters(self) -> None: + if self.parameters is None: + raise SerializationError( + f"{self.__class__.__name__} cannot be serialized " + "because its constructor arguments were not saved. " + "Please add 'self.save_parameters(param_1=param_1, " + "..., param_n=param_n)' as first instruction of its " + "constructor." + ) + + init_inspector = SignatureInspector(self.__init__) + for par_name in init_inspector.required_params: + if self.parameters.get(par_name) is None: + raise SerializationError( + f"Required parameter '{par_name}' of " + f"{self.__class__.__name__} class not present in " + "saved parameters. " + "Please add 'self.save_parameters(param_1=param_1, " + "..., param_n=param_n)' as first instruction of its " + f"constructor, including also '{par_name}'." + ) + + def _get_class_path(self) -> str: + class_path = fullname(self) + if "" in class_path: + raise SerializationError( + f"{self.__class__.__name__} is " + "defined locally, which is not supported for serialization. " + "Move the class to a separate Python file and import it " + "from there." + ) + return class_path + + def _saved_constructor_parameters(self) -> Dict[str, Any]: + """Extracts the current constructor parameters from all + the saved parameters, as some of them may had been added by + superclasses. + + Returns: + Dict[str, Any]: subset of saved parameters containing only + the constructor parameters for this class. + """ + init_params = inspect.signature(self.__init__).parameters.items() + init_par_nam = map(lambda x: x[0], init_params) + return { + par_name: self.parameters[par_name] for par_name in init_par_nam + if self.parameters.get(par_name, inspect._empty) != inspect._empty + } + + def _recursive_serialization(self, item: Any, item_name: str) -> Any: + if isinstance(item, (tuple, list, set)): + return [self._recursive_serialization(x, item_name) for x in item] + elif isinstance(item, dict): + return { + k: self._recursive_serialization(v, item_name) + for k, v in item.items() + } + elif is_jsonable(item): + return item + elif isinstance(item, Serializable): + return item.to_dict() + else: + raise SerializationError( + f"{self.__class__.__name__} cannot be serialized " + f"because its constructor argument '{item_name}' " + "is not a Python built-in type and it does not " + "extend 'itwinai.serialization.Serializable' class." + ) + + def to_json(self, file_path: Union[str, Path], nested_key: str) -> None: + """Save a component to JSON file. + + Args: + file_path (Union[str, Path]): JSON file path. + nested_key (str): root field containing the serialized object. + """ + with open(file_path, "w") as fp: + json.dump({nested_key: self.to_dict()}, fp) + + def to_yaml(self, file_path: Union[str, Path], nested_key: str) -> None: + """Save a component to YAML file. + + Args: + file_path (Union[str, Path]): YAML file path. + nested_key (str): root field containing the serialized object. + """ + with open(file_path, "w") as fp: + yaml.dump({nested_key: self.to_dict()}, fp) + + +class ModelLoader(abc.ABC, Serializable): """Loads a machine learning model from somewhere.""" def __init__(self, model_uri: str) -> None: @@ -10,5 +172,5 @@ def __init__(self, model_uri: str) -> None: self.model_uri = model_uri @abc.abstractmethod - def __call__(self) -> ModelML: + def __call__(self) -> MLModel: """Loads model from model URI.""" diff --git a/src/itwinai/tensorflow/trainer.py b/src/itwinai/tensorflow/trainer.py index 3f51f000..f1a10214 100644 --- a/src/itwinai/tensorflow/trainer.py +++ b/src/itwinai/tensorflow/trainer.py @@ -4,7 +4,7 @@ from jsonargparse import ArgumentParser import tensorflow as tf -from ..components import Trainer +from ..components import Trainer, monitor_exec def import_class(name): @@ -38,6 +38,7 @@ def __init__( strategy ): super().__init__() + self.save_parameters(**self.locals2params(locals())) self.strategy = strategy self.epochs = epochs self.batch_size = batch_size @@ -96,7 +97,8 @@ def instantiate_compile_conf(conf: Dict) -> Dict: conf[item_name] = instance_from_dict(item) return conf - def train(self, train_dataset, validation_dataset): + @monitor_exec + def execute(self, train_dataset, validation_dataset) -> Any: # Set batch size to the dataset # train = train.batch(self.batch_size, drop_remainder=True) # test = test.batch(self.batch_size, drop_remainder=True) @@ -169,7 +171,8 @@ def train(self, train_dataset, validation_dataset): # # TODO: move loss, optimizer and metrics instantiation under # # here # # Ref: -# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit +# # https://www.tensorflow.org/guide/distributed_training\ +# #use_tfdistributestrategy_with_keras_modelfit # else: # self.model = parser.instantiate_classes(model_dict).model # self.model.compile(**compile_conf) @@ -191,8 +194,10 @@ def train(self, train_dataset, validation_dataset): # n_test = test.cardinality().numpy() # # TODO: read -# # https://github.com/tensorflow/tensorflow/issues/56773#issuecomment-1188693881 -# # https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit +# # https://github.com/tensorflow/tensorflow/issues/56773\ +# #issuecomment-1188693881 +# # https://www.tensorflow.org/guide/distributed_training\ +# #use_tfdistributestrategy_with_keras_modelfit # # Distribute dataset # if self.strategy: diff --git a/src/itwinai/tests/__init__.py b/src/itwinai/tests/__init__.py new file mode 100644 index 00000000..5486fb7a --- /dev/null +++ b/src/itwinai/tests/__init__.py @@ -0,0 +1,11 @@ +from .dummy_components import ( + FakeGetter, FakeGetterExec, FakePreproc, FakePreprocExec, + FakeSaver, FakeSaverExec, FakeSplitter, FakeSplitterExec, + FakeTrainer, FakeTrainerExec +) + +_ = ( + FakeGetter, FakeGetterExec, FakePreproc, FakePreprocExec, + FakeSaver, FakeSaverExec, FakeSplitter, FakeSplitterExec, + FakeTrainer, FakeTrainerExec +) diff --git a/src/itwinai/tests/dummy_components.py b/src/itwinai/tests/dummy_components.py new file mode 100644 index 00000000..b60f1df0 --- /dev/null +++ b/src/itwinai/tests/dummy_components.py @@ -0,0 +1,97 @@ +from typing import Optional +from ..components import BaseComponent, monitor_exec + + +class FakeGetter(BaseComponent): + def __init__(self, data_uri: str, name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters(data_uri=data_uri, name=name) + self.data_uri = data_uri + + def execute(self): + ... + + +class FakeGetterExec(FakeGetter): + result: str = "dataset" + + @monitor_exec + def execute(self): + return self.result + + +class FakeSplitter(BaseComponent): + def __init__(self, train_prop: float, name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters(train_prop=train_prop, name=name) + self.train_prop = train_prop + + def execute(self): + ... + + +class FakeSplitterExec(FakeSplitter): + result: tuple = ("train_dataset", "val_dataset", "test_dataset") + + @monitor_exec + def execute(self, dataset): + return self.result + + +class FakePreproc(BaseComponent): + def __init__(self, max_items: int, name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters(max_items=max_items, name=name) + self.max_items = max_items + + def execute(self): + ... + + +class FakePreprocExec(FakePreproc): + @monitor_exec + def execute(self, train_dataset, val_dataset, test_dataset): + return train_dataset, val_dataset, test_dataset + + +class FakeTrainer(BaseComponent): + def __init__( + self, + lr: float, + batch_size: int, + name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters(lr=lr, batch_size=batch_size, name=name) + self.lr = lr + self.batch_size = batch_size + + def execute(self): + ... + + +class FakeTrainerExec(FakeTrainer): + model: str = "trained_model" + + @monitor_exec + def execute(self, train_dataset, val_dataset, test_dataset): + return train_dataset, val_dataset, test_dataset, self.model + + +class FakeSaver(BaseComponent): + def __init__(self, save_path: str, name: Optional[str] = None) -> None: + super().__init__(name) + self.save_parameters(save_path=save_path, name=name) + self.save_path = save_path + + def execute(self): + ... + + +class FakeSaverExec(FakeSaver): + @monitor_exec + def execute(self, artifact): + return artifact diff --git a/src/itwinai/torch/inference.py b/src/itwinai/torch/inference.py index 4d7797c6..02882f06 100644 --- a/src/itwinai/torch/inference.py +++ b/src/itwinai/torch/inference.py @@ -8,7 +8,7 @@ from ..utils import dynamically_import_class from .utils import clear_key -from ..components import Predictor +from ..components import Predictor, monitor_exec from .types import TorchDistributedStrategy as StrategyT from .types import Metric, Batch from ..serialization import ModelLoader @@ -93,6 +93,7 @@ def __init__( name: str = None ) -> None: super().__init__(model=model, name=name) + self.save_parameters(**self.locals2params(locals())) self.model = self.model.eval() # self.seed = seed # self.strategy = strategy @@ -122,7 +123,8 @@ def __init__( # else validation_metrics # ) - def predict( + @monitor_exec + def execute( self, test_dataset: Dataset, model: nn.Module = None, diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 6d8a1771..31794c49 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -17,7 +17,7 @@ import torch.nn as nn from torch.optim.optimizer import Optimizer -from ..components import Trainer +from ..components import Trainer, monitor_exec from .utils import seed_worker, par_allgather_obj, clear_key from .types import ( Batch, Loss, LrScheduler, Metric @@ -205,6 +205,7 @@ def __init__( Makes the model a DDP model. """ super().__init__() + self.save_parameters(**self.locals2params(locals())) self.model = model self.loss = loss self.epochs = epochs @@ -309,6 +310,7 @@ def set_seed(self, seed: Optional[int] = None): if self.cluster.is_cuda_available(): torch.cuda.manual_seed(seed) + @monitor_exec def execute( self, train_dataset: Dataset, @@ -316,8 +318,7 @@ def execute( model: nn.Module = None, optimizer: Optimizer = None, lr_scheduler: LrScheduler = None, - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: + ) -> Any: self.train_dataset = train_dataset self.validation_dataset = validation_dataset @@ -337,7 +338,7 @@ def execute( result = self._train(0) # Return value compliant with Executable.execute format - return ((result,), config) + return result def _train( self, diff --git a/src/itwinai/types.py b/src/itwinai/types.py index 9c302eb1..977068b9 100644 --- a/src/itwinai/types.py +++ b/src/itwinai/types.py @@ -3,9 +3,13 @@ """ -class DatasetML: +class MLArtifact: + """A framework-independent machine learning artifact.""" + + +class MLDataset(MLArtifact): """A framework-independent machine learning dataset.""" -class ModelML: +class MLModel(MLArtifact): """A framework-independent machine learning model.""" diff --git a/src/itwinai/utils.py b/src/itwinai/utils.py index 1314423a..52279aeb 100644 --- a/src/itwinai/utils.py +++ b/src/itwinai/utils.py @@ -1,8 +1,10 @@ """ Utilities for itwinai package. """ -from typing import Dict, Type +from typing import Dict, Type, Callable, Tuple import os +import sys +import inspect from collections.abc import MutableMapping import yaml from omegaconf import OmegaConf @@ -67,9 +69,25 @@ def dynamically_import_class(name: str) -> Type: Returns: __class__: class type. """ - module, class_name = name.rsplit(".", 1) - mod = __import__(module, fromlist=[class_name]) - klass = getattr(mod, class_name) + try: + module, class_name = name.rsplit(".", 1) + mod = __import__(module, fromlist=[class_name]) + klass = getattr(mod, class_name) + except ModuleNotFoundError as err: + print( + f"Module not found when trying to dynamically import '{name}'. " + "Make sure that the module's file is reachable from your current " + "directory." + ) + raise err + except Exception as err: + print( + f"Exception occurred when trying to dynamically import '{name}'. " + "Make sure that the module's file is reachable from your current " + "directory and that the class is present in that module." + ) + raise err + return klass @@ -107,3 +125,59 @@ def parse_pipe_config(yaml_file, parser): raise exc return parser.parse_object(config) + + +class SignatureInspector: + """Provides the functionalities to inspect the signature of a function + or a method. + + Args: + func (Callable): function to be inspected. + """ + + INFTY: int = sys.maxsize + + def __init__(self, func: Callable) -> None: + self.func = func + self.func_params = inspect.signature(func).parameters.items() + + @property + def has_varargs(self) -> bool: + """Checks if the function has ``*args`` parameter.""" + return any(map( + lambda p: p[1].kind == p[1].VAR_POSITIONAL, + self.func_params + )) + + @property + def has_kwargs(self) -> bool: + """Checks if the function has ``**kwargs`` parameter.""" + return any(map( + lambda p: p[1].kind == p[1].VAR_KEYWORD, + self.func_params + )) + + @property + def required_params(self) -> Tuple[str]: + """Names of required parameters. Class method's 'self' is skipped.""" + required_params = list(filter( + lambda p: (p[0] != 'self' and p[1].default == inspect._empty + and p[1].kind != p[1].VAR_POSITIONAL + and p[1].kind != p[1].VAR_KEYWORD), + self.func_params + )) + return tuple(map(lambda p: p[0], required_params)) + + @property + def min_params_num(self) -> int: + """Minimum number of arguments required.""" + return len(self.required_params) + + @property + def max_params_num(self) -> int: + """Max number of supported input arguments. + If no limit, ``SignatureInspector.INFTY`` is returned. + """ + if self.has_kwargs or self.has_varargs: + return self.INFTY + return len(self.func_params) diff --git a/tests/components/conftest.py b/tests/components/conftest.py new file mode 100644 index 00000000..0ba66af1 --- /dev/null +++ b/tests/components/conftest.py @@ -0,0 +1,72 @@ +import pytest + +pytest.PIPE_LIST_YAML = """ +my-list-pipeline: + class_path: itwinai.pipeline.Pipeline + init_args: + steps: + - class_path: itwinai.tests.dummy_components.FakePreproc + init_args: + max_items: 32 + name: my-preproc + + - class_path: itwinai.tests.dummy_components.FakeTrainer + init_args: + lr: 0.001 + batch_size: 32 + name: my-trainer + + - class_path: itwinai.tests.dummy_components.FakeSaver + init_args: + save_path: ./some/path + name: my-saver +""" + +pytest.PIPE_DICT_YAML = """ +my-dict-pipeline: + class_path: itwinai.pipeline.Pipeline + init_args: + steps: + preproc-step: + class_path: itwinai.tests.dummy_components.FakePreproc + init_args: + max_items: 32 + name: my-preproc + + train-step: + class_path: itwinai.tests.dummy_components.FakeTrainer + init_args: + lr: 0.001 + batch_size: 32 + name: my-trainer + + save-step: + class_path: itwinai.tests.dummy_components.FakeSaver + init_args: + save_path: ./some/path + name: my-saver +""" + +pytest.NESTED_PIPELINE = """ +some: + field: + nst-pipeline: + class_path: itwinai.pipeline.Pipeline + init_args: + steps: + - class_path: itwinai.tests.FakePreproc + init_args: + max_items: 32 + name: my-preproc + + - class_path: itwinai.tests.FakeTrainer + init_args: + lr: 0.001 + batch_size: 32 + name: my-trainer + + - class_path: itwinai.tests.FakeSaver + init_args: + save_path: ./some/path + name: my-saver +""" diff --git a/tests/components/test_components.py b/tests/components/test_components.py new file mode 100644 index 00000000..364b4917 --- /dev/null +++ b/tests/components/test_components.py @@ -0,0 +1,156 @@ +import pytest + +from itwinai.components import Trainer, Adapter +from itwinai.pipeline import Pipeline +from itwinai.tests import ( + FakeGetterExec, FakeSplitterExec, FakeTrainerExec, FakeSaverExec +) +from itwinai.serialization import SerializationError + + +def test_serializable(): + """Test serialization of components.""" + comp = FakeGetterExec(data_uri='http://...') + dict_serializ = comp.to_dict() + assert isinstance(dict_serializ, dict) + assert comp.name == "FakeGetterExec" + assert dict_serializ == dict( + class_path="itwinai.tests.dummy_components.FakeGetterExec", + init_args=dict(data_uri='http://...', name=None) + ) + + # List + comp = FakeGetterExec(data_uri=[1, 2, 3]) + dict_serializ = comp.to_dict() + assert isinstance(dict_serializ, dict) + assert comp.name == "FakeGetterExec" + assert dict_serializ == dict( + class_path="itwinai.tests.dummy_components.FakeGetterExec", + init_args=dict(data_uri=[1, 2, 3], name=None) + ) + + # Tuple + comp = FakeGetterExec(data_uri=(1, 2, 3)) + dict_serializ = comp.to_dict() + assert isinstance(dict_serializ, dict) + assert comp.name == "FakeGetterExec" + assert dict_serializ == dict( + class_path="itwinai.tests.dummy_components.FakeGetterExec", + init_args=dict(data_uri=[1, 2, 3], name=None) + ) + + # Set + comp = FakeGetterExec(data_uri={1, 2, 3}) + dict_serializ = comp.to_dict() + assert isinstance(dict_serializ, dict) + assert comp.name == "FakeGetterExec" + assert dict_serializ == dict( + class_path="itwinai.tests.dummy_components.FakeGetterExec", + init_args=dict(data_uri=[1, 2, 3], name=None) + ) + + # Dict + comp = FakeGetterExec(data_uri=dict(foo=12, bar="123", hl=3.14)) + dict_serializ = comp.to_dict() + assert isinstance(dict_serializ, dict) + assert comp.name == "FakeGetterExec" + assert dict_serializ == dict( + class_path="itwinai.tests.dummy_components.FakeGetterExec", + init_args=dict(data_uri=dict(foo=12, bar="123", hl=3.14), name=None) + ) + + # Non serializable obj + class NonSerializable: + ... + + comp = FakeGetterExec(data_uri=NonSerializable()) + with pytest.raises(SerializationError) as exc_info: + dict_serializ = comp.to_dict() + assert ("is not a Python built-in type and it does not extend" + in str(exc_info.value)) + + # Local component class + class MyTrainer(Trainer): + def execute(self): + ... + + def save_state(self): + ... + + def load_state(self): + ... + comp = MyTrainer() + with pytest.raises(SerializationError) as exc_info: + dict_serializ = comp.to_dict() + assert ("is defined locally, which is not supported for serialization." + in str(exc_info.value)) + + +def test_adapter(): + """Test Adapter component.""" + prefix = Adapter.INPUT_PREFIX + adapter = Adapter( + policy=[f"{prefix}{3-i}" for i in range(4)] + ) + result = adapter.execute(0, 1, 2, 3) + assert result == (3, 2, 1, 0) + + result = adapter.execute(*tuple(range(100))) + assert result == (3, 2, 1, 0) + + adapter = Adapter( + policy=[f"{prefix}0" for i in range(4)] + ) + result = adapter.execute(0, 1, 2, 3) + assert result == (0, 0, 0, 0) + + adapter = Adapter( + policy=[f"{prefix}{i%2}" for i in range(4)] + ) + result = adapter.execute(0, 1, 2, 3) + assert result == (0, 1, 0, 1) + + adapter = Adapter( + policy=[f"{prefix}2", "hello", "world", 3.14] + ) + result = adapter.execute(0, 1, 2, 3) + assert result == (2, "hello", "world", 3.14) + + adapter = Adapter( + policy=[1, 3, 5, 7, 11] + ) + result = adapter.execute(0, 1, 2, 3) + assert result == (1, 3, 5, 7, 11) + + adapter = Adapter( + policy=[f"{prefix}{9-i}" for i in range(10)] + ) + with pytest.raises(IndexError) as exc_info: + result = adapter.execute(0, 1) + assert str(exc_info.value) == ( + "The args received as input by 'Adapter' are not consistent with " + "the given adapter policy because input args are too few! Input " + "args are 2 but the policy foresees at least 10 items." + ) + + adapter = Adapter( + policy=[] + ) + result = adapter.execute(*tuple(range(100))) + assert result == () + + +@pytest.mark.integration +def test_adapter_integration_pipeline(): + """Test integration of Adapter component in the pipeline, + connecting other components. + """ + pipeline = Pipeline([ + FakeGetterExec(data_uri='http://...'), + FakeSplitterExec(train_prop=.7), + FakeTrainerExec(lr=1e-3, batch_size=32), + Adapter(policy=[f"{Adapter.INPUT_PREFIX}-1"]), + FakeSaverExec(save_path="my_model.pth") + ]) + saved_model = pipeline.execute() + assert saved_model == FakeTrainerExec.model diff --git a/tests/components/test_pipe_parser.py b/tests/components/test_pipe_parser.py new file mode 100644 index 00000000..f26d105d --- /dev/null +++ b/tests/components/test_pipe_parser.py @@ -0,0 +1,216 @@ +import yaml +import pytest + +from itwinai.components import BaseComponent +from itwinai.parser import ConfigParser, add_replace_field +from itwinai.tests import FakeTrainer, FakePreproc, FakeSaver + + +def test_add_replace_field(): + conf = {} + add_replace_field(conf, "some.key.chain", 123) + target1 = dict(some=dict(key=dict(chain=123))) + assert conf == target1 + + add_replace_field(conf, "some.key.chain", 222) + target2 = dict(some=dict(key=dict(chain=222))) + assert conf == target2 + + add_replace_field(conf, "some.key.field", 333) + target3 = dict(some=dict(key=dict(chain=222, field=333))) + assert conf == target3 + + conf['some']['list'] = [1, 2, 3] + add_replace_field(conf, "some.list.0", 3) + target4 = dict(some=dict( + key=dict(chain=222, field=333), + list=[3, 2, 3] + )) + assert conf == target4 + + add_replace_field(conf, "some.list.0.some.el", 7) + target5 = dict(some=dict( + key=dict(chain=222, field=333), + list=[dict(some=dict(el=7)), 2, 3] + )) + assert conf == target5 + + conf2 = dict(first=dict(list1=[[0, 1], [2, 3]], el=0)) + add_replace_field(conf2, "first.list1.1.0", 77) + target6 = dict(first=dict(list1=[[0, 1], [77, 3]], el=0)) + assert conf2 == target6 + + conf3 = dict(first=dict( + list1=[[0, dict(nst=("el", dict(ciao="ciao")))], [2, 3]], el=0)) + add_replace_field(conf3, "first.list1.0.1.nst.1.ciao", "hello") + target7 = dict(first=dict( + list1=[[0, dict(nst=("el", dict(ciao="hello")))], [2, 3]], el=0)) + assert conf3 == target7 + + add_replace_field(conf3, "first.list1.0.1.nst.1.ciao.I.am.john", True) + target8 = dict(first=dict( + list1=[ + [0, dict(nst=("el", dict(ciao=dict(I=dict(am=dict(john=True))))))], + [2, 3] + ], el=0)) + assert conf3 == target8 + + +def test_parse_list_pipeline(): + """Parse a pipeline from config file, + where the pipeline is define as a list of components. + """ + config = yaml.safe_load(pytest.PIPE_LIST_YAML) + parser = ConfigParser(config=config) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-list-pipeline" + ) + + assert isinstance(pipe.steps, list) + for step in pipe.steps: + assert isinstance(step, BaseComponent) + + +def test_parse_dict_pipeline(): + """Parse a pipeline from config file, + where the pipeline is define as a dict of components. + """ + config = yaml.safe_load(pytest.PIPE_DICT_YAML) + parser = ConfigParser(config=config) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-dict-pipeline" + ) + + assert isinstance(pipe.steps, dict) + for step in pipe.steps.values(): + assert isinstance(step, BaseComponent) + + +def test_parse_non_existing_pipeline(): + """Parse a pipeline from config file, + where the pipeline key is wrong. + """ + config = yaml.safe_load(pytest.PIPE_DICT_YAML) + parser = ConfigParser(config=config) + with pytest.raises(KeyError): + _ = parser.parse_pipeline( + pipeline_nested_key="non-existing-pipeline" + ) + + +def test_parse_nested_pipeline(): + """Parse a pipeline from config file, + where the pipeline key is nested. + """ + config = yaml.safe_load(pytest.NESTED_PIPELINE) + parser = ConfigParser(config=config) + _ = parser.parse_pipeline( + pipeline_nested_key="some.field.nst-pipeline" + ) + + +def test_dynamic_override_parser_pipeline_dict(): + """Parse a pipeline from config file, + and verify that dynamic override works + in a pipeline composed of a dict of components. + """ + config = yaml.safe_load(pytest.PIPE_DICT_YAML) + + override_keys = { + "my-dict-pipeline.init_args.steps.preproc-step.init_args.max_items": 33 + } + parser = ConfigParser(config=config, override_keys=override_keys) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-dict-pipeline" + ) + assert pipe.steps['preproc-step'].max_items == 33 + + +def test_dynamic_override_parser_pipeline_list(): + """Parse a pipeline from config file, + and verify that dynamic override works + in a pipeline composed of a list of components. + """ + config = yaml.safe_load(pytest.PIPE_LIST_YAML) + + override_keys = { + "my-list-pipeline.init_args.steps.0.init_args.max_items": 42 + } + parser = ConfigParser(config=config, override_keys=override_keys) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-list-pipeline" + ) + assert pipe.steps[0].max_items == 42 + + +def test_parse_step_list_pipeline(): + """Parse a pipeline step from config file, + where the pipeline is define as a list of components. + """ + config = yaml.safe_load(pytest.PIPE_LIST_YAML) + parser = ConfigParser(config=config) + step = parser.parse_step( + step_idx=1, + pipeline_nested_key="my-list-pipeline" + ) + + assert isinstance(step, BaseComponent) + assert isinstance(step, FakeTrainer) + + with pytest.raises(IndexError): + _ = parser.parse_step( + step_idx=12, + pipeline_nested_key="my-list-pipeline" + ) + with pytest.raises(TypeError): + _ = parser.parse_step( + step_idx='my-step-name', + pipeline_nested_key="my-list-pipeline" + ) + + +def test_parse_step_dict_pipeline(): + """Parse a pipeline step from config file, + where the pipeline is define as a dict of components. + """ + config = yaml.safe_load(pytest.PIPE_DICT_YAML) + parser = ConfigParser(config=config) + step = parser.parse_step( + step_idx='preproc-step', + pipeline_nested_key="my-dict-pipeline" + ) + + assert isinstance(step, BaseComponent) + assert isinstance(step, FakePreproc) + + with pytest.raises(KeyError): + _ = parser.parse_step( + step_idx='unk-step', + pipeline_nested_key="my-dict-pipeline" + ) + with pytest.raises(KeyError): + _ = parser.parse_step( + step_idx=0, + pipeline_nested_key="my-dict-pipeline" + ) + + +def test_parse_step_nested_pipeline(): + """Parse a pipeline step from config file, + where the pipeline is nested under some field. + """ + config = yaml.safe_load(pytest.NESTED_PIPELINE) + parser = ConfigParser(config=config) + step = parser.parse_step( + step_idx=2, + pipeline_nested_key="some.field.nst-pipeline" + ) + + assert isinstance(step, BaseComponent) + assert isinstance(step, FakeSaver) + + with pytest.raises(KeyError): + _ = parser.parse_step( + step_idx=2, + pipeline_nested_key="my-pipeline" + ) diff --git a/tests/components/test_pipeline.py b/tests/components/test_pipeline.py new file mode 100644 index 00000000..a61198b6 --- /dev/null +++ b/tests/components/test_pipeline.py @@ -0,0 +1,83 @@ +import yaml +import pytest + +from itwinai.pipeline import Pipeline +from itwinai.parser import ConfigParser +from itwinai.tests import ( + FakeGetterExec, FakeSplitterExec, FakeTrainerExec, FakeSaverExec +) + + +def test_slice_into_sub_pipelines(): + """Test slicing the pipeline to obtain + sub-pipelines as Pipeline objects. + """ + p = Pipeline(['step1', 'step2', 'step3']) + sub_pipe1, sub_pipe2 = p[:1], p[1:] + assert isinstance(sub_pipe1, Pipeline) + assert isinstance(sub_pipe2, Pipeline) + assert len(sub_pipe1) == 1 + assert sub_pipe1[0] == "step1" + assert len(sub_pipe2) == 2 + + p = Pipeline(dict(step1="step1", step2="step2", step3="step3")) + sub_pipe1, sub_pipe2 = p[:1], p[1:] + assert isinstance(sub_pipe1, Pipeline) + assert isinstance(sub_pipe2, Pipeline) + assert len(sub_pipe1) == 1 + assert sub_pipe1["step1"] == "step1" + assert len(sub_pipe2) == 2 + + +def test_serialization_pipe_list(): + """Test dict serialization of pipeline + defined as list of BaseComponent objects. + """ + config = yaml.safe_load(pytest.PIPE_LIST_YAML) + parser = ConfigParser(config=config) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-list-pipeline" + ) + + dict_pipe = pipe.to_dict() + del dict_pipe['init_args']['name'] + dict_pipe = {"my-list-pipeline": dict_pipe} + assert dict_pipe == config + + +def test_serialization_pipe_dict(): + """Test dict serialization of pipeline + defined as dict of BaseComponent objects. + """ + config = yaml.safe_load(pytest.PIPE_DICT_YAML) + parser = ConfigParser(config=config) + pipe = parser.parse_pipeline( + pipeline_nested_key="my-dict-pipeline" + ) + + dict_pipe = pipe.to_dict() + del dict_pipe['init_args']['name'] + dict_pipe = {"my-dict-pipeline": dict_pipe} + assert dict_pipe == config + + +def test_arguments_mismatch(): + """Test mismatch of arguments passed among components in a pipeline.""" + pipeline = Pipeline([ + FakeGetterExec(data_uri='http://...'), + FakeSplitterExec(train_prop=.7), + FakeTrainerExec(lr=1e-3, batch_size=32), + # Adapter(policy=[f"{Adapter.INPUT_PREFIX}-1"]), + FakeSaverExec(save_path="my_model.pth") + ]) + with pytest.raises(TypeError) as exc_info: + _ = pipeline.execute() + assert "received too many input arguments" in str(exc_info.value) + + pipeline = Pipeline([ + FakeGetterExec(data_uri='http://...'), + FakeTrainerExec(lr=1e-3, batch_size=32), + ]) + with pytest.raises(TypeError) as exc_info: + _ = pipeline.execute() + assert "received too few input arguments" in str(exc_info.value) diff --git a/tests/test_components.py b/tests/test_components.py deleted file mode 100644 index f7396214..00000000 --- a/tests/test_components.py +++ /dev/null @@ -1,9 +0,0 @@ -from itwinai.components import Executor - - -def test_slice(): - p = Executor(['step1', 'step2', 'step3'], pippo=2) - assert len(p[:1]) == 1 - assert p[:1][0] == 'step1' - assert len(p[1:]) == 2 - assert p[1:].constructor_args['pippo'] == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py index bbeb61fa..5fb7b936 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ Tests for itwinai.utils module. """ -from itwinai.utils import flatten_dict +from itwinai.utils import flatten_dict, SignatureInspector def test_flatten_dict(): @@ -16,3 +16,84 @@ def test_flatten_dict(): assert flattened.get("b.b1") == 2 assert flattened.get("b.b2") == 3 assert len(flattened) == 3 + + +def test_signature_inspector(): + """Test SignatureInspector class.""" + def f(): + ... + + inspector = SignatureInspector(f) + assert not inspector.has_varargs + assert not inspector.has_kwargs + assert inspector.required_params == () + assert inspector.min_params_num == 0 + assert inspector.max_params_num == 0 + + def f(*args): + ... + + inspector = SignatureInspector(f) + assert inspector.has_varargs + assert not inspector.has_kwargs + assert inspector.required_params == () + assert inspector.min_params_num == 0 + assert inspector.max_params_num == SignatureInspector.INFTY + + def f(foo, *args): + ... + + inspector = SignatureInspector(f) + assert inspector.has_varargs + assert not inspector.has_kwargs + assert inspector.required_params == ("foo",) + assert inspector.min_params_num == 1 + assert inspector.max_params_num == SignatureInspector.INFTY + + def f(foo, bar=123): + ... + + inspector = SignatureInspector(f) + assert not inspector.has_varargs + assert not inspector.has_kwargs + assert inspector.required_params == ("foo",) + assert inspector.min_params_num == 1 + assert inspector.max_params_num == 2 + + def f(foo, *args, bar=123): + ... + + inspector = SignatureInspector(f) + assert inspector.has_varargs + assert not inspector.has_kwargs + assert inspector.required_params == ("foo",) + assert inspector.min_params_num == 1 + assert inspector.max_params_num == SignatureInspector.INFTY + + def f(*args, **kwargs): + ... + + inspector = SignatureInspector(f) + assert inspector.has_varargs + assert inspector.has_kwargs + assert inspector.required_params == () + assert inspector.min_params_num == 0 + assert inspector.max_params_num == SignatureInspector.INFTY + + def f(foo, /, bar, *arg, **kwargs): + ... + inspector = SignatureInspector(f) + assert inspector.has_varargs + assert inspector.has_kwargs + assert inspector.required_params == ("foo", "bar") + assert inspector.min_params_num == 2 + assert inspector.max_params_num == SignatureInspector.INFTY + + def f(foo, /, bar, *, hello, **kwargs): + ... + inspector = SignatureInspector(f) + assert not inspector.has_varargs + assert inspector.has_kwargs + assert inspector.required_params == ("foo", "bar", "hello") + assert inspector.min_params_num == 3 + assert inspector.max_params_num == SignatureInspector.INFTY diff --git a/tests/use-cases/conftest.py b/tests/use-cases/conftest.py index c965799c..5c36e2ee 100644 --- a/tests/use-cases/conftest.py +++ b/tests/use-cases/conftest.py @@ -8,7 +8,6 @@ FNAMES = [ 'pipeline.yaml', 'startscript', - 'train.py', ] diff --git a/tutorials/ml-workflows/basic_components.py b/tutorials/ml-workflows/basic_components.py new file mode 100644 index 00000000..49e74180 --- /dev/null +++ b/tutorials/ml-workflows/basic_components.py @@ -0,0 +1,91 @@ +""" +Here we show how to implement component interfaces in a simple way. +""" +from typing import List, Optional, Tuple, Any +from itwinai.components import ( + DataGetter, DataSplitter, Trainer, Saver, monitor_exec +) + + +class MyDataGetter(DataGetter): + def __init__(self, data_size: int, name: Optional[str] = None) -> None: + super().__init__(name) + self.save_parameters(data_size=data_size) + + @monitor_exec + def execute(self) -> List[int]: + """Return a list dataset. + + Returns: + List[int]: dataset + """ + return list(range(self.data_size)) + + +class MyDatasetSplitter(DataSplitter): + @monitor_exec + def execute( + self, + dataset: List[int] + ) -> Tuple[List[int], List[int], List[int]]: + """Splits a list dataset into train, validation and test datasets. + + Args: + dataset (List[int]): input list dataset. + + Returns: + Tuple[List[int], List[int], List[int]]: train, validation, and + test datasets. + """ + train_n = int(len(dataset)*self.train_proportion) + valid_n = int(len(dataset)*self.validation_proportion) + train_set = dataset[:train_n] + vaild_set = dataset[train_n:train_n+valid_n] + test_set = dataset[train_n+valid_n:] + return train_set, vaild_set, test_set + + +class MyTrainer(Trainer): + def __init__(self, lr: float = 1e-3, name: Optional[str] = None) -> None: + super().__init__(name) + self.save_parameters(name=name, lr=lr) + + @monitor_exec + def execute( + self, + train_set: List[int], + vaild_set: List[int], + test_set: List[int] + ) -> Tuple[List[int], List[int], List[int], str]: + """Dummy ML trainer mocking a ML training algorithm. + + Args: + train_set (List[int]): training dataset. + vaild_set (List[int]): validation dataset. + test_set (List[int]): test dataset. + + Returns: + Tuple[List[int], List[int], List[int], str]: train, validation, + test datasets, and trained model. + """ + return train_set, vaild_set, test_set, "my_trained_model" + + def save_state(self): + return super().save_state() + + def load_state(self): + return super().load_state() + + +class MySaver(Saver): + @monitor_exec + def execute(self, artifact: Any) -> Any: + """Saves an artifact to disk. + + Args: + artifact (Any): artifact to save (e.g., dataset, model). + + Returns: + Any: input artifact. + """ + return artifact diff --git a/tutorials/ml-workflows/tutorial_0_basic_workflow.py b/tutorials/ml-workflows/tutorial_0_basic_workflow.py new file mode 100644 index 00000000..98861777 --- /dev/null +++ b/tutorials/ml-workflows/tutorial_0_basic_workflow.py @@ -0,0 +1,71 @@ +""" +The most simple workflow that you can write is a sequential pipeline of steps, +where the outputs of a component are fed as input to the following component, +employing a scikit-learn-like Pipeline. + +In itwinai, a step is also called "component" and is implemented by extending +the ``itwinai.components.BaseComponent`` class. Each component implements +the `execute(...)` method, which provides a unified interface to interact with +each component. + +The aim of itwinai components is to provide reusable machine learning best +practices, and some common operations are already encoded in some abstract +components. Some examples are: +- ``DataGetter``: has no input and returns a dataset, collected from somewhere +(e.g., downloaded). +- ``DataSplitter``: splits an input dataset into train, validation and test. +- ``DataPreproc``: perform preprocessing on train, validation, and test +datasets. +- ``Trainer``: trains an ML model and returns the trained model. +- ``Saver``: saved an ML artifact (e.g., dataset, model) to disk. + +In this tutorial you will see how to create new components and how they +are assembled into sequential pipelines. Newly created components are +in a separate file called 'basic_components.py'. +""" +from itwinai.pipeline import Pipeline + +# Import the custom components from file +from basic_components import MyDataGetter, MyDatasetSplitter, MyTrainer + +if __name__ == "__main__": + # Assemble them in a scikit-learn like pipeline + pipeline = Pipeline([ + MyDataGetter(data_size=100), + MyDatasetSplitter( + train_proportion=.5, + validation_proportion=.25, + test_proportion=0.25 + ), + MyTrainer() + ]) + + # Inspect steps + print(pipeline[0]) + print(pipeline[2].name) + print(pipeline[1].train_proportion) + + # Run pipeline + _, _, _, trained_model = pipeline.execute() + print("Trained model: ", trained_model) + + # You can also create a Pipeline from a dict of components, which + # simplifies their retrieval by name + pipeline = Pipeline({ + "datagetter": MyDataGetter(data_size=100), + "splitter": MyDatasetSplitter( + train_proportion=.5, + validation_proportion=.25, + test_proportion=0.25 + ), + "trainer": MyTrainer() + }) + + # Inspect steps + print(pipeline["datagetter"]) + print(pipeline["trainer"].name) + print(pipeline["splitter"].train_proportion) + + # Run pipeline + _, _, _, trained_model = pipeline.execute() + print("Trained model: ", trained_model) diff --git a/tutorials/ml-workflows/tutorial_1_intermediate_workflow.py b/tutorials/ml-workflows/tutorial_1_intermediate_workflow.py new file mode 100644 index 00000000..6604df13 --- /dev/null +++ b/tutorials/ml-workflows/tutorial_1_intermediate_workflow.py @@ -0,0 +1,98 @@ +""" +In the previous tutorial we saw how to create new components and assemble them +into a Pipeline for a simplified workflow execution. The Pipeline executes +the components in the order in which they are given, *assuming* that the +outputs of a component will fit as inputs of the following component. +This is not always true, thus you can use the ``Adapter`` component to +compensate for mismatches. This component allows to define a policy to +rearrange intermediate results between two components. + +Moreover, it is good for reproducibility to keep track of the pipeline +configuration used to achieve some outstanding ML results. It would be a shame +to forget how you achieved state-of-the-art results! + +itwinai allows to export the Pipeline form Python code to configuration file, +to persist both parameters and workflow structure. Exporting to configuration +file assumes that each component class resides in a separate python file, so +that the pipeline configuration is agnostic from the current python script. + +Once the Pipeline has been exported to configuration file (YAML), it can +be executed directly from CLI: + +>>> itwinai exec-pipeline --config my-pipeline.yaml --override nested.key=42 + +The itwinai CLI allows for dynamic override of configuration fields, by means +of nested key notation. Also list indices are supported: + +>>> itwinai exec-pipeline --config my-pipe.yaml --override nested.list.2.0=42 + +""" +import subprocess +from itwinai.pipeline import Pipeline +from itwinai.parser import ConfigParser +from itwinai.components import Adapter + +from basic_components import ( + MyDataGetter, MyDatasetSplitter, MyTrainer, MySaver +) + +if __name__ == "__main__": + + # In this pipeline, the MyTrainer produces 4 elements as output: train, + # validation, test datasets, and trained model. The Adapter selects the + # trained model only, and forwards it to the saver, which expects a single + # item as input. + pipeline = Pipeline([ + MyDataGetter(data_size=100), + MyDatasetSplitter( + train_proportion=.5, + validation_proportion=.25, + test_proportion=0.25 + ), + MyTrainer(), + Adapter(policy=[f"{Adapter.INPUT_PREFIX}-1"]), + MySaver() + ]) + + # Run pipeline + trained_model = pipeline.execute() + print("Trained model: ", trained_model) + print("\n" + "="*50 + "\n") + + # Serialize pipeline to YAML + pipeline.to_yaml("basic_pipeline_example.yaml", "pipeline") + + # Below, we show how to run a pre-existing pipeline stored as + # a configuration file, with the possibility of dynamically + # override some fields + + # Load pipeline from saved YAML (dynamic serialization) + parser = ConfigParser( + config="basic_pipeline_example.yaml", + override_keys={ + "pipeline.init_args.steps.0.init_args.data_size": 200 + } + ) + pipeline = parser.parse_pipeline() + print(f"MyDataGetter's data_size is now: {pipeline.steps[0].data_size}\n") + + # Run parsed pipeline, with new data_size for MyDataGetter + trained_model = pipeline.execute() + print("Trained model (2): ", trained_model) + + # Save new pipeline to YAML file + pipeline.to_yaml("basic_pipeline_example_v2.yaml", "pipeline") + + print("\n" + "="*50 + "\n") + + # Emulate pipeline execution from CLI, with dynamic override of + # pipeline configuration fields + subprocess.run( + ["itwinai", "exec-pipeline", "--config", + "basic_pipeline_example_v2.yaml", + "--override", + "pipeline.init_args.steps.0.init_args.data_size=300", + "--override", + "pipeline.init_args.steps.1.init_args.train_proportion=0.4" + ] + ) diff --git a/tutorials/ml-workflows/tutorial_2_advanced_workflow.py b/tutorials/ml-workflows/tutorial_2_advanced_workflow.py new file mode 100644 index 00000000..6c437fb2 --- /dev/null +++ b/tutorials/ml-workflows/tutorial_2_advanced_workflow.py @@ -0,0 +1,86 @@ +""" +In the first two tutorials we saw how to define simple sequential workflows by +means of the Pipeline object, which feds the outputs of the previous component +as inputs of the following one. + +In this tutorial we show how to create more complex workflows, with +non-sequential data flows. Here, components can be arranges as an directed +acyclic graph (DAG). Under the DAG assumption, outputs of each block can be fed +as input potentially to any other component, granting great flexibility to the +experimenter. + +The trade-off for improved flexibility is a change in the way we define +configuration files. From now on, it will only be possible to configure the +parameters used by the training script, but not its structure through the +Pipeline. + +itwinai provides a wrapper of jsonarparse's ArgumentParser which supports +configuration files by default. + +To run as usual: +>>> python my_script.py -d 20 --train-prop 0.7 --val-prop 0.2 --lr 1e-5 + +To reuse the parameters saved in a configuration file and override some +parameter (e.g., learning rate): +>>> python my_script.py --config advanced_tutorial_conf.yaml --lr 2e-3 + +""" +from typing import Any +from itwinai.parser import ArgumentParser +from itwinai.components import Predictor, monitor_exec + +from basic_components import ( + MyDataGetter, MyDatasetSplitter, MyTrainer, MySaver +) + + +class MyEnsemblePredictor(Predictor): + @monitor_exec + def execute(self, dataset, model_ensemble) -> Any: + # do some predictions with model on dataset... + return dataset + + +if __name__ == "__main__": + parser = ArgumentParser(description="itwinai advanced workflows tutorial") + parser.add_argument( + "--data-size", "-d", type=int, required=True, + help="Dataset cardinality.") + parser.add_argument( + "--train-prop", type=float, required=True, + help="Train split proportion.") + parser.add_argument( + "--val-prop", type=float, required=True, + help="Validation split proportion.") + parser.add_argument( + "--lr", type=float, help="Training learning rate.") + args = parser.parse_args() + + # Save parsed arguments to configuration file. + # Previous configurations are overwritten, which is not good, + # but the versioning of configuration files is out of the scope + # of this tutorial. + parser.save( + args, "advanced_tutorial_conf.yaml", format='yaml', overwrite=True) + + # Define workflow components + getter = MyDataGetter(data_size=args.data_size) + splitter = MyDatasetSplitter( + train_proportion=args.train_prop, + validation_proportion=args.val_prop, + test_proportion=1-args.train_prop-args.val_prop + ) + trainer1 = MyTrainer(lr=args.lr) + trainer2 = MyTrainer(lr=args.lr) + saver = MySaver() + predictor = MyEnsemblePredictor(model=None) + + # Define ML workflow + dataset = getter.execute() + train_spl, val_spl, test_spl = splitter.execute(dataset) + _, _, _, trained_model1 = trainer1.execute(train_spl, val_spl, test_spl) + _, _, _, trained_model2 = trainer2.execute(train_spl, val_spl, test_spl) + _ = saver.execute(trained_model1) + predictions = predictor.execute(test_spl, [trained_model1, trained_model2]) + print() + print("Predictions: " + str(predictions)) diff --git a/use-cases/3dgan/cern-pipeline.yaml b/use-cases/3dgan/cern-pipeline.yaml index 57245450..0bc9a756 100644 --- a/use-cases/3dgan/cern-pipeline.yaml +++ b/use-cases/3dgan/cern-pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: - class_path: dataloader.Lightning3DGANDownloader diff --git a/use-cases/3dgan/dataloader.py b/use-cases/3dgan/dataloader.py index d6e5a880..f21e57d9 100644 --- a/use-cases/3dgan/dataloader.py +++ b/use-cases/3dgan/dataloader.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Dict +from typing import Optional import os from lightning.pytorch.utilities.types import EVAL_DATALOADERS @@ -10,21 +10,23 @@ import h5py import gdown -from itwinai.components import DataGetter +from itwinai.components import DataGetter, monitor_exec class Lightning3DGANDownloader(DataGetter): def __init__( - self, - data_path: str, - data_url: Optional[str] = None, - name: Optional[str] = None, - **kwargs) -> None: - super().__init__(name, **kwargs) + self, + data_path: str, + data_url: Optional[str] = None, + name: Optional[str] = None, + ) -> None: + self.save_parameters(**self.locals2params(locals())) + super().__init__(name) self.data_path = data_path self.data_url = data_url - def load(self): + @monitor_exec + def execute(self): # Download data if not os.path.exists(self.data_path): if self.data_url is None: @@ -36,13 +38,6 @@ def load(self): output=self.data_path ) - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[None, Optional[Dict]]: - self.load() - return None, config - class ParticlesDataset(Dataset): def __init__(self, datapath: str, max_samples: Optional[int] = None): diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 773dd399..883533b3 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: - class_path: dataloader.Lightning3DGANDownloader diff --git a/use-cases/3dgan/pipeline.yaml b/use-cases/3dgan/pipeline.yaml index 82665304..d6bade54 100644 --- a/use-cases/3dgan/pipeline.yaml +++ b/use-cases/3dgan/pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: dataloading_step: diff --git a/use-cases/3dgan/saver.py b/use-cases/3dgan/saver.py index 7aa72429..fd9bd710 100644 --- a/use-cases/3dgan/saver.py +++ b/use-cases/3dgan/saver.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple, Optional +from typing import Dict import os import shutil @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np -from itwinai.components import Saver +from itwinai.components import Saver, monitor_exec class ParticleImagesSaver(Saver): @@ -17,30 +17,12 @@ def __init__( self, save_dir: str = '3dgan-generated' ) -> None: + self.save_parameters(**self.locals2params(locals())) super().__init__() self.save_dir = save_dir - def execute( - self, - generated_images: Dict[str, Tensor], - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - """Saves generated images to disk. - - Args: - generated_images (Dict[str, Tensor]): maps unique item ID to - the generated image. - config (Optional[Dict], optional): inherited configuration. - Defaults to None. - - Returns: - Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited - configuration and saver return value. - """ - result = self.save(generated_images) - return ((result,), config) - - def save(self, generated_images: Dict[str, Tensor]) -> None: + @monitor_exec + def execute(self, generated_images: Dict[str, Tensor]) -> None: """Saves generated images to disk. Args: diff --git a/use-cases/3dgan/train.py b/use-cases/3dgan/train.py index d04596be..d12ee05e 100644 --- a/use-cases/3dgan/train.py +++ b/use-cases/3dgan/train.py @@ -15,13 +15,10 @@ import argparse -from itwinai.components import Executor -from itwinai.utils import parse_pipe_config -from jsonargparse import ArgumentParser +from itwinai.parser import ConfigParser if __name__ == "__main__": - # Create CLI Parser parser = argparse.ArgumentParser() parser.add_argument( "-p", "--pipeline", type=str, required=True, @@ -36,20 +33,12 @@ ) args = parser.parse_args() - # Create parser for the pipeline (ordered) - pipe_parser = ArgumentParser() - pipe_parser.add_subclass_arguments(Executor, "executor") - - # Parse, Instantiate pipe - parsed = parse_pipe_config(args.pipeline, pipe_parser) - pipe = pipe_parser.instantiate_classes(parsed) - executor: Executor = getattr(pipe, 'executor') + # Create parser for the pipeline + pipe_parser = ConfigParser(config=args.pipeline) + pipeline = pipe_parser.parse_pipeline() if args.download_only: print('Downloading datasets and exiting...') - executor = executor[:1] - else: - print('Downloading datasets (if not already done) and running...') - executor = executor - executor.setup() - executor() + pipeline = pipeline[:1] + + pipeline.execute() diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 5bd2bcdb..30a55e08 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -1,36 +1,38 @@ import os import sys -from typing import Union, Dict, Tuple, Optional, Any +from typing import Union, Dict, Optional, Any import torch from torch import Tensor import lightning as pl from lightning.pytorch.cli import LightningCLI -from itwinai.components import Trainer, Predictor +from itwinai.components import Trainer, Predictor, monitor_exec from itwinai.serialization import ModelLoader from itwinai.torch.inference import TorchModelLoader from itwinai.torch.types import Batch +from itwinai.utils import load_yaml from itwinai.torch.mlflow import ( init_lightning_mlflow, teardown_lightning_mlflow ) + from model import ThreeDGAN from dataloader import ParticlesDataModule -from utils import load_yaml class Lightning3DGANTrainer(Trainer): def __init__(self, config: Union[Dict, str]): + self.save_parameters(**self.locals2params(locals())) super().__init__() if isinstance(config, str) and os.path.isfile(config): # Load from YAML config = load_yaml(config) self.conf = config - def train(self) -> Any: - init_lightning_mlflow(self.conf, registered_model_name='3dgan-lite') + @monitor_exec + def execute(self) -> Any: old_argv = sys.argv sys.argv = ['some_script_placeholder.py'] cli = LightningCLI( @@ -49,13 +51,6 @@ def train(self) -> Any: cli.trainer.fit(cli.model, datamodule=cli.datamodule) teardown_lightning_mlflow() - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Any, Optional[Dict]]: - result = self.train() - return result, config - def save_state(self): return super().save_state() @@ -99,13 +94,15 @@ def __init__( config: Union[Dict, str], name: Optional[str] = None ): + self.save_parameters(**self.locals2params(locals())) super().__init__(model, name) if isinstance(config, str) and os.path.isfile(config): # Load from YAML config = load_yaml(config) self.conf = config - def predict( + @monitor_exec + def execute( self, datamodule: Optional[pl.LightningDataModule] = None, model: Optional[pl.LightningModule] = None @@ -158,19 +155,3 @@ def transform_predictions(self, batch: Batch) -> Batch: Post-process the predictions of the torch model. """ return batch.squeeze(1) - - def execute( - self, - config: Optional[Dict] = None, - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - """"Execute some operations. - - Args: - config (Dict, optional): key-value configuration. - Defaults to None. - - Returns: - Tuple[Optional[Tuple], Optional[Dict]]: tuple structured as - (results, config). - """ - return self.predict(), config diff --git a/use-cases/3dgan/utils.py b/use-cases/3dgan/utils.py deleted file mode 100644 index d04f9e63..00000000 --- a/use-cases/3dgan/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Utilities for itwinai package. -""" -import os -import yaml - -from collections.abc import MutableMapping -from typing import Dict -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig - - -def load_yaml(path: str) -> Dict: - """Load YAML file as dict. - - Args: - path (str): path to YAML file. - - Raises: - exc: yaml.YAMLError for loading/parsing errors. - - Returns: - Dict: nested dict representation of parsed YAML file. - """ - with open(path, "r", encoding="utf-8") as yaml_file: - try: - loaded_config = yaml.safe_load(yaml_file) - except yaml.YAMLError as exc: - print(exc) - raise exc - return loaded_config - - -def load_yaml_with_deps_from_file(path: str) -> DictConfig: - """ - Load YAML file with OmegaConf and merge it with its dependencies - specified in the `conf-dependencies` field. - Assume that the dependencies live in the same folder of the - YAML file which is importing them. - - Args: - path (str): path to YAML file. - - Raises: - exc: yaml.YAMLError for loading/parsing errors. - - Returns: - DictConfig: nested representation of parsed YAML file. - """ - yaml_conf = load_yaml(path) - use_case_dir = os.path.dirname(path) - deps = [] - if yaml_conf.get("conf-dependencies"): - for dependency in yaml_conf["conf-dependencies"]: - deps.append(load_yaml(os.path.join(use_case_dir, dependency))) - - return OmegaConf.merge(yaml_conf, *deps) - - -def load_yaml_with_deps_from_dict(dict_conf, use_case_dir) -> DictConfig: - deps = [] - - if dict_conf.get("conf-dependencies"): - for dependency in dict_conf["conf-dependencies"]: - deps.append(load_yaml(os.path.join(use_case_dir, dependency))) - - return OmegaConf.merge(dict_conf, *deps) - - -def dynamically_import_class(name: str): - """ - Dynamically import class by module path. - Adapted from https://stackoverflow.com/a/547867 - - Args: - name (str): path to the class (e.g., mypackage.mymodule.MyClass) - - Returns: - __class__: class object. - """ - module, class_name = name.rsplit(".", 1) - mod = __import__(module, fromlist=[class_name]) - klass = getattr(mod, class_name) - return klass - - -def flatten_dict( - d: MutableMapping, parent_key: str = "", sep: str = "." -) -> MutableMapping: - """Flatten dictionary - - Args: - d (MutableMapping): nested dictionary to flatten - parent_key (str, optional): prefix for all keys. Defaults to ''. - sep (str, optional): separator for nested key concatenation. - Defaults to '.'. - - Returns: - MutableMapping: flattened dictionary with new keys. - """ - items = [] - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - items.extend(flatten_dict(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) diff --git a/use-cases/cyclones/.gitignore b/use-cases/cyclones/.gitignore new file mode 100644 index 00000000..255b69f5 --- /dev/null +++ b/use-cases/cyclones/.gitignore @@ -0,0 +1,2 @@ +data +experiments \ No newline at end of file diff --git a/use-cases/cyclones/dataloader.py b/use-cases/cyclones/dataloader.py index 8c837822..ee19b805 100644 --- a/use-cases/cyclones/dataloader.py +++ b/use-cases/cyclones/dataloader.py @@ -1,8 +1,7 @@ -import logging from os import listdir from os.path import join, exists -from itwinai.components import DataGetter -from typing import List, Dict, Optional, Tuple +from itwinai.components import DataGetter, monitor_exec +from typing import List, Dict from lib.macros import ( PatchType, LabelNoCyclone, @@ -29,6 +28,7 @@ class TensorflowDataGetter(DataGetter): def __init__( self, + data_url: str, patch_type: PatchType, shuffle: bool, split_ratio: List[float], @@ -38,11 +38,14 @@ def __init__( target_scale: bool, label_no_cyclone: LabelNoCyclone, aug_type: AugmentationType, - experiment: dict, + experiment: Dict, + global_config: Dict, shuffle_buffer: int = None, data_path: str = "tmp_data" ): super().__init__() + self.save_parameters(**self.locals2params(locals())) + self.data_url = data_url self.batch_size = batch_size self.split_ratio = split_ratio self.epochs = epochs @@ -52,6 +55,7 @@ def __init__( self.aug_type = aug_type.value self.patch_type = patch_type.value self.augment = augment + self.global_config = global_config self.shuffle = shuffle self.data_path = data_path self.drv_vars, self.coo_vars = ( @@ -87,6 +91,9 @@ def __init__( else: self.aug_fns = {} + # Parse global config + self.setup_config(self.global_config) + def split_files(self, files, ratio): n = len(files) return ( @@ -94,7 +101,8 @@ def split_files(self, files, ratio): files[int(ratio[0] * n): int((ratio[0] + ratio[1]) * n)], ) - def load(self): + @monitor_exec + def execute(self): # divide into train, valid and test dataset files train_c_fs, valid_c_fs = self.split_files( files=self.cyclone_files, ratio=self.split_ratio @@ -160,30 +168,16 @@ def load(self): patch_type=self.patch_type, aug_type=self.aug_type, ) - return train_dataset, valid_dataset + return train_dataset, valid_dataset, self.channels - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - config = self.setup_config(config) - train, test = self.load() - logging.debug("Train, valid and test datasets loaded.") - return (train, test), config - - def setup_config(self, config: Optional[Dict] = None) -> Dict: - config = config if config is not None else {} + def setup_config(self, config: Dict) -> None: self.shape = config["shape"] root_dir = config["root_dir"] # Download data - url = ( - "https://drive.google.com/drive/folders/" - "15DEq33MmtRvIpe2bNCg44lnfvEiHcPaf" - ) if not exists(join(root_dir, self.data_path)): gdown.download_folder( - url=url, quiet=False, + url=self.data_url, quiet=False, output=join(root_dir, self.data_path) ) @@ -228,8 +222,3 @@ def setup_config(self, config: Optional[Dict] = None) -> Dict: PatchType.RANDOM.value) ] ) - - config["epochs"] = self.epochs - config["batch_size"] = self.batch_size - config["channels"] = self.channels - return config diff --git a/use-cases/cyclones/executor.py b/use-cases/cyclones/executor.py deleted file mode 100644 index 9c00af43..00000000 --- a/use-cases/cyclones/executor.py +++ /dev/null @@ -1,75 +0,0 @@ -import logging -from os.path import join -from os import makedirs -from datetime import datetime -from typing import Tuple, Dict, Optional, Iterable - -from lib.macros import PATCH_SIZE as patch_size, SHAPE as shape -from itwinai.components import Executor, Executable - - -class CycloneExecutor(Executor): - def __init__( - self, - run_name: str, - steps: Iterable[Executable], - name: Optional[str] = None - ): - super().__init__(steps=steps, name=name) - self.run_name = run_name - - def execute( - self, - root_dir, - config: Optional[Dict] = None, - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - self.root_dir = root_dir - print(f" Data will be stored at: {self.root_dir}") - config = self.setup_config(config) - super().execute(config=config) - - def setup_config(self, config: Optional[Dict] = None) -> Dict: - config = config if config is not None else {} - - # Paths, Folders - FORMATTED_DATETIME = str(datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - MODEL_BACKUP_DIR = join(self.root_dir, "models/") - EXPERIMENTS_DIR = join(self.root_dir, "experiments") - RUN_DIR = join(EXPERIMENTS_DIR, self.run_name + - "_" + FORMATTED_DATETIME) - SCALER_DIR = join(RUN_DIR, "scalers") - TENSORBOARD_DIR = join(RUN_DIR, "tensorboard") - CHECKPOINTS_DIR = join(RUN_DIR, "checkpoints") - - # Files - LOG_FILE = join(RUN_DIR, "run.log") - - # Create folders - makedirs(EXPERIMENTS_DIR, exist_ok=True) - makedirs(RUN_DIR, exist_ok=True) - makedirs(SCALER_DIR, exist_ok=True) - makedirs(TENSORBOARD_DIR, exist_ok=True) - makedirs(CHECKPOINTS_DIR, exist_ok=True) - - config = { - "root_dir": self.root_dir, - "experiment_dir": EXPERIMENTS_DIR, - "run_dir": RUN_DIR, - "scaler_dir": SCALER_DIR, - "tensorboard_dir": TENSORBOARD_DIR, - "checkpoints_dir": CHECKPOINTS_DIR, - "backup_dir": MODEL_BACKUP_DIR, - "log_file": LOG_FILE, - "shape": shape, - "patch_size": patch_size, - } - self.args = config - - # initialize logger - logging.basicConfig( - format="[%(asctime)s] %(levelname)s : %(message)s", - level=logging.DEBUG, - filename=LOG_FILE, - datefmt="%Y-%m-%d %H:%M:%S", - ) - return config diff --git a/use-cases/cyclones/pipeline.yaml b/use-cases/cyclones/pipeline.yaml index de52df9b..97cfc083 100644 --- a/use-cases/cyclones/pipeline.yaml +++ b/use-cases/cyclones/pipeline.yaml @@ -1,10 +1,11 @@ -executor: - class_path: executor.CycloneExecutor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: - run_name: 'default' steps: - - class_path: dataloader.TensorflowDataGetter + download-step: + class_path: dataloader.TensorflowDataGetter init_args: + data_url: https://drive.google.com/drive/folders/15DEq33MmtRvIpe2bNCg44lnfvEiHcPaf patch_type: NEAREST shuffle: False split_ratio: [0.75, 0.25] @@ -19,8 +20,12 @@ executor: 'COO_VARS_1': ['patch_cyclone'], 'MSK_VAR_1': None } - - class_path: trainer.TensorflowTrainer + + training-step: + class_path: trainer.TensorflowTrainer init_args: + epochs: ${pipeline.init_args.steps.download-step.init_args.epochs} + batch_size: ${pipeline.init_args.steps.download-step.init_args.batch_size} network: VGG_V1 activation: LINEAR regularization_strength: NONE diff --git a/use-cases/cyclones/train.py b/use-cases/cyclones/train.py index 82a6d15d..0146dddf 100644 --- a/use-cases/cyclones/train.py +++ b/use-cases/cyclones/train.py @@ -11,22 +11,76 @@ """ +from typing import Dict import argparse +import logging +from os.path import join +from os import makedirs +from datetime import datetime -from itwinai.components import Executor -from itwinai.utils import parse_pipe_config -from jsonargparse import ArgumentParser -from executor import CycloneExecutor +from itwinai.parser import ConfigParser, ArgumentParser + +from lib.macros import PATCH_SIZE, SHAPE + + +def setup_config(args) -> Dict: + config = {} + + # Paths, Folders + FORMATTED_DATETIME = str(datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + MODEL_BACKUP_DIR = join(args.root_dir, "models/") + EXPERIMENTS_DIR = join(args.root_dir, "experiments") + RUN_DIR = join(EXPERIMENTS_DIR, args.run_name + + "_" + FORMATTED_DATETIME) + SCALER_DIR = join(RUN_DIR, "scalers") + TENSORBOARD_DIR = join(RUN_DIR, "tensorboard") + CHECKPOINTS_DIR = join(RUN_DIR, "checkpoints") + + # Files + LOG_FILE = join(RUN_DIR, "run.log") + + # Create folders + makedirs(EXPERIMENTS_DIR, exist_ok=True) + makedirs(RUN_DIR, exist_ok=True) + makedirs(SCALER_DIR, exist_ok=True) + makedirs(TENSORBOARD_DIR, exist_ok=True) + makedirs(CHECKPOINTS_DIR, exist_ok=True) + + config = { + "root_dir": args.root_dir, + "experiment_dir": EXPERIMENTS_DIR, + "run_dir": RUN_DIR, + "scaler_dir": SCALER_DIR, + "tensorboard_dir": TENSORBOARD_DIR, + "checkpoints_dir": CHECKPOINTS_DIR, + "backup_dir": MODEL_BACKUP_DIR, + "log_file": LOG_FILE, + "shape": SHAPE, + "patch_size": PATCH_SIZE, + # "epochs": args.epochs, + # "batch_size": args.batch_size + } + + # initialize logger + logging.basicConfig( + format="[%(asctime)s] %(levelname)s : %(message)s", + level=logging.DEBUG, + filename=LOG_FILE, + datefmt="%Y-%m-%d %H:%M:%S", + ) + return config if __name__ == "__main__": - # Create CLI Parser - parser = argparse.ArgumentParser() + parser = ArgumentParser() parser.add_argument( "-p", "--pipeline", type=str, required=True, help='Configuration file to the pipeline to execute.' ) parser.add_argument("-r", "--root_dir", type=str, default='./data') + parser.add_argument("-n", "--run_name", default="noname", type=str) + parser.add_argument("-e", "--epochs", default=1, type=int) + parser.add_argument("-b", "--batch_size", default=32, type=int) parser.add_argument( '-d', '--download-only', action=argparse.BooleanOptionalAction, @@ -35,21 +89,24 @@ '(suggested on login nodes of HPC systems)') ) args = parser.parse_args() + global_config = setup_config(args) - # Create parser for the pipeline (ordered) - pipe_parser = ArgumentParser() - pipe_parser.add_subclass_arguments(Executor, "executor") - - # Parse, Instantiate pipe - parsed = parse_pipe_config(args.pipeline, pipe_parser) - pipe = pipe_parser.instantiate_classes(parsed) - executor: CycloneExecutor = getattr(pipe, 'executor') + # Create parser for the pipeline + downloader_params = "pipeline.init_args.steps.download-step.init_args." + trainer_params = "pipeline.init_args.steps.training-step.init_args." + pipe_parser = ConfigParser( + config=args.pipeline, + override_keys={ + downloader_params + "epochs": args.epochs, + downloader_params + "batch_size": args.batch_size, + downloader_params + "global_config": global_config, + trainer_params + "global_config": global_config + } + ) + pipeline = pipe_parser.parse_pipeline() if args.download_only: print('Downloading datasets and exiting...') - executor = executor[:1] - else: - print('Downloading datasets (if not already done) and running...') - executor = executor - executor.setup() - executor(root_dir=args.root_dir) + pipeline = pipeline[:1] + + pipeline.execute() diff --git a/use-cases/cyclones/trainer.py b/use-cases/cyclones/trainer.py index 8760e4bc..2fb3c1bc 100644 --- a/use-cases/cyclones/trainer.py +++ b/use-cases/cyclones/trainer.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Any import logging from os.path import join, exists @@ -6,7 +6,7 @@ from lib.strategy import get_mirrored_strategy from lib.utils import get_network_config, load_model -from itwinai.components import Trainer +from itwinai.components import Trainer, monitor_exec from lib.callbacks import ProcessBenchmark from lib.macros import ( Network, @@ -24,12 +24,18 @@ def __init__( regularization_strength: RegularizationStrength, learning_rate: float, loss: Losses, + epochs: int, + batch_size: int, + global_config: Dict[str, Any], kernel_size: int = None, model_backup: str = None, cores: int = None, ): super().__init__() - # Configurable + self.save_parameters(**self.locals2params(locals())) + self.epochs = epochs + self.batch_size = batch_size + self.global_config = global_config self.cores = cores self.model_backup = model_backup self.network = network.value @@ -43,7 +49,11 @@ def __init__( # Optimizers, Losses self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate) - def train(self, train_data, validation_data): + # Parse global config + self.setup_config(self.global_config) + + @monitor_exec + def execute(self, train_data, validation_data, channels) -> None: train_dataset, n_train = train_data valid_dataset, n_valid = validation_data @@ -68,7 +78,7 @@ def train(self, train_data, validation_data): activation=self.activation, regularizer=self.regularizer, kernel_size=self.kernel_size, - channels=self.channels, + channels=channels, ) logging.debug("New model created") else: @@ -103,24 +113,10 @@ def train(self, train_data, validation_data): model.save(self.last_model_name) logging.debug("Saved training history") - def execute( - self, - train_dataset, - validation_dataset, - config: Optional[Dict] = None, - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - config = self.setup_config(config) - train_result = self.train(train_dataset, validation_dataset) - return (train_result,), config - - def setup_config(self, config: Optional[Dict] = None) -> Dict: - config = config if config is not None else {} + def setup_config(self, config: Dict) -> None: self.experiment_dir = config["experiment_dir"] self.run_dir = config["run_dir"] - self.epochs = config["epochs"] - self.batch_size = config["batch_size"] self.patch_size = config["patch_size"] - self.channels = config["channels"] # Paths CHECKPOINTS_DIR = join(self.run_dir, "checkpoints") @@ -159,8 +155,6 @@ def setup_config(self, config: Optional[Dict] = None) -> Dict: self.best_model_name = join(self.model_backup, "best_model.h5") self.last_model_name = join(self.run_dir, "last_model.h5") - return config - def load_state(self): return super().load_state() diff --git a/use-cases/mnist/tensorflow/dataloader.py b/use-cases/mnist/tensorflow/dataloader.py index 920e0dba..cc95153e 100644 --- a/use-cases/mnist/tensorflow/dataloader.py +++ b/use-cases/mnist/tensorflow/dataloader.py @@ -1,31 +1,32 @@ -from typing import Optional, Dict, Tuple +from typing import Tuple import tensorflow.keras as keras import tensorflow as tf -from itwinai.components import DataGetter, DataPreproc +from itwinai.components import DataGetter, DataPreproc, monitor_exec class MNISTDataGetter(DataGetter): def __init__(self): super().__init__() + self.save_parameters(**self.locals2params(locals())) - def load(self): - return keras.datasets.mnist.load_data() - - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - train, test = self.load() - return ([train, test],), config + @monitor_exec + def execute(self) -> Tuple: + train, test = keras.datasets.mnist.load_data() + return train, test class MNISTDataPreproc(DataPreproc): def __init__(self, classes: int): super().__init__() + self.save_parameters(**self.locals2params(locals())) self.classes = classes - def preproc(self, datasets) -> Tuple: + @monitor_exec + def execute( + self, + *datasets, + ) -> Tuple: options = tf.data.Options() options.experimental_distribute.auto_shard_policy = ( tf.data.experimental.AutoShardPolicy.FILE) @@ -37,10 +38,3 @@ def preproc(self, datasets) -> Tuple: sliced = sliced.with_options(options) preprocessed.append(sliced) return tuple(preprocessed) - - def execute( - self, - datasets, - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - return self.preproc(datasets), config diff --git a/use-cases/mnist/tensorflow/pipeline.yaml b/use-cases/mnist/tensorflow/pipeline.yaml index aa34e0d4..9fced327 100644 --- a/use-cases/mnist/tensorflow/pipeline.yaml +++ b/use-cases/mnist/tensorflow/pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: - class_path: dataloader.MNISTDataGetter @@ -29,7 +29,7 @@ executor: input_shape: [ 28, 28, 1 ] output_shape: 10 - strategy: + strategy: class_path: tensorflow.python.distribute.mirrored_strategy.MirroredStrategy logger: diff --git a/use-cases/mnist/tensorflow/train.py b/use-cases/mnist/tensorflow/train.py index 65e12c78..26a90f81 100644 --- a/use-cases/mnist/tensorflow/train.py +++ b/use-cases/mnist/tensorflow/train.py @@ -13,13 +13,10 @@ import argparse -from itwinai.components import Executor -from itwinai.utils import parse_pipe_config -from jsonargparse import ArgumentParser +from itwinai.parser import ConfigParser if __name__ == "__main__": - # Create CLI Parser parser = argparse.ArgumentParser() parser.add_argument( "-p", "--pipeline", type=str, required=True, @@ -34,20 +31,12 @@ ) args = parser.parse_args() - # Create parser for the pipeline (ordered) - pipe_parser = ArgumentParser() - pipe_parser.add_subclass_arguments(Executor, "executor") - - # Parse, Instantiate pipe - parsed = parse_pipe_config(args.pipeline, pipe_parser) - pipe = pipe_parser.instantiate_classes(parsed) - executor: Executor = getattr(pipe, 'executor') + # Create parser for the pipeline + pipe_parser = ConfigParser(config=args.pipeline) + pipeline = pipe_parser.parse_pipeline() if args.download_only: print('Downloading datasets and exiting...') - executor = executor[:1] - else: - print('Downloading datasets (if not already done) and running...') - executor = executor - executor.setup() - executor() + pipeline = pipeline[:1] + + pipeline.execute() diff --git a/use-cases/mnist/tensorflow/trainer.py b/use-cases/mnist/tensorflow/trainer.py index dfbc06c7..17ef19a5 100644 --- a/use-cases/mnist/tensorflow/trainer.py +++ b/use-cases/mnist/tensorflow/trainer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Any +from typing import Dict, List, Optional, Any # from tensorflow.keras.optimizers import Optimizer # from tensorflow.keras.losses import Loss @@ -6,6 +6,7 @@ from itwinai.tensorflow.trainer import TensorflowTrainer from itwinai.loggers import Logger +from itwinai.components import monitor_exec class MNISTTrainer(TensorflowTrainer): @@ -19,29 +20,21 @@ def __init__( strategy: Optional[MirroredStrategy] = None, logger: Optional[List[Logger]] = None ): - # Configurable - self.logger = logger if logger is not None else [] - compile_conf = dict(loss=loss, optimizer=optimizer) - print(f'STRATEGY: {strategy}') super().__init__( epochs=epochs, batch_size=batch_size, callbacks=[], model_dict=model, - compile_conf=compile_conf, + compile_conf=dict(loss=loss, optimizer=optimizer), strategy=strategy ) + self.save_parameters(**self.locals2params(locals())) + print(f'STRATEGY: {strategy}') + self.logger = logger if logger is not None else [] - def train(self, train_dataset, validation_dataset) -> Any: - return super().train(train_dataset, validation_dataset) - - def execute( - self, - train_dataset, - validation_dataset, - config: Optional[Dict] = None, - ) -> Tuple[Optional[Tuple], Optional[Dict]]: - return (self.train(train_dataset, validation_dataset),), config + @monitor_exec + def execute(self, train_dataset, validation_dataset) -> Any: + return super().execute(train_dataset, validation_dataset) def load_state(self): return super().load_state() diff --git a/use-cases/mnist/torch-lightning/dataloader.py b/use-cases/mnist/torch-lightning/dataloader.py index 28ec236d..1f062fe5 100644 --- a/use-cases/mnist/torch-lightning/dataloader.py +++ b/use-cases/mnist/torch-lightning/dataloader.py @@ -1,20 +1,21 @@ -from typing import Optional, Tuple, Dict +from typing import Optional import lightning as L from torchvision.datasets import MNIST from torch.utils.data import DataLoader, random_split from torchvision import transforms -from itwinai.components import DataGetter +from itwinai.components import DataGetter, monitor_exec class LightningMNISTDownloader(DataGetter): def __init__( - self, - data_path: str, - name: Optional[str] = None, - **kwargs) -> None: - super().__init__(name, **kwargs) + self, + data_path: str, + name: Optional[str] = None + ) -> None: + super().__init__(name) + self.save_parameters(**self.locals2params(locals())) self.data_path = data_path self._downloader = MNISTDataModule( data_path=self.data_path, download=True, @@ -22,19 +23,13 @@ def __init__( batch_size=1, train_prop=.5, ) - def load(self): + @monitor_exec + def execute(self) -> None: # Simulate dataset creation to force data download self._downloader.setup(stage='fit') self._downloader.setup(stage='test') self._downloader.setup(stage='predict') - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[None, Optional[Dict]]: - self.load() - return None, config - class MNISTDataModule(L.LightningModule): def __init__( diff --git a/use-cases/mnist/torch-lightning/pipeline.yaml b/use-cases/mnist/torch-lightning/pipeline.yaml index 33ae0a94..cf754b2f 100644 --- a/use-cases/mnist/torch-lightning/pipeline.yaml +++ b/use-cases/mnist/torch-lightning/pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: - class_path: dataloader.LightningMNISTDownloader diff --git a/use-cases/mnist/torch-lightning/train.py b/use-cases/mnist/torch-lightning/train.py index 50c91988..97f53093 100644 --- a/use-cases/mnist/torch-lightning/train.py +++ b/use-cases/mnist/torch-lightning/train.py @@ -15,13 +15,10 @@ import argparse -from itwinai.components import Executor -from itwinai.utils import parse_pipe_config -from jsonargparse import ArgumentParser +from itwinai.parser import ConfigParser if __name__ == "__main__": - # Create CLI Parser parser = argparse.ArgumentParser() parser.add_argument( "-p", "--pipeline", type=str, required=True, @@ -36,20 +33,12 @@ ) args = parser.parse_args() - # Create parser for the pipeline (ordered) - pipe_parser = ArgumentParser() - pipe_parser.add_subclass_arguments(Executor, "executor") - - # Parse, Instantiate pipe - parsed = parse_pipe_config(args.pipeline, pipe_parser) - pipe = pipe_parser.instantiate_classes(parsed) - executor: Executor = getattr(pipe, 'executor') + # Create parser for the pipeline + pipe_parser = ConfigParser(config=args.pipeline) + pipeline = pipe_parser.parse_pipeline() if args.download_only: print('Downloading datasets and exiting...') - executor = executor[:1] - else: - print('Downloading datasets (if not already done) and running...') - executor = executor - executor.setup() - executor() + pipeline = pipeline[:1] + + pipeline.execute() diff --git a/use-cases/mnist/torch-lightning/trainer.py b/use-cases/mnist/torch-lightning/trainer.py index 72454cea..128cf5c6 100644 --- a/use-cases/mnist/torch-lightning/trainer.py +++ b/use-cases/mnist/torch-lightning/trainer.py @@ -1,7 +1,7 @@ import os -from typing import Union, Dict, Tuple, Optional, Any +from typing import Union, Dict, Any -from itwinai.components import Trainer +from itwinai.components import Trainer, monitor_exec from itwinai.torch.models.mnist import MNISTModel from dataloader import MNISTDataModule from lightning.pytorch.cli import LightningCLI @@ -11,12 +11,14 @@ class LightningMNISTTrainer(Trainer): def __init__(self, config: Union[Dict, str]): super().__init__() + self.save_parameters(**self.locals2params(locals())) if isinstance(config, str) and os.path.isfile(config): # Load from YAML config = load_yaml(config) self.conf = config - def train(self) -> Any: + @monitor_exec + def execute(self) -> Any: cli = LightningCLI( args=self.conf, model_class=MNISTModel, @@ -31,13 +33,6 @@ def train(self) -> Any: ) cli.trainer.fit(cli.model, datamodule=cli.datamodule) - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Any, Optional[Dict]]: - result = self.train() - return result, config - def save_state(self): return super().save_state() diff --git a/use-cases/mnist/torch/dataloader.py b/use-cases/mnist/torch/dataloader.py index 39e9b56b..e4243763 100644 --- a/use-cases/mnist/torch/dataloader.py +++ b/use-cases/mnist/torch/dataloader.py @@ -1,6 +1,6 @@ """Dataloader for Torch-based MNIST use case.""" -from typing import Dict, Optional, Tuple, Callable, Any +from typing import Optional, Tuple, Callable, Any import os import shutil @@ -8,59 +8,33 @@ from torch.utils.data import Dataset from torchvision import transforms, datasets -from itwinai.components import DataGetter +from itwinai.components import DataGetter, monitor_exec class MNISTDataModuleTorch(DataGetter): """Download MNIST dataset for torch.""" - def __init__( - self, - save_path: str = '.tmp/', - # batch_size: int = 32, - # pin_memory: bool = True, - # num_workers: int = 4 - ) -> None: + def __init__(self, save_path: str = '.tmp/',) -> None: super().__init__() + self.save_parameters(**self.locals2params(locals())) self.save_path = save_path - # self.batch_size = batch_size - # self.pin_memory = pin_memory - # self.num_workers = num_workers - def load(self): - self.train_dataset = datasets.MNIST( + @monitor_exec + def execute(self) -> Tuple[Dataset, Dataset]: + train_dataset = datasets.MNIST( self.save_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) - self.val_dataset = datasets.MNIST( + validation_dataset = datasets.MNIST( self.save_path, train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) - - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Tuple[Dataset, Dataset], Optional[Dict]]: - self.load() - print("Train and valid datasets loaded.") - # train_dataloder = DataLoader( - # self.train_dataset, - # batch_size=self.batch_size, - # pin_memory=self.pin_memory, - # num_workers=self.num_workers - # ) - # validation_dataloader = DataLoader( - # self.val_dataset, - # batch_size=self.batch_size, - # pin_memory=self.pin_memory, - # num_workers=self.num_workers - # ) - # return (train_dataloder, validation_dataloader) - return (self.train_dataset, self.val_dataset), config + print("Train and validation datasets loaded.") + return train_dataset, validation_dataset class InferenceMNIST(Dataset): @@ -100,11 +74,6 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ img_id, img = list(self.data.items())[index] - # doing this so that it is consistent with all other datasets - # to return a PIL Image - # print(type(img)) - # img = Image.fromarray(img.numpy(), mode="L") - if self.transform is not None: img = self.transform(img) @@ -136,21 +105,13 @@ def generate_jpg_sample( class MNISTPredictLoader(DataGetter): - def __init__( - self, - test_data_path: str - ) -> None: + def __init__(self, test_data_path: str) -> None: super().__init__() + self.save_parameters(**self.locals2params(locals())) self.test_data_path = test_data_path - def execute( - self, - config: Optional[Dict] = None - ) -> Tuple[Tuple[Dataset, Dataset], Optional[Dict]]: - data = self.load() - return data, config - - def load(self) -> Dataset: + @monitor_exec + def execute(self) -> Dataset: return InferenceMNIST( root=self.test_data_path, transform=transforms.Compose([ diff --git a/use-cases/mnist/torch/inference-pipeline.yaml b/use-cases/mnist/torch/inference-pipeline.yaml index ba4f5e86..5edf6ce9 100644 --- a/use-cases/mnist/torch/inference-pipeline.yaml +++ b/use-cases/mnist/torch/inference-pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: - class_path: dataloader.MNISTPredictLoader diff --git a/use-cases/mnist/torch/pipeline.yaml b/use-cases/mnist/torch/pipeline.yaml index 67848652..99f35c73 100644 --- a/use-cases/mnist/torch/pipeline.yaml +++ b/use-cases/mnist/torch/pipeline.yaml @@ -1,5 +1,5 @@ -executor: - class_path: itwinai.components.Executor +pipeline: + class_path: itwinai.pipeline.Pipeline init_args: steps: dataloading_step: diff --git a/use-cases/mnist/torch/saver.py b/use-cases/mnist/torch/saver.py index fd54c0cf..e1ce56ac 100644 --- a/use-cases/mnist/torch/saver.py +++ b/use-cases/mnist/torch/saver.py @@ -2,12 +2,12 @@ This module is used during inference to save predicted labels to file. """ -from typing import Optional, List, Dict, Tuple +from typing import Optional, List, Dict import os import shutil import csv -from itwinai.components import Saver +from itwinai.components import Saver, monitor_exec class TorchMNISTLabelSaver(Saver): @@ -20,6 +20,7 @@ def __init__( class_labels: Optional[List] = None ) -> None: super().__init__() + self.save_parameters(**self.locals2params(locals())) self.save_dir = save_dir self.predictions_file = predictions_file self.class_labels = ( @@ -27,23 +28,17 @@ def __init__( else [f'Digit {i}' for i in range(10)] ) - def execute( - self, - predicted_classes: Dict[str, int], - config: Optional[Dict] = None - ) -> Tuple[Optional[Tuple], Optional[Dict]]: + @monitor_exec + def execute(self, predicted_classes: Dict[str, int],) -> Dict[str, int]: """Translate predictions from class idx to class label and save them to disk. Args: predicted_classes (Dict[str, int]): maps unique item ID to the predicted class ID. - config (Optional[Dict], optional): inherited configuration. - Defaults to None. Returns: - Tuple[Optional[Tuple], Optional[Dict]]: propagation of inherited - configuration and saver return value. + Dict[str, int]: predicted classes. """ if os.path.exists(self.save_dir): shutil.rmtree(self.save_dir) @@ -54,12 +49,11 @@ def execute( itm_name: self.class_labels[cls_idx] for itm_name, cls_idx in predicted_classes.items() } - result = self.save(predicted_labels) - return ((result,), config) - def save(self, predicted_labels: Dict[str, str]) -> None: + # Save to disk filepath = os.path.join(self.save_dir, self.predictions_file) with open(filepath, 'w') as csv_file: writer = csv.writer(csv_file) for key, value in predicted_labels.items(): writer.writerow([key, value]) + return predicted_labels diff --git a/use-cases/mnist/torch/train.py b/use-cases/mnist/torch/train.py index 50c91988..97f53093 100644 --- a/use-cases/mnist/torch/train.py +++ b/use-cases/mnist/torch/train.py @@ -15,13 +15,10 @@ import argparse -from itwinai.components import Executor -from itwinai.utils import parse_pipe_config -from jsonargparse import ArgumentParser +from itwinai.parser import ConfigParser if __name__ == "__main__": - # Create CLI Parser parser = argparse.ArgumentParser() parser.add_argument( "-p", "--pipeline", type=str, required=True, @@ -36,20 +33,12 @@ ) args = parser.parse_args() - # Create parser for the pipeline (ordered) - pipe_parser = ArgumentParser() - pipe_parser.add_subclass_arguments(Executor, "executor") - - # Parse, Instantiate pipe - parsed = parse_pipe_config(args.pipeline, pipe_parser) - pipe = pipe_parser.instantiate_classes(parsed) - executor: Executor = getattr(pipe, 'executor') + # Create parser for the pipeline + pipe_parser = ConfigParser(config=args.pipeline) + pipeline = pipe_parser.parse_pipeline() if args.download_only: print('Downloading datasets and exiting...') - executor = executor[:1] - else: - print('Downloading datasets (if not already done) and running...') - executor = executor - executor.setup() - executor() + pipeline = pipeline[:1] + + pipeline.execute() diff --git a/use-cases/zebra2horse/train.py b/use-cases/zebra2horse/train.py index 08a91fd2..c33b9402 100644 --- a/use-cases/zebra2horse/train.py +++ b/use-cases/zebra2horse/train.py @@ -2,7 +2,7 @@ from trainer import Zebra2HorseTrainer from dataloader import Zebra2HorseDataLoader -from itwinai.executors import LocalExecutor # , RayExecutor +from itwinai.experimental.executors import LocalExecutor # , RayExecutor if __name__ == "__main__": From dd2c5ead1e0da2ac712e98e5d38b6b4d1c359773 Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Wed, 13 Dec 2023 13:51:18 +0100 Subject: [PATCH 54/57] Simplified workflow configuration (#109) * Add SQAaaS dynamic badge for dev branch (#104) * Add SQAaaS dynamic badge * Upgrade to sqaaas-assessment-action@v2 * Add draft example * UPDATE credits field * ADD docs * REFACTOR components and pipeline code * UPDATE docstring * UPDATE mnist torch uc * ADD config file parser draft * ADD itwinaiCLI and ConfigParser * ADD docs * ADD pipeline parser and serializer plus tests * UPDATE docs * ADD adapter component and tests (incl parser) * ADD splitter component, improve pipeline, tests * UPDATE test * REMOVE todos * ADD component tests * ADD serializer tests * FIX linter * ADD basic workflow tutorial * ADD basic intermediate tutorial * ADD advanced tutorial * UPDATE advanced tutorial * UPDATE use cases * UPDATE save parameters * FIX linter * FIX cyclones use case workflow * ADD slurm jobscript * FIX merge error * FIX components template --------- Co-authored-by: orviz --- src/itwinai/components.py | 52 +++++++++++++++++++------------------ tests/all_tests_startscript | 32 +++++++++++++++++++++++ use-cases/3dgan/trainer.py | 1 + 3 files changed, 60 insertions(+), 25 deletions(-) create mode 100644 tests/all_tests_startscript diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 0c628e0c..1f41bacd 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -95,7 +95,6 @@ from .serialization import ModelLoader, Serializable - def monitor_exec(method: Callable) -> Callable: """Decorator for execute method of a component class. Computes execution time and gives some information about @@ -144,7 +143,7 @@ def __init__( # logs_dir: Optional[str] = None, # debug: bool = False, ) -> None: - self.save_parameters(name=name) + self.save_parameters(**self.locals2params(locals())) self.name = name @property @@ -202,17 +201,19 @@ class Trainer(BaseComponent): def execute( self, train_dataset: MLDataset, - validation_dataset: MLDataset - ) -> Tuple[MLDataset, MLDataset, MLModel]: + validation_dataset: MLDataset, + test_dataset: MLDataset + ) -> Tuple[MLDataset, MLDataset, MLDataset, MLModel]: """Trains a machine learning model. Args: - train_dataset (DatasetML): training dataset. - validation_dataset (DatasetML): validation dataset. + train_dataset (MLDataset): training dataset. + validation_dataset (MLDataset): validation dataset. + test_dataset (MLDataset): test dataset. Returns: - Tuple[DatasetML, DatasetML, ModelML]: training dataset, - validation dataset, trained model. + Tuple[MLDataset, MLDataset, MLDataset]: training dataset, + validation dataset, test dataset, trained model. """ @abstractmethod @@ -235,7 +236,7 @@ def __init__( name: Optional[str] = None, ) -> None: super().__init__(name=name) - self.save_parameters(model=model, name=name) + self.save_parameters(**self.locals2params(locals())) self.model = model() if isinstance(model, ModelLoader) else model @abstractmethod @@ -248,12 +249,12 @@ def execute( """Applies a machine learning model on a dataset of samples. Args: - predict_dataset (DatasetML): dataset for inference. - model (Optional[ModelML], optional): overrides the internal model, + predict_dataset (MLDataset): dataset for inference. + model (Optional[MLModel], optional): overrides the internal model, if given. Defaults to None. Returns: - DatasetML: predictions with the same cardinality of the + MLDataset: predictions with the same cardinality of the input dataset. """ @@ -276,18 +277,25 @@ class DataPreproc(BaseComponent): @abstractmethod @monitor_exec - def execute(self, dataset: MLDataset) -> MLDataset: - """Pre-processes a dataset. + def execute( + self, + train_dataset: MLDataset, + validation_dataset: MLDataset, + test_dataset: MLDataset + ) -> Tuple[MLDataset, MLDataset, MLDataset]: + """Trains a machine learning model. Args: - dataset (MLDataset): dataset. + train_dataset (MLDataset): training dataset. + validation_dataset (MLDataset): validation dataset. + test_dataset (MLDataset): test dataset. Returns: - MLDataset: pre-processed dataset. + Tuple[MLDataset, MLDataset, MLDataset]: preprocessed training + dataset, validation dataset, test dataset. """ - class Saver(BaseComponent): """Saves artifact to disk.""" @@ -331,7 +339,7 @@ class Adapter(BaseComponent): def __init__(self, policy: List[Any], name: Optional[str] = None) -> None: super().__init__(name=name) - self.save_parameters(policy=policy, name=name) + self.save_parameters(**self.locals2params(locals())) self.name = name self.policy = policy @@ -379,7 +387,6 @@ class DataSplitter(BaseComponent): _validation_proportion: Union[int, float] _test_proportion: Union[int, float] - def __init__( self, train_proportion: Union[int, float], @@ -388,12 +395,7 @@ def __init__( name: Optional[str] = None ) -> None: super().__init__(name) - self.save_parameters( - train_proportion=train_proportion, - validation_proportion=validation_proportion, - test_proportion=test_proportion, - name=name - ) + self.save_parameters(**self.locals2params(locals())) self.train_proportion = train_proportion self.validation_proportion = validation_proportion self.test_proportion = test_proportion diff --git a/tests/all_tests_startscript b/tests/all_tests_startscript new file mode 100644 index 00000000..1dc92c0e --- /dev/null +++ b/tests/all_tests_startscript @@ -0,0 +1,32 @@ +#!/bin/bash + +# general configuration of the job +#SBATCH --job-name=PrototypeTest +#SBATCH --account=intertwin +#SBATCH --mail-user= +#SBATCH --mail-type=ALL +#SBATCH --output=job.out +#SBATCH --error=job.err +#SBATCH --time=00:30:00 + +# configure node and process count on the CM +#SBATCH --partition=batch +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --gpus-per-node=4 + +# SBATCH --exclusive + +# gres options have to be disabled for deepv +#SBATCH --gres=gpu:4 + +# load modules +ml --force purge +ml Stages/2023 StdEnv/2023 NVHPC/23.1 OpenMPI/4.1.4 cuDNN/8.6.0.163-CUDA-11.7 Python/3.10.4 HDF5 libaio/0.3.112 GCC/11.3.0 + +# shellcheck source=/dev/null +source ~/.bashrc + +# from repo's root dir +srun micromamba run -p ./.venv-pytorch pytest -v tests/ \ No newline at end of file diff --git a/use-cases/3dgan/trainer.py b/use-cases/3dgan/trainer.py index 30a55e08..3bb5a1fd 100644 --- a/use-cases/3dgan/trainer.py +++ b/use-cases/3dgan/trainer.py @@ -33,6 +33,7 @@ def __init__(self, config: Union[Dict, str]): @monitor_exec def execute(self) -> Any: + init_lightning_mlflow(self.conf, registered_model_name='3dgan-lite') old_argv = sys.argv sys.argv = ['some_script_placeholder.py'] cli = LightningCLI( From debc6a4e887e076f17823e9c28924840b2a365fb Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 13 Dec 2023 15:14:38 +0100 Subject: [PATCH 55/57] ADD integration tests --- src/itwinai/cli.py | 1 + tests/use-cases/conftest.py | 5 +- tests/use-cases/test_3dgan.py | 71 +++++++++++++++++++++++++ use-cases/3dgan/inference-pipeline.yaml | 2 - 4 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 tests/use-cases/test_3dgan.py diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index 12954fbf..1bf2feb9 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -50,6 +50,7 @@ def exec_pipeline( # to find the local python files imported from the pipeline file import os import sys + sys.path.append(os.path.dirname(config)) sys.path.append(os.getcwd()) # Parse and execute pipeline diff --git a/tests/use-cases/conftest.py b/tests/use-cases/conftest.py index 5c36e2ee..d080e0a8 100644 --- a/tests/use-cases/conftest.py +++ b/tests/use-cases/conftest.py @@ -1,4 +1,5 @@ import os +from typing import Callable import pytest import subprocess @@ -12,7 +13,7 @@ @pytest.fixture -def check_folder_structure(): +def check_folder_structure() -> Callable: """ Verify that the use case folder complies with some predefined structure. @@ -25,7 +26,7 @@ def _check_structure(root: str): @pytest.fixture -def install_requirements(): +def install_requirements() -> Callable: """Install requirements.txt, if present in root folder.""" def _install_reqs(root: str, env_prefix: str): req_path = os.path.join(root, 'requirements.txt') diff --git a/tests/use-cases/test_3dgan.py b/tests/use-cases/test_3dgan.py new file mode 100644 index 00000000..9a6eebdf --- /dev/null +++ b/tests/use-cases/test_3dgan.py @@ -0,0 +1,71 @@ +""" +Tests for CERN use case (3DGAN). +""" +import pytest +import subprocess +from itwinai.utils import dynamically_import_class + +CERN_PATH = "use-cases/3dgan" +CKPT_PATH = "3dgan-inference.pth" + + +@pytest.fixture(scope="module") +def fake_model_checkpoint() -> None: + """ + Create a dummy model checkpoint for inference. + """ + import sys + import torch + sys.path.append(CERN_PATH) + from model import ThreeDGAN + ThreeDGAN = dynamically_import_class('model.ThreeDGAN') + net = ThreeDGAN() + torch.save(net, CKPT_PATH) + + +def test_structure_3dgan(check_folder_structure): + """Test 3DGAN folder structure.""" + check_folder_structure(CERN_PATH) + + +@pytest.mark.functional +def test_3dgan_train(install_requirements): + """ + Test 3DGAN torch lightning trainer by running it end-to-end. + """ + install_requirements(CERN_PATH, pytest.TORCH_PREFIX) + # cmd = (f"micromamba run -p {pytest.TORCH_PREFIX} python " + # f"{CERN_PATH}/train.py -p {CERN_PATH}/pipeline.yaml") + cmd = (f"micromamba run -p {pytest.TORCH_PREFIX} itwinai exec-pipeline " + f"--config {CERN_PATH}/pipeline.yaml") + subprocess.run(cmd.split(), check=True) + + +@pytest.mark.functional +def test_3dgan_inference(install_requirements, fake_model_checkpoint): + """ + Test 3DGAN torch lightning trainer by running it end-to-end. + """ + install_requirements(CERN_PATH, pytest.TORCH_PREFIX) + # cmd = (f"micromamba run -p {pytest.TORCH_PREFIX} python " + # f"{CERN_PATH}/train.py -p {CERN_PATH}/pipeline.yaml") + # cmd = (f"micromamba run -p {pytest.TORCH_PREFIX} itwinai exec-pipeline " + # f"--config {CERN_PATH}/inference-pipeline.yaml") + + getter_params = "pipeline.init_args.steps.0.init_args" + trainer_params = "pipeline.init_args.steps.1.init_args" + logger_params = trainer_params + ".config.trainer.logger.init_args" + data_params = trainer_params + ".config.data.init_args" + saver_params = "pipeline.init_args.steps.2.init_args" + cmd = ( + 'itwinai exec-pipeline ' + '--config use-cases/3dgan/inference-pipeline.yaml ' + f'-o {getter_params}.data_path=exp_data ' + f'-o {trainer_params}.model.init_args.model_uri="{CKPT_PATH}" ' + f'-o {logger_params}.save_dir=ml_logs/mlflow_logs ' + f'-o {data_params}.datapath="exp_data/*/*.h5" ' + f'-o {saver_params}.save_dir=3dgan-generated-data ' + ) + raise ValueError(cmd) + + subprocess.run(cmd.split(), check=True) diff --git a/use-cases/3dgan/inference-pipeline.yaml b/use-cases/3dgan/inference-pipeline.yaml index 883533b3..59d8f54f 100644 --- a/use-cases/3dgan/inference-pipeline.yaml +++ b/use-cases/3dgan/inference-pipeline.yaml @@ -84,11 +84,9 @@ pipeline: class_path: model.ThreeDGAN init_args: latent_size: 256 - batch_size: 64 loss_weights: [3, 0.1, 25, 0.1] power: 0.85 lr: 0.001 - checkpoint_path: /usr/data/exp_data/3dgan.pth # Lightning data module configuration data: From 9e8eafefac2ee9e59a134ef87d23764b963e9079 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 13 Dec 2023 15:15:29 +0100 Subject: [PATCH 56/57] FIX test --- tests/use-cases/test_3dgan.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/use-cases/test_3dgan.py b/tests/use-cases/test_3dgan.py index 9a6eebdf..2373c76f 100644 --- a/tests/use-cases/test_3dgan.py +++ b/tests/use-cases/test_3dgan.py @@ -66,6 +66,4 @@ def test_3dgan_inference(install_requirements, fake_model_checkpoint): f'-o {data_params}.datapath="exp_data/*/*.h5" ' f'-o {saver_params}.save_dir=3dgan-generated-data ' ) - raise ValueError(cmd) - subprocess.run(cmd.split(), check=True) From c9b1c17bd18c76b462d91d8ed1dfea57cd221e80 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 13 Dec 2023 16:17:25 +0100 Subject: [PATCH 57/57] FIX 3dgan inference test --- tests/use-cases/test_3dgan.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/use-cases/test_3dgan.py b/tests/use-cases/test_3dgan.py index 2373c76f..10d6b46c 100644 --- a/tests/use-cases/test_3dgan.py +++ b/tests/use-cases/test_3dgan.py @@ -3,7 +3,7 @@ """ import pytest import subprocess -from itwinai.utils import dynamically_import_class +# from itwinai.utils import dynamically_import_class CERN_PATH = "use-cases/3dgan" CKPT_PATH = "3dgan-inference.pth" @@ -18,7 +18,7 @@ def fake_model_checkpoint() -> None: import torch sys.path.append(CERN_PATH) from model import ThreeDGAN - ThreeDGAN = dynamically_import_class('model.ThreeDGAN') + # ThreeDGAN = dynamically_import_class('model.ThreeDGAN') net = ThreeDGAN() torch.save(net, CKPT_PATH) @@ -61,9 +61,9 @@ def test_3dgan_inference(install_requirements, fake_model_checkpoint): 'itwinai exec-pipeline ' '--config use-cases/3dgan/inference-pipeline.yaml ' f'-o {getter_params}.data_path=exp_data ' - f'-o {trainer_params}.model.init_args.model_uri="{CKPT_PATH}" ' + f'-o {trainer_params}.model.init_args.model_uri={CKPT_PATH} ' f'-o {logger_params}.save_dir=ml_logs/mlflow_logs ' - f'-o {data_params}.datapath="exp_data/*/*.h5" ' + f'-o {data_params}.datapath=exp_data/*/*.h5 ' f'-o {saver_params}.save_dir=3dgan-generated-data ' ) subprocess.run(cmd.split(), check=True)