From 586c0957797d3d48bc1de0f4c967de11b0dc0559 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Thu, 24 Oct 2024 15:04:55 -0400 Subject: [PATCH 01/25] Add basic readme --- src/classification/README.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/classification/README.md diff --git a/src/classification/README.md b/src/classification/README.md new file mode 100644 index 0000000..7ac7055 --- /dev/null +++ b/src/classification/README.md @@ -0,0 +1,3 @@ +# Model Training + +Scripts for training a classification model from wedataset files. \ No newline at end of file From 916aa3cd3de6caff954fe3fab1a91730879f2441 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Thu, 24 Oct 2024 15:05:05 -0400 Subject: [PATCH 02/25] Fix basic formatting --- research/eccv2024/model_evaluation/utils.py | 3 ++- src/localization/training.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/research/eccv2024/model_evaluation/utils.py b/research/eccv2024/model_evaluation/utils.py index b962ffd..47547ed 100644 --- a/research/eccv2024/model_evaluation/utils.py +++ b/research/eccv2024/model_evaluation/utils.py @@ -6,9 +6,10 @@ import pathlib import PIL -import wandb from torchvision import transforms +import wandb + def download_model(artifact: str, model_dir: str): """Download the model from Weights and Biases""" diff --git a/src/localization/training.py b/src/localization/training.py index 6e091a1..69660ab 100644 --- a/src/localization/training.py +++ b/src/localization/training.py @@ -26,7 +26,6 @@ disable_beta_transforms_warning() import torchvision.transforms.v2 as T -import wandb from data.custom_datasets import SplitDataset, TrainingDataset from torch import nn, optim from torch.utils.data import DataLoader, random_split @@ -34,6 +33,8 @@ from tqdm import tqdm from utils import SupportedModels, bounding_box_to_tensor, load_model, set_random_seed +import wandb + SupportedLRSchedulers = tp.Literal["multisteplr", "cosineannealinglr"] From 4e30f9417b02feb6e7257dffd7e4390a4cbd6572 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 28 Oct 2024 12:41:56 -0400 Subject: [PATCH 03/25] Ignore W&B folder --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d9aa5ec..6eefb1e 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ cython_debug/ # Training related files **/*.out +wandb/ From cd1692f840c152b3a241be7994e38f0a21655d95 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 28 Oct 2024 15:11:26 -0400 Subject: [PATCH 04/25] Add classification tools module --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ce0a7f3..e870c24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ pytest = "^8.1.1" [tool.poetry.scripts] ami-dataset = "src.dataset_tools.cli:cli" +ami-classification = "src.classification.cli:cli" [build-system] From 7c091b94cc037bee2757324237e0513fca9df6cb Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 28 Oct 2024 15:23:27 -0400 Subject: [PATCH 05/25] Add basic cli --- src/classification/cli.py | 97 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 src/classification/cli.py diff --git a/src/classification/cli.py b/src/classification/cli.py new file mode 100644 index 0000000..dcee5e5 --- /dev/null +++ b/src/classification/cli.py @@ -0,0 +1,97 @@ +""" +Command Line Interface for the training module + +To add a new command, create a new function below following these instructions: + +- Create a new Command key constant + _CMD = "_cmd" + +- Add the new Command key constant to the COMMAND_KEYS frozenset + +- Add the command name and help text in the COMMANDS and COMMANDS_HELP dictionaries + +- Create a new function named _command(), alongside the appropriate + options + - Make sure to reuse appropriate options before creating duplicate ones. + - The contents of the CLI command should be minimal : execution of an imported + function + +- If unsure, take some time to look at how other commands have been implemented + +- Make sure to use lazy loading when importing modules that are only used by 1 command +""" + +import click + +# Command key constants +# Make sure to add them to COMMAND_KEYS frozenset +TRAIN_CMD = "train_cmd" + +# This is most useful to automatically test the CLI +COMMAND_KEYS = frozenset([TRAIN_CMD]) + +# Command dictionary +COMMANDS = { + TRAIN_CMD: "train-model", +} + +# Command help text dictionary +COMMANDS_HELP = {TRAIN_CMD: "Train a classification model"} + + +# # # # # # # +# Commands # +# # # # # # # + +# The order of declaration of the commands affect the order +# in which they appear in the CLI + + +# +# Train Model Command +# +@click.command( + name=COMMANDS[TRAIN_CMD], + help=COMMANDS_HELP[TRAIN_CMD], + context_settings={"show_default": True}, +) +@click.option( + "--random_seed", + type=int, + default=42, + help="Random seed for reproducibility", +) +def train_model_command(random_seed: int): + from src.classification.train_model import train_model + + train_model(random_seed=random_seed) + + +# # # # # # # # # # # # # # +# Main CLI configuration # +# # # # # # # # # # # # # # +class OrderCommands(click.Group): + """This class is necessary to order the commands the way we want to.""" + + def list_commands(self, ctx: click.Context) -> list[str]: + return list(self.commands) + + +@click.group(cls=OrderCommands) +def cli(): + """This is the main command line interface for the classification tools.""" + pass + + +# Following is an automated way to add all functions containing the word `command` +# in their name instead of manually having to add them. +all_objects = globals() +functions = [ + obj for name, obj in all_objects.items() if callable(obj) and "command" in name +] + +for command_function in functions: + cli.add_command(command_function) + +if __name__ == "__main__": + cli() From f24c24cfacf2dd82fc5c6905969d27a63afbadf4 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 28 Oct 2024 15:54:59 -0400 Subject: [PATCH 06/25] Test setting random seed --- src/classification/cli.py | 2 +- src/classification/train.py | 37 +++++++++ src/classification/utils.py | 144 ++++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 src/classification/train.py create mode 100644 src/classification/utils.py diff --git a/src/classification/cli.py b/src/classification/cli.py index dcee5e5..3ab2e35 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -62,7 +62,7 @@ help="Random seed for reproducibility", ) def train_model_command(random_seed: int): - from src.classification.train_model import train_model + from src.classification.train import train_model train_model(random_seed=random_seed) diff --git a/src/classification/train.py b/src/classification/train.py new file mode 100644 index 0000000..3e05580 --- /dev/null +++ b/src/classification/train.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# coding: utf-8 + + +""" Main script for training classification models +""" + +# package imports +import torch + +# 3rd party packages +from dotenv import load_dotenv + +from src.classification.utils import set_random_seeds + +# Load secrets and config from optional .env file +load_dotenv() + + +def train_model_one_epoch(): + """Training model for one epoch""" + + +def prepare_dataloader(): + """Returns the training, validation and test data loaders, + which have different transforms + (data augmentation is only applied on the training set) + """ + + +def train_model(random_seed: int) -> None: + """Main training function""" + + # Basic initialization + set_random_seeds(random_seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"The available device is {device}.") diff --git a/src/classification/utils.py b/src/classification/utils.py new file mode 100644 index 0000000..0e01554 --- /dev/null +++ b/src/classification/utils.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" Utility functions +""" + +import os +import random + +import numpy as np +import timm +import torch + +# import webdataset as wds + + +def set_random_seeds(random_seed: int) -> None: + """Set random seeds for reproducibility""" + + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + torch.backends.cudnn.deterministic = True + + +def model_builder(model_name: str, num_classes: int, pretrained: bool = True): + """Model builder""" + + if model_name == "timm_efficientnetv2-b3": + model = timm.create_model( + "tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_efficientnetv2-s-in21k": + model = timm.create_model( + "tf_efficientnetv2_s_in21k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_swin-s": + model = timm.create_model( + "swin_small_patch4_window7_224", + pretrained=pretrained, + num_classes=num_classes, + ) + + elif model_name == "timm_mobilenetv3large": + model = timm.create_model( + "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_resnet50": + model = timm.create_model( + "resnet50", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_convnext-t": + model = timm.create_model( + "convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_convnext-b": + model = timm.create_model( + "convnext_base_in22k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_vit-b16-128": + model = timm.create_model( + "vit_base_patch16_224_in21k", + pretrained=pretrained, + img_size=128, + num_classes=num_classes, + ) + elif model_name == "timm_vit-b16-224": + model = timm.create_model( + "vit_base_patch16_224_in21k", + pretrained=pretrained, + num_classes=num_classes, + ) + elif model_name == "timm_vit-b16-384": + model = timm.create_model( + "vit_base_patch16_384", + pretrained=pretrained, + num_classes=num_classes, + ) + else: + raise RuntimeError(f"Model {model_name} not implemented") + + return model + + +# def get_transforms(input_size: int, preprocess_mode: str, square_pad: bool): +# """Transformation applied to each image""" + +# if preprocess_mode == "torch": +# mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] +# elif preprocess_mode == "tf": +# mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] +# else: +# mean, std = [0.0, 0.0, 0.0], [1.0, 1.0, 1.0] + +# if square_pad: +# pass + + +def identity(x): + return x + + +# def webdataset_pipeline( +# sharedurl: str, +# input_size: int, +# batch_size: int, +# preprocess_mode: str, +# num_workers: int, +# square_pad: bool, +# is_training: bool = False, +# ) -> None: +# """Main dataset builder and loader function""" + +# # Load the webdataset +# if is_training: +# dataset = wds.WebDataset(sharedurl, shardshuffle=True) +# dataset = dataset.shuffle(10000) +# else: +# dataset = wds.WebDataset(sharedurl, shardshuffle=False) + +# # Get image transforms +# img_transform = get_transforms(input_size, preprocess_mode, square_pad) + +# # Decode dataset +# dataset = ( +# dataset.decode("pil").to_tuple("jpg", "cls").map_tuple(img_transform, identity) +# ) + +# loader = torch.utils.data.DataLoader( +# dataset, num_workers=num_workers, batch_size=batch_size +# ) + +# pass + + +def get_num_workers() -> int: + """Gets the optimal number of DatLoader workers to use in the current job.""" + + if "SLURM_CPUS_PER_TASK" in os.environ: + return int(os.environ["SLURM_CPUS_PER_TASK"]) + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return torch.multiprocessing.cpu_count() From b0448f0e5e95b6cd85096ab6803eddf888006fd8 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Tue, 29 Oct 2024 16:09:38 -0400 Subject: [PATCH 07/25] Model builder implemented --- src/classification/cli.py | 34 +++++++++++-- src/classification/models.py | 68 ++++++++++++++++++++++++++ src/classification/train.py | 32 +++++++----- src/classification/utils.py | 94 +++++++++++++++--------------------- 4 files changed, 158 insertions(+), 70 deletions(-) create mode 100644 src/classification/models.py diff --git a/src/classification/cli.py b/src/classification/cli.py index 3ab2e35..768a6ca 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -21,8 +21,13 @@ - Make sure to use lazy loading when importing modules that are only used by 1 command """ +import typing as tp +from typing import Optional + import click +from src.classification.utils import SupportedModels + # Command key constants # Make sure to add them to COMMAND_KEYS frozenset TRAIN_CMD = "train_cmd" @@ -61,10 +66,34 @@ default=42, help="Random seed for reproducibility", ) -def train_model_command(random_seed: int): +@click.option( + "--model_type", + type=click.Choice(tp.get_args(SupportedModels)), + required=True, + help="Model architecture", +) +@click.option( + "--num_classes", + type=int, + required=True, + help="Number of model's output classes", +) +@click.option( + "--existing_weights", + type=str, + help="Existing weights to be loaded, if available", +) +def train_model_command( + random_seed: int, model_type: str, num_classes: int, existing_weights: Optional[str] +): from src.classification.train import train_model - train_model(random_seed=random_seed) + train_model( + random_seed=random_seed, + model_type=model_type, + num_classes=num_classes, + existing_weights=existing_weights, + ) # # # # # # # # # # # # # # @@ -80,7 +109,6 @@ def list_commands(self, ctx: click.Context) -> list[str]: @click.group(cls=OrderCommands) def cli(): """This is the main command line interface for the classification tools.""" - pass # Following is an automated way to add all functions containing the word `command` diff --git a/src/classification/models.py b/src/classification/models.py new file mode 100644 index 0000000..e6bf46e --- /dev/null +++ b/src/classification/models.py @@ -0,0 +1,68 @@ +""" List of available models to train +""" + +import timm +import torch +from torchvision import models + + +def model_list(model_name: str, num_classes: int, pretrained: bool): + """Main model builder function""" + + if model_name == "efficientnetv2-b3": + model = timm.create_model( + "tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "efficientnetv2-s-in21k": + model = timm.create_model( + "tf_efficientnetv2_s_in21k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "swin-s": + model = timm.create_model( + "swin_small_patch4_window7_224", + pretrained=pretrained, + num_classes=num_classes, + ) + elif model_name == "resnet50": + model = models.resnet50(weights="IMAGENET1K_V1" if pretrained else None) + num_ftrs = model.fc.in_features + model.fc = torch.nn.Linear(num_ftrs, num_classes) + elif model_name == "timm_mobilenetv3large": + model = timm.create_model( + "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_resnet50": + model = timm.create_model( + "resnet50", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_convnext-t": + model = timm.create_model( + "convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_convnext-b": + model = timm.create_model( + "convnext_base_in22k", pretrained=pretrained, num_classes=num_classes + ) + elif model_name == "timm_vit-b16-128": + model = timm.create_model( + "vit_base_patch16_224_in21k", + pretrained=pretrained, + img_size=128, + num_classes=num_classes, + ) + elif model_name == "timm_vit-b16-224": + model = timm.create_model( + "vit_base_patch16_224_in21k", + pretrained=pretrained, + num_classes=num_classes, + ) + elif model_name == "timm_vit-b16-384": + model = timm.create_model( + "vit_base_patch16_384", + pretrained=pretrained, + num_classes=num_classes, + ) + else: + raise RuntimeError(f"Model {model_name} not implemented") + + return model diff --git a/src/classification/train.py b/src/classification/train.py index 3e05580..60f50fa 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -4,21 +4,12 @@ """ Main script for training classification models """ +from typing import Optional # package imports import torch -# 3rd party packages -from dotenv import load_dotenv - -from src.classification.utils import set_random_seeds - -# Load secrets and config from optional .env file -load_dotenv() - - -def train_model_one_epoch(): - """Training model for one epoch""" +from src.classification.utils import model_builder, set_random_seeds def prepare_dataloader(): @@ -28,10 +19,25 @@ def prepare_dataloader(): """ -def train_model(random_seed: int) -> None: +def train_model_one_epoch(): + """Training model for one epoch""" + + +def train_model( + random_seed: int, model_type: str, num_classes: int, existing_weights: Optional[str] +) -> None: """Main training function""" - # Basic initialization + # Set random seeds set_random_seeds(random_seed) + + # Model initialization device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The available device is {device}.") + model = model_builder(device, model_type, num_classes, existing_weights) + print(model) + + # Setup dataloaders + # train_dataloader = ... + # val_dataloader = ... + # test_dataloader = ... diff --git a/src/classification/utils.py b/src/classification/utils.py index 0e01554..1e9e5bb 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -6,11 +6,15 @@ import os import random +import typing as tp import numpy as np -import timm + +# import timm import torch +from src.classification.models import model_list + # import webdataset as wds @@ -24,61 +28,42 @@ def set_random_seeds(random_seed: int) -> None: torch.backends.cudnn.deterministic = True -def model_builder(model_name: str, num_classes: int, pretrained: bool = True): +SupportedModels = tp.Literal[ + "efficientnetv2-b3", + "efficientnetv2-s-in21k", + "swin-s", + "resnet50", + "timm_mobilenetv3large", + "timm_resnet50", + "timm_convnext-t", + "timm_convnext-b", + "timm_vit-b16-128", + "timm_vit-b16-224", + "timm_vit-b16-384", +] + + +def model_builder( + device: str, + model_type: str, + num_classes: int, + existing_weights: tp.Optional[str], + pretrained: bool = True, +): """Model builder""" - if model_name == "timm_efficientnetv2-b3": - model = timm.create_model( - "tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_efficientnetv2-s-in21k": - model = timm.create_model( - "tf_efficientnetv2_s_in21k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_swin-s": - model = timm.create_model( - "swin_small_patch4_window7_224", - pretrained=pretrained, - num_classes=num_classes, - ) - - elif model_name == "timm_mobilenetv3large": - model = timm.create_model( - "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_resnet50": - model = timm.create_model( - "resnet50", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_convnext-t": - model = timm.create_model( - "convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_convnext-b": - model = timm.create_model( - "convnext_base_in22k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_vit-b16-128": - model = timm.create_model( - "vit_base_patch16_224_in21k", - pretrained=pretrained, - img_size=128, - num_classes=num_classes, - ) - elif model_name == "timm_vit-b16-224": - model = timm.create_model( - "vit_base_patch16_224_in21k", - pretrained=pretrained, - num_classes=num_classes, - ) - elif model_name == "timm_vit-b16-384": - model = timm.create_model( - "vit_base_patch16_384", - pretrained=pretrained, - num_classes=num_classes, - ) - else: - raise RuntimeError(f"Model {model_name} not implemented") + model = model_list(model_type, num_classes, pretrained) + + # If available, load existing weights + if existing_weights: + print("Loading existing model weights.") + state_dict = torch.load(existing_weights, map_location=torch.device(device)) + model.load_state_dict(state_dict, strict=False) + + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + model = model.to(device) return model @@ -98,6 +83,7 @@ def model_builder(model_name: str, num_classes: int, pretrained: bool = True): def identity(x): + """Identity function""" return x From 640a87b80e6d97c19092be751d1721bf17ebe3db Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Wed, 30 Oct 2024 12:49:35 -0400 Subject: [PATCH 08/25] Move model building fn --- src/classification/models.py | 29 ++++++++++- src/classification/train.py | 10 ++-- src/classification/utils.py | 94 ------------------------------------ 3 files changed, 33 insertions(+), 100 deletions(-) diff --git a/src/classification/models.py b/src/classification/models.py index e6bf46e..215d5e3 100644 --- a/src/classification/models.py +++ b/src/classification/models.py @@ -1,12 +1,14 @@ """ List of available models to train """ +import typing as tp + import timm import torch from torchvision import models -def model_list(model_name: str, num_classes: int, pretrained: bool): +def _model_list(model_name: str, num_classes: int, pretrained: bool): """Main model builder function""" if model_name == "efficientnetv2-b3": @@ -66,3 +68,28 @@ def model_list(model_name: str, num_classes: int, pretrained: bool): raise RuntimeError(f"Model {model_name} not implemented") return model + + +def model_builder( + device: str, + model_type: str, + num_classes: int, + existing_weights: tp.Optional[str], + pretrained: bool = True, +): + """Model builder""" + + model = _model_list(model_type, num_classes, pretrained) + + # If available, load existing weights + if existing_weights: + print("Loading existing model weights.") + state_dict = torch.load(existing_weights, map_location=torch.device(device)) + model.load_state_dict(state_dict, strict=False) + + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + model = model.to(device) + + return model diff --git a/src/classification/train.py b/src/classification/train.py index 60f50fa..80208d9 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -4,12 +4,14 @@ """ Main script for training classification models """ + from typing import Optional -# package imports import torch -from src.classification.utils import model_builder, set_random_seeds +# from src.classification.dataloader import webdataset_pipeline +from src.classification.models import model_builder +from src.classification.utils import set_random_seeds def prepare_dataloader(): @@ -38,6 +40,4 @@ def train_model( print(model) # Setup dataloaders - # train_dataloader = ... - # val_dataloader = ... - # test_dataloader = ... + # train_data = webdataset_pipeline() diff --git a/src/classification/utils.py b/src/classification/utils.py index 1e9e5bb..dcc3461 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -4,19 +4,12 @@ """ Utility functions """ -import os import random import typing as tp import numpy as np - -# import timm import torch -from src.classification.models import model_list - -# import webdataset as wds - def set_random_seeds(random_seed: int) -> None: """Set random seeds for reproducibility""" @@ -41,90 +34,3 @@ def set_random_seeds(random_seed: int) -> None: "timm_vit-b16-224", "timm_vit-b16-384", ] - - -def model_builder( - device: str, - model_type: str, - num_classes: int, - existing_weights: tp.Optional[str], - pretrained: bool = True, -): - """Model builder""" - - model = model_list(model_type, num_classes, pretrained) - - # If available, load existing weights - if existing_weights: - print("Loading existing model weights.") - state_dict = torch.load(existing_weights, map_location=torch.device(device)) - model.load_state_dict(state_dict, strict=False) - - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) - - model = model.to(device) - - return model - - -# def get_transforms(input_size: int, preprocess_mode: str, square_pad: bool): -# """Transformation applied to each image""" - -# if preprocess_mode == "torch": -# mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] -# elif preprocess_mode == "tf": -# mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] -# else: -# mean, std = [0.0, 0.0, 0.0], [1.0, 1.0, 1.0] - -# if square_pad: -# pass - - -def identity(x): - """Identity function""" - return x - - -# def webdataset_pipeline( -# sharedurl: str, -# input_size: int, -# batch_size: int, -# preprocess_mode: str, -# num_workers: int, -# square_pad: bool, -# is_training: bool = False, -# ) -> None: -# """Main dataset builder and loader function""" - -# # Load the webdataset -# if is_training: -# dataset = wds.WebDataset(sharedurl, shardshuffle=True) -# dataset = dataset.shuffle(10000) -# else: -# dataset = wds.WebDataset(sharedurl, shardshuffle=False) - -# # Get image transforms -# img_transform = get_transforms(input_size, preprocess_mode, square_pad) - -# # Decode dataset -# dataset = ( -# dataset.decode("pil").to_tuple("jpg", "cls").map_tuple(img_transform, identity) -# ) - -# loader = torch.utils.data.DataLoader( -# dataset, num_workers=num_workers, batch_size=batch_size -# ) - -# pass - - -def get_num_workers() -> int: - """Gets the optimal number of DatLoader workers to use in the current job.""" - - if "SLURM_CPUS_PER_TASK" in os.environ: - return int(os.environ["SLURM_CPUS_PER_TASK"]) - if hasattr(os, "sched_getaffinity"): - return len(os.sched_getaffinity(0)) - return torch.multiprocessing.cpu_count() From fc9e0cb05f5fdafc16011f785ca6adfb44eb8006 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 1 Nov 2024 17:10:24 -0400 Subject: [PATCH 09/25] Basic dataloader complete (w/ error) --- src/classification/dataloader.py | 128 +++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 src/classification/dataloader.py diff --git a/src/classification/dataloader.py b/src/classification/dataloader.py new file mode 100644 index 0000000..cc8cae8 --- /dev/null +++ b/src/classification/dataloader.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" Functions related to dataset loading and image transformations +""" + +import os +from functools import partial +from typing import Any + +import numpy as np +import PIL +import torch +import webdataset as wds +from torchvision import transforms + + +def _pad_to_square(image: PIL.Image.Image) -> PIL.Image.Image: + """Padding transformation to make the image square""" + + width, height = image.size + if height < width: + transform = transforms.Pad(padding=[0, 0, 0, width - height]) + elif height > width: + transform = transforms.Pad(padding=[0, 0, height - width, 0]) + else: + transform = transforms.Pad(padding=[0, 0, 0, 0]) + + return transform(image) + + +def _normalization(preprocess_mode: str) -> tuple[list[float], list[float]]: + """Get the mean and std for normalization""" + + if preprocess_mode == "torch": + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + elif preprocess_mode == "tf": + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + else: + mean, std = [0.0, 0.0, 0.0], [1.0, 1.0, 1.0] + + return mean, std + + +def _random_resize(image: PIL.Image.Image, full_size: int) -> PIL.Image.Image: + random_num = np.random.uniform() + if random_num <= 0.25: + transform = transforms.Resize((int(0.5 * full_size), int(0.5 * full_size))) + image = transform(image) + elif random_num <= 0.5: + transform = transforms.Resize((int(0.25 * full_size), int(0.25 * full_size))) + image = transform(image) + + return image + + +# TODO: Add return type of this function +def _get_transforms(input_size: int, is_training: bool, preprocess_mode: str = "torch"): + """Transformation applied to each image""" + + # Add square padding + final_transforms = [transforms.Lambda(_pad_to_square)] + + if is_training: + f_random_resize = partial(_random_resize, input_size) + final_transforms += [ + transforms.Lambda(f_random_resize), # mixed resolution + transforms.RandomResizedCrop(input_size, scale=(0.3, 1)), + transforms.RandomHorizontalFlip(), + transforms.RandAugment(num_ops=2, magnitude=9), + ] + else: + final_transforms += [transforms.Resize((input_size, input_size))] + + # Normalization + mean, std = _normalization(preprocess_mode) + final_transforms += [transforms.ToTensor(), transforms.Normalize(mean, std)] + + return transforms.Compose(final_transforms) + + +def build_webdataset_pipeline( + sharedurl: str, + input_size: int, + batch_size: int, + preprocess_mode: str, + is_training: bool = False, +) -> torch.utils.data.DataLoader: + """Main dataset builder and loader function""" + + # Load the webdataset + if is_training: + dataset = wds.WebDataset(sharedurl, shardshuffle=True) + dataset = dataset.shuffle(10000) + else: + dataset = wds.WebDataset(sharedurl, shardshuffle=False) + + # Get image transforms + image_transform = _get_transforms(input_size, is_training, preprocess_mode) + + # Decode dataset + dataset_decoded = ( + dataset.decode("pil") + .to_tuple("jpg", "cls") + .map_tuple(image_transform, _identity) + ) + + # Create dataLoader + dataset_loader = torch.utils.data.DataLoader( + dataset_decoded, num_workers=_get_num_workers(), batch_size=batch_size + ) + + return dataset_loader + + +def _identity(x: Any) -> Any: + """Identity function""" + return x + + +def _get_num_workers() -> int: + """Gets the optimal number of dataloader workers to use in the current job""" + + if "SLURM_CPUS_PER_TASK" in os.environ: + return int(os.environ["SLURM_CPUS_PER_TASK"]) + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return torch.multiprocessing.cpu_count() From e1567ca23444509f9efdf108a0ec3a3c3b65ed8e Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 1 Nov 2024 17:10:47 -0400 Subject: [PATCH 10/25] Add more CLI arguments --- src/classification/cli.py | 53 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index 768a6ca..80e4ca2 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -83,8 +83,53 @@ type=str, help="Existing weights to be loaded, if available", ) +@click.option( + "--train_webdataset", + type=str, + required=True, + help="Webdataset files for the training set", +) +@click.option( + "--val_webdataset", + type=str, + required=True, + help="Webdataset files for the validation set", +) +@click.option( + "--test_webdataset", + type=str, + required=True, + help="Webdataset files for the test set", +) +@click.option( + "--image_input_size", + type=int, + default=128, + help="Image input size for training and inference", +) +@click.option( + "--batch_size", + type=int, + default=32, + help="Batch size for training", +) +@click.option( + "--preprocess_mode", + type=click.Choice(["torch", "tf", "other"]), + default="torch", + help="Preprocessing mode for normalization", +) def train_model_command( - random_seed: int, model_type: str, num_classes: int, existing_weights: Optional[str] + random_seed: int, + model_type: str, + num_classes: int, + existing_weights: Optional[str], + train_webdataset: str, + val_webdataset: str, + test_webdataset: str, + image_input_size: int, + batch_size: int, + preprocess_mode: str, ): from src.classification.train import train_model @@ -93,6 +138,12 @@ def train_model_command( model_type=model_type, num_classes=num_classes, existing_weights=existing_weights, + train_webdataset=train_webdataset, + val_webdataset=val_webdataset, + test_webdataset=test_webdataset, + image_input_size=image_input_size, + batch_size=batch_size, + preprocess_mode=preprocess_mode, ) From 46a0332e40ed312ccc451e6b13a714590e43e0cc Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 4 Nov 2024 13:32:43 -0500 Subject: [PATCH 11/25] Dataloader implemented and tested --- src/classification/dataloader.py | 4 +++- src/classification/train.py | 37 ++++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/classification/dataloader.py b/src/classification/dataloader.py index cc8cae8..904e06f 100644 --- a/src/classification/dataloader.py +++ b/src/classification/dataloader.py @@ -43,6 +43,8 @@ def _normalization(preprocess_mode: str) -> tuple[list[float], list[float]]: def _random_resize(image: PIL.Image.Image, full_size: int) -> PIL.Image.Image: + """Mixed resolution transformation""" + random_num = np.random.uniform() if random_num <= 0.25: transform = transforms.Resize((int(0.5 * full_size), int(0.5 * full_size))) @@ -62,7 +64,7 @@ def _get_transforms(input_size: int, is_training: bool, preprocess_mode: str = " final_transforms = [transforms.Lambda(_pad_to_square)] if is_training: - f_random_resize = partial(_random_resize, input_size) + f_random_resize = partial(_random_resize, full_size=input_size) final_transforms += [ transforms.Lambda(f_random_resize), # mixed resolution transforms.RandomResizedCrop(input_size, scale=(0.3, 1)), diff --git a/src/classification/train.py b/src/classification/train.py index 80208d9..78adfde 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -9,24 +9,26 @@ import torch -# from src.classification.dataloader import webdataset_pipeline +from src.classification.dataloader import build_webdataset_pipeline from src.classification.models import model_builder from src.classification.utils import set_random_seeds -def prepare_dataloader(): - """Returns the training, validation and test data loaders, - which have different transforms - (data augmentation is only applied on the training set) - """ - - -def train_model_one_epoch(): +def _train_model_for_one_epoch(): """Training model for one epoch""" def train_model( - random_seed: int, model_type: str, num_classes: int, existing_weights: Optional[str] + random_seed: int, + model_type: str, + num_classes: int, + existing_weights: Optional[str], + train_webdataset: str, + val_webdataset: str, + test_webdataset: str, + image_input_size: int, + batch_size: int, + preprocess_mode: str, ) -> None: """Main training function""" @@ -40,4 +42,17 @@ def train_model( print(model) # Setup dataloaders - # train_data = webdataset_pipeline() + training_dataloader = build_webdataset_pipeline( + train_webdataset, + image_input_size, + batch_size, + preprocess_mode, + is_training=True, + ) + validation_dataloader = build_webdataset_pipeline( + val_webdataset, image_input_size, batch_size, preprocess_mode + ) + test_dataloader = build_webdataset_pipeline( + test_webdataset, image_input_size, batch_size, preprocess_mode + ) + print(training_dataloader, validation_dataloader, test_dataloader) From f091d1da76328f8de5c58552163ca741f5b87176 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 4 Nov 2024 16:52:53 -0500 Subject: [PATCH 12/25] Add optimizer --- src/classification/cli.py | 24 ++++++++++++++++++++++++ src/classification/train.py | 28 ++++++++++++++++++++++++---- src/classification/utils.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index 80e4ca2..b361e88 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -119,6 +119,24 @@ default="torch", help="Preprocessing mode for normalization", ) +@click.option( + "--optimizer", + type=click.Choice(["adamw", "sgd"]), + default="adamw", + help="Optimizer type", +) +@click.option( + "--learning_rate", + type=float, + default=0.001, + help="Initial learning rate", +) +@click.option( + "--weight_decay", + type=float, + default=1e-5, + help="Weight decay for regularization", +) def train_model_command( random_seed: int, model_type: str, @@ -130,6 +148,9 @@ def train_model_command( image_input_size: int, batch_size: int, preprocess_mode: str, + optimizer_type: str, + learning_rate: float, + weight_decay: float, ): from src.classification.train import train_model @@ -144,6 +165,9 @@ def train_model_command( image_input_size=image_input_size, batch_size=batch_size, preprocess_mode=preprocess_mode, + optimizer_type=optimizer_type, + learning_rate=learning_rate, + weight_decay=weight_decay, ) diff --git a/src/classification/train.py b/src/classification/train.py index 78adfde..21eb173 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -11,10 +11,14 @@ from src.classification.dataloader import build_webdataset_pipeline from src.classification.models import model_builder -from src.classification.utils import set_random_seeds +from src.classification.utils import ( + get_learning_rate_scheduler, + get_optimizer, + set_random_seeds, +) -def _train_model_for_one_epoch(): +def _train_model_for_one_epoch() -> None: """Training model for one epoch""" @@ -29,6 +33,9 @@ def train_model( image_input_size: int, batch_size: int, preprocess_mode: str, + optimizer_type: str, + learning_rate: float, + weight_decay: float, ) -> None: """Main training function""" @@ -39,7 +46,6 @@ def train_model( device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The available device is {device}.") model = model_builder(device, model_type, num_classes, existing_weights) - print(model) # Setup dataloaders training_dataloader = build_webdataset_pipeline( @@ -55,4 +61,18 @@ def train_model( test_dataloader = build_webdataset_pipeline( test_webdataset, image_input_size, batch_size, preprocess_mode ) - print(training_dataloader, validation_dataloader, test_dataloader) + + # Other training ingredients + optimizer = get_optimizer(optimizer_type, model, learning_rate, weight_decay) + learning_rate_scheduler = get_learning_rate_scheduler() + # loss = ... + print( + optimizer, + training_dataloader, + validation_dataloader, + test_dataloader, + learning_rate_scheduler, + ) + + # Model training + _train_model_for_one_epoch() diff --git a/src/classification/utils.py b/src/classification/utils.py index dcc3461..8c81ef2 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -34,3 +34,32 @@ def set_random_seeds(random_seed: int) -> None: "timm_vit-b16-224", "timm_vit-b16-384", ] + + +def get_optimizer( + optimizer_type: str, + model: torch.nn.Module, + learning_rate: float, + weight_decay: float, + momentum: float = 0.9, +) -> torch.optim.Optimizer: + """Optimizer definitions""" + + if optimizer_type == "adamw": + return torch.optim.AdamW( + model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + elif optimizer_type == "sgd": + return torch.optim.SGD( + model.parameters(), + lr=learning_rate, + momentum=momentum, + weight_decay=weight_decay, + ) + else: + raise RuntimeError(f"{optimizer_type} optimizer is not implemented.") + + +def get_learning_rate_scheduler() -> None: + """Scheduler definitions""" + pass From b99a678f02675815744950a22ae63310545c5ab0 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 4 Nov 2024 17:12:59 -0500 Subject: [PATCH 13/25] Add a constants file --- src/classification/constants.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/classification/constants.py diff --git a/src/classification/constants.py b/src/classification/constants.py new file mode 100644 index 0000000..1c9539b --- /dev/null +++ b/src/classification/constants.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# coding: utf-8 + +"""Constants to be used within the project +""" + +EFFICIENTNETV2_B3 = "efficientnetv2_b3" +EFFICIENTNETV2_S_IN21K = "efficientnetv2_s_in21k" +SWIN_S = "swin_small_patch4_window7_224" +RESNET50 = "resnet50" +TIMM_MOBILENETV3LARGE = "mobilenetv3_large_100" +TIMM_RESNET50 = "resnet50" +TIMM_CONVNEXT_T = "convnext_tiny_in22k" +TIMM_CONVNEXT_B = "convnext_base_in22k" +TIMM_VIT_B16_128 = "vit_base_patch16_224_in21k" +TIMM_VIT_B16_224 = "vit_base_patch16_224_in21k" +TIMM_VIT_B16_384 = "vit_base_patch16_224_in21k" + +AVAILABLE_MODELS = frozenset( + [ + EFFICIENTNETV2_B3, + EFFICIENTNETV2_S_IN21K, + SWIN_S, + RESNET50, + TIMM_MOBILENETV3LARGE, + TIMM_RESNET50, + TIMM_CONVNEXT_T, + TIMM_CONVNEXT_B, + TIMM_VIT_B16_128, + TIMM_VIT_B16_224, + TIMM_VIT_B16_384, + ] +) From a1b51af5935cd619c998f88bf2cee56f341204c7 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 4 Nov 2024 17:13:31 -0500 Subject: [PATCH 14/25] Add header --- src/classification/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/classification/models.py b/src/classification/models.py index 215d5e3..60897d4 100644 --- a/src/classification/models.py +++ b/src/classification/models.py @@ -1,4 +1,7 @@ -""" List of available models to train +#!/usr/bin/env python +# coding: utf-8 + +""" Model-related functions """ import typing as tp From 72e536d3404909b566378d87e0e7df2963838a26 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Tue, 5 Nov 2024 17:10:48 -0500 Subject: [PATCH 15/25] Change model building method --- src/classification/cli.py | 4 +- src/classification/constants.py | 32 +++++------ src/classification/models.py | 98 --------------------------------- src/classification/train.py | 38 ++++++++----- src/classification/utils.py | 53 +++++++++++++----- 5 files changed, 79 insertions(+), 146 deletions(-) delete mode 100644 src/classification/models.py diff --git a/src/classification/cli.py b/src/classification/cli.py index b361e88..a205833 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -26,7 +26,9 @@ import click -from src.classification.utils import SupportedModels +from src.classification.constants import AVAILABLE_MODELS + +SupportedModels = tp.Literal[*AVAILABLE_MODELS] # Command key constants # Make sure to add them to COMMAND_KEYS frozenset diff --git a/src/classification/constants.py b/src/classification/constants.py index 1c9539b..12310e7 100644 --- a/src/classification/constants.py +++ b/src/classification/constants.py @@ -1,33 +1,31 @@ #!/usr/bin/env python # coding: utf-8 -"""Constants to be used within the project +"""Constants accessible in the entire project """ -EFFICIENTNETV2_B3 = "efficientnetv2_b3" -EFFICIENTNETV2_S_IN21K = "efficientnetv2_s_in21k" +EFFICIENTNETV2_B3 = "tf_efficientnetv2_b3" +EFFICIENTNETV2_S_IN21K = "tf_efficientnetv2_s_in21k" SWIN_S = "swin_small_patch4_window7_224" +MOBILENETV3LARGE = "mobilenetv3_large_100" RESNET50 = "resnet50" -TIMM_MOBILENETV3LARGE = "mobilenetv3_large_100" -TIMM_RESNET50 = "resnet50" -TIMM_CONVNEXT_T = "convnext_tiny_in22k" -TIMM_CONVNEXT_B = "convnext_base_in22k" -TIMM_VIT_B16_128 = "vit_base_patch16_224_in21k" -TIMM_VIT_B16_224 = "vit_base_patch16_224_in21k" -TIMM_VIT_B16_384 = "vit_base_patch16_224_in21k" +CONVNEXT_T = "convnext_tiny_in22k" +CONVNEXT_B = "convnext_base_in22k" +VIT_B16_128 = "vit_base_patch16_128_in21k" +VIT_B16_224 = "vit_base_patch16_224_in21k" +VIT_B16_384 = "vit_base_patch16_384" AVAILABLE_MODELS = frozenset( [ EFFICIENTNETV2_B3, EFFICIENTNETV2_S_IN21K, SWIN_S, + MOBILENETV3LARGE, RESNET50, - TIMM_MOBILENETV3LARGE, - TIMM_RESNET50, - TIMM_CONVNEXT_T, - TIMM_CONVNEXT_B, - TIMM_VIT_B16_128, - TIMM_VIT_B16_224, - TIMM_VIT_B16_384, + CONVNEXT_T, + CONVNEXT_B, + VIT_B16_128, + VIT_B16_224, + VIT_B16_384, ] ) diff --git a/src/classification/models.py b/src/classification/models.py deleted file mode 100644 index 60897d4..0000000 --- a/src/classification/models.py +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -""" Model-related functions -""" - -import typing as tp - -import timm -import torch -from torchvision import models - - -def _model_list(model_name: str, num_classes: int, pretrained: bool): - """Main model builder function""" - - if model_name == "efficientnetv2-b3": - model = timm.create_model( - "tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "efficientnetv2-s-in21k": - model = timm.create_model( - "tf_efficientnetv2_s_in21k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "swin-s": - model = timm.create_model( - "swin_small_patch4_window7_224", - pretrained=pretrained, - num_classes=num_classes, - ) - elif model_name == "resnet50": - model = models.resnet50(weights="IMAGENET1K_V1" if pretrained else None) - num_ftrs = model.fc.in_features - model.fc = torch.nn.Linear(num_ftrs, num_classes) - elif model_name == "timm_mobilenetv3large": - model = timm.create_model( - "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_resnet50": - model = timm.create_model( - "resnet50", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_convnext-t": - model = timm.create_model( - "convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_convnext-b": - model = timm.create_model( - "convnext_base_in22k", pretrained=pretrained, num_classes=num_classes - ) - elif model_name == "timm_vit-b16-128": - model = timm.create_model( - "vit_base_patch16_224_in21k", - pretrained=pretrained, - img_size=128, - num_classes=num_classes, - ) - elif model_name == "timm_vit-b16-224": - model = timm.create_model( - "vit_base_patch16_224_in21k", - pretrained=pretrained, - num_classes=num_classes, - ) - elif model_name == "timm_vit-b16-384": - model = timm.create_model( - "vit_base_patch16_384", - pretrained=pretrained, - num_classes=num_classes, - ) - else: - raise RuntimeError(f"Model {model_name} not implemented") - - return model - - -def model_builder( - device: str, - model_type: str, - num_classes: int, - existing_weights: tp.Optional[str], - pretrained: bool = True, -): - """Model builder""" - - model = _model_list(model_type, num_classes, pretrained) - - # If available, load existing weights - if existing_weights: - print("Loading existing model weights.") - state_dict = torch.load(existing_weights, map_location=torch.device(device)) - model.load_state_dict(state_dict, strict=False) - - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) - - model = model.to(device) - - return model diff --git a/src/classification/train.py b/src/classification/train.py index 21eb173..d4c76b9 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -10,12 +10,7 @@ import torch from src.classification.dataloader import build_webdataset_pipeline -from src.classification.models import model_builder -from src.classification.utils import ( - get_learning_rate_scheduler, - get_optimizer, - set_random_seeds, -) +from src.classification.utils import build_model, get_optimizer, set_random_seeds def _train_model_for_one_epoch() -> None: @@ -45,7 +40,8 @@ def train_model( # Model initialization device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The available device is {device}.") - model = model_builder(device, model_type, num_classes, existing_weights) + model = build_model(device, model_type, num_classes, existing_weights) + print(model) # Setup dataloaders training_dataloader = build_webdataset_pipeline( @@ -64,15 +60,27 @@ def train_model( # Other training ingredients optimizer = get_optimizer(optimizer_type, model, learning_rate, weight_decay) - learning_rate_scheduler = get_learning_rate_scheduler() + # learning_rate_scheduler = get_learning_rate_scheduler() # loss = ... - print( - optimizer, - training_dataloader, - validation_dataloader, - test_dataloader, - learning_rate_scheduler, - ) + print(optimizer, training_dataloader, validation_dataloader, test_dataloader) # Model training _train_model_for_one_epoch() + + +if __name__ == "__main__": + train_model( + random_seed=42, + model_type="vit_base_patch16_128_in21k", + num_classes=29176, + existing_weights=None, + train_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-000000.tar", + val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", + test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", + image_input_size=128, + batch_size=1, + preprocess_mode="torch", + optimizer_type="adamw", + learning_rate=0.001, + weight_decay=1e-5, + ) diff --git a/src/classification/utils.py b/src/classification/utils.py index 8c81ef2..d5914e0 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -8,8 +8,11 @@ import typing as tp import numpy as np +import timm import torch +from src.classification.constants import AVAILABLE_MODELS, VIT_B16_128 + def set_random_seeds(random_seed: int) -> None: """Set random seeds for reproducibility""" @@ -21,21 +24,6 @@ def set_random_seeds(random_seed: int) -> None: torch.backends.cudnn.deterministic = True -SupportedModels = tp.Literal[ - "efficientnetv2-b3", - "efficientnetv2-s-in21k", - "swin-s", - "resnet50", - "timm_mobilenetv3large", - "timm_resnet50", - "timm_convnext-t", - "timm_convnext-b", - "timm_vit-b16-128", - "timm_vit-b16-224", - "timm_vit-b16-384", -] - - def get_optimizer( optimizer_type: str, model: torch.nn.Module, @@ -63,3 +51,38 @@ def get_optimizer( def get_learning_rate_scheduler() -> None: """Scheduler definitions""" pass + + +def build_model( + device: str, + model_type: str, + num_classes: int, + existing_weights: tp.Optional[str], + pretrained: bool = True, +): + """Model builder""" + + if model_type not in AVAILABLE_MODELS: + raise RuntimeError(f"Model {model_type} not implemented") + + model_arguments = {"pretrained": pretrained, "num_classes": num_classes} + if model_type == VIT_B16_128: + # There is no off-the-shelf ViT model for 128x128 image size, + # so we use 224x224 model with a custom input image size + model_type = "vit_base_patch16_224_in21k" + model_arguments["img_size"] = 128 + + model = timm.create_model(model_type, **model_arguments) + + # If available, load existing weights + if existing_weights: + print("Loading existing model weights.") + state_dict = torch.load(existing_weights, map_location=torch.device(device)) + model.load_state_dict(state_dict, strict=False) + + # Make use of multiple GPUs, if available + if torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + model = model.to(device) + + return model From 49c7672c9982070f8df3ec37cc6c15bc91db9eab Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Mon, 11 Nov 2024 17:19:05 -0500 Subject: [PATCH 16/25] Optimizer added; to test scheduler --- src/classification/cli.py | 57 +++++++++++++++++++++++++++++--- src/classification/constants.py | 14 ++++++++ src/classification/dataloader.py | 4 ++- src/classification/train.py | 53 +++++++++++++++++------------ src/classification/utils.py | 45 ++++++++++++++++++++++--- 5 files changed, 141 insertions(+), 32 deletions(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index a205833..3c01bd6 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -26,9 +26,20 @@ import click -from src.classification.constants import AVAILABLE_MODELS +from src.classification.constants import ( + ADAMW, + AVAILABLE_LOSS_FUNCIONS, + AVAILABLE_LR_SCHEDULERS, + AVAILABLE_MODELS, + AVAILABLE_OPTIMIZERS, + COSINE_LR_SCHEDULER, + CROSS_ENTROPY_LOSS, +) SupportedModels = tp.Literal[*AVAILABLE_MODELS] +SupportedLossFunctions = tp.Literal[*AVAILABLE_LOSS_FUNCIONS] +SupportedOptimizers = tp.Literal[*AVAILABLE_OPTIMIZERS] +SupportedLearningRateSchedulers = tp.Literal[*AVAILABLE_LR_SCHEDULERS] # Command key constants # Make sure to add them to COMMAND_KEYS frozenset @@ -83,8 +94,16 @@ @click.option( "--existing_weights", type=str, + default=None, help="Existing weights to be loaded, if available", ) +@click.option( + "--total_epochs", + type=int, + default=30, + help="Total number of training epochs", +) +@click.option("--warmup_epochs", type=int, default=2, help="Number of warmup epochs") @click.option( "--train_webdataset", type=str, @@ -112,7 +131,7 @@ @click.option( "--batch_size", type=int, - default=32, + default=64, help="Batch size for training", ) @click.option( @@ -122,9 +141,9 @@ help="Preprocessing mode for normalization", ) @click.option( - "--optimizer", - type=click.Choice(["adamw", "sgd"]), - default="adamw", + "--optimizer_type", + type=click.Choice(tp.get_args(SupportedOptimizers)), + default=ADAMW, help="Optimizer type", ) @click.option( @@ -133,17 +152,37 @@ default=0.001, help="Initial learning rate", ) +@click.option( + "--learning_rate_scheduler_type", + type=click.Choice(tp.get_args(SupportedLearningRateSchedulers)), + default=COSINE_LR_SCHEDULER, + help="Learning rate scheduler", +) @click.option( "--weight_decay", type=float, default=1e-5, help="Weight decay for regularization", ) +@click.option( + "--loss_function_type", + type=click.Choice(tp.get_args(SupportedLossFunctions)), + default=CROSS_ENTROPY_LOSS, + help="Loss function", +) +@click.option( + "--label_smoothing", + type=float, + default=0.1, + help="Label smoothing for model regularization. No smoothing if 0.0", +) def train_model_command( random_seed: int, model_type: str, num_classes: int, existing_weights: Optional[str], + total_epochs: int, + warmup_epochs: int, train_webdataset: str, val_webdataset: str, test_webdataset: str, @@ -152,7 +191,10 @@ def train_model_command( preprocess_mode: str, optimizer_type: str, learning_rate: float, + learning_rate_scheduler_type: str, weight_decay: float, + loss_function_type: str, + label_smoothing: float, ): from src.classification.train import train_model @@ -161,6 +203,8 @@ def train_model_command( model_type=model_type, num_classes=num_classes, existing_weights=existing_weights, + total_epochs=total_epochs, + warmup_epochs=warmup_epochs, train_webdataset=train_webdataset, val_webdataset=val_webdataset, test_webdataset=test_webdataset, @@ -169,7 +213,10 @@ def train_model_command( preprocess_mode=preprocess_mode, optimizer_type=optimizer_type, learning_rate=learning_rate, + learning_rate_scheduler_type=learning_rate_scheduler_type, weight_decay=weight_decay, + loss_function_type=loss_function_type, + label_smoothing=label_smoothing, ) diff --git a/src/classification/constants.py b/src/classification/constants.py index 12310e7..c67214e 100644 --- a/src/classification/constants.py +++ b/src/classification/constants.py @@ -15,6 +15,13 @@ VIT_B16_224 = "vit_base_patch16_224_in21k" VIT_B16_384 = "vit_base_patch16_384" +ADAMW = "adamw" +SGD = "sgd" + +CROSS_ENTROPY_LOSS = "cross_entropy" + +COSINE_LR_SCHEDULER = "cosine" + AVAILABLE_MODELS = frozenset( [ EFFICIENTNETV2_B3, @@ -29,3 +36,10 @@ VIT_B16_384, ] ) + + +AVAILABLE_LOSS_FUNCIONS = frozenset([CROSS_ENTROPY_LOSS]) + +AVAILABLE_OPTIMIZERS = frozenset([ADAMW, SGD]) + +AVAILABLE_LR_SCHEDULERS = frozenset([COSINE_LR_SCHEDULER]) diff --git a/src/classification/dataloader.py b/src/classification/dataloader.py index 904e06f..758e9f3 100644 --- a/src/classification/dataloader.py +++ b/src/classification/dataloader.py @@ -57,7 +57,9 @@ def _random_resize(image: PIL.Image.Image, full_size: int) -> PIL.Image.Image: # TODO: Add return type of this function -def _get_transforms(input_size: int, is_training: bool, preprocess_mode: str = "torch"): +def _get_transforms( + input_size: int, is_training: bool, preprocess_mode: str = "torch" +) -> transforms.Compose: """Transformation applied to each image""" # Add square padding diff --git a/src/classification/train.py b/src/classification/train.py index d4c76b9..23549f6 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -10,7 +10,13 @@ import torch from src.classification.dataloader import build_webdataset_pipeline -from src.classification.utils import build_model, get_optimizer, set_random_seeds +from src.classification.utils import ( + build_model, + get_learning_rate_scheduler, + get_loss_function, + get_optimizer, + set_random_seeds, +) def _train_model_for_one_epoch() -> None: @@ -22,6 +28,8 @@ def train_model( model_type: str, num_classes: int, existing_weights: Optional[str], + total_epochs: int, + warmup_epochs: int, train_webdataset: str, val_webdataset: str, test_webdataset: str, @@ -30,7 +38,10 @@ def train_model( preprocess_mode: str, optimizer_type: str, learning_rate: float, + learning_rate_scheduler_type: str, weight_decay: float, + loss_function_type: str, + label_smoothing: float, ) -> None: """Main training function""" @@ -60,27 +71,25 @@ def train_model( # Other training ingredients optimizer = get_optimizer(optimizer_type, model, learning_rate, weight_decay) - # learning_rate_scheduler = get_learning_rate_scheduler() - # loss = ... - print(optimizer, training_dataloader, validation_dataloader, test_dataloader) + steps_per_epoch = ... + learning_rate_scheduler = get_learning_rate_scheduler( + optimizer, + learning_rate_scheduler_type, + total_epochs, + steps_per_epoch, + warmup_epochs, + ) + loss_function = get_loss_function( + loss_function_type, label_smoothing=label_smoothing + ) + print( + loss_function, + learning_rate_scheduler, + optimizer, + training_dataloader, + validation_dataloader, + test_dataloader, + ) # Model training _train_model_for_one_epoch() - - -if __name__ == "__main__": - train_model( - random_seed=42, - model_type="vit_base_patch16_128_in21k", - num_classes=29176, - existing_weights=None, - train_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-000000.tar", - val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", - test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", - image_input_size=128, - batch_size=1, - preprocess_mode="torch", - optimizer_type="adamw", - learning_rate=0.001, - weight_decay=1e-5, - ) diff --git a/src/classification/utils.py b/src/classification/utils.py index d5914e0..3d12f9f 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -10,8 +10,14 @@ import numpy as np import timm import torch +from timm.scheduler import CosineLRScheduler -from src.classification.constants import AVAILABLE_MODELS, VIT_B16_128 +from src.classification.constants import ( + AVAILABLE_MODELS, + COSINE_LR_SCHEDULER, + CROSS_ENTROPY_LOSS, + VIT_B16_128, +) def set_random_seeds(random_seed: int) -> None: @@ -48,9 +54,40 @@ def get_optimizer( raise RuntimeError(f"{optimizer_type} optimizer is not implemented.") -def get_learning_rate_scheduler() -> None: - """Scheduler definitions""" - pass +def get_learning_rate_scheduler( + optimizer: torch.optim.Optimizer, + lr_scheduler_type: str, + total_epochs: int, + steps_per_epoch: int, + warmup_epochs: int, +) -> tp.Any: + """Learning rate scheduler definitions""" + + total_steps = int(total_epochs * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + if lr_scheduler_type == COSINE_LR_SCHEDULER: + return CosineLRScheduler( + optimizer, + t_initial=(total_steps - warmup_steps), + warmup_t=warmup_steps, + warmup_prefix=True, + cycle_limit=1, + t_in_epochs=False, + ) + else: + raise RuntimeError( + f"{lr_scheduler_type} learning rate scheduler is not implemented." + ) + + +def get_loss_function(loss_function_name: str, label_smoothing: float = 0.0) -> tp.Any: + """Loss function definitions""" + + if loss_function_name == CROSS_ENTROPY_LOSS: + return torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing) + else: + raise RuntimeError(f"{loss_function_name} loss is not implemented.") def build_model( From b388e71034427dc9f99baf0ccce747183d47763e Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Tue, 12 Nov 2024 13:56:30 -0500 Subject: [PATCH 17/25] Add and test LR Scheduler --- src/classification/cli.py | 10 +++++++--- src/classification/train.py | 21 ++++++++++++--------- src/classification/utils.py | 20 ++++++++++++++++++++ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index 3c01bd6..eae3389 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -32,7 +32,6 @@ AVAILABLE_LR_SCHEDULERS, AVAILABLE_MODELS, AVAILABLE_OPTIMIZERS, - COSINE_LR_SCHEDULER, CROSS_ENTROPY_LOSS, ) @@ -103,7 +102,12 @@ default=30, help="Total number of training epochs", ) -@click.option("--warmup_epochs", type=int, default=2, help="Number of warmup epochs") +@click.option( + "--warmup_epochs", + type=int, + default=0, + help="Number of warmup epochs, if using a learning rate scehduler", +) @click.option( "--train_webdataset", type=str, @@ -155,7 +159,7 @@ @click.option( "--learning_rate_scheduler_type", type=click.Choice(tp.get_args(SupportedLearningRateSchedulers)), - default=COSINE_LR_SCHEDULER, + default=None, help="Learning rate scheduler", ) @click.option( diff --git a/src/classification/train.py b/src/classification/train.py index 23549f6..6ca3964 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -15,6 +15,7 @@ get_learning_rate_scheduler, get_loss_function, get_optimizer, + get_webdataset_length, set_random_seeds, ) @@ -38,7 +39,7 @@ def train_model( preprocess_mode: str, optimizer_type: str, learning_rate: float, - learning_rate_scheduler_type: str, + learning_rate_scheduler_type: Optional[str], weight_decay: float, loss_function_type: str, label_smoothing: float, @@ -71,14 +72,16 @@ def train_model( # Other training ingredients optimizer = get_optimizer(optimizer_type, model, learning_rate, weight_decay) - steps_per_epoch = ... - learning_rate_scheduler = get_learning_rate_scheduler( - optimizer, - learning_rate_scheduler_type, - total_epochs, - steps_per_epoch, - warmup_epochs, - ) + if learning_rate_scheduler_type: + train_data_length = get_webdataset_length(train_webdataset) + steps_per_epoch = int((train_data_length - 1) / batch_size) + 1 + learning_rate_scheduler = get_learning_rate_scheduler( + optimizer, + learning_rate_scheduler_type, + total_epochs, + steps_per_epoch, + warmup_epochs, + ) loss_function = get_loss_function( loss_function_type, label_smoothing=label_smoothing ) diff --git a/src/classification/utils.py b/src/classification/utils.py index 3d12f9f..d2b25c9 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -5,8 +5,10 @@ """ import random +import tarfile import typing as tp +import braceexpand import numpy as np import timm import torch @@ -90,6 +92,24 @@ def get_loss_function(loss_function_name: str, label_smoothing: float = 0.0) -> raise RuntimeError(f"{loss_function_name} loss is not implemented.") +def _count_files_from_tar(tar_filename: str, ext="jpg") -> int: + """Count the number of images in a single tar archive""" + + tar = tarfile.open(tar_filename) + files = [f for f in tar.getmembers() if f.name.endswith(ext)] + count_files = len(files) + tar.close() + return count_files + + +def get_webdataset_length(sharedurl: str) -> int: + """Get the total number of images in all webdataset files for a given dataset""" + + tar_filenames = list(braceexpand.braceexpand(sharedurl)) + counts = [_count_files_from_tar(tar_f) for tar_f in tar_filenames] + return int(sum(counts)) + + def build_model( device: str, model_type: str, From 3f5b5e93d3c174599063a984ef21ab91308231e5 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Wed, 13 Nov 2024 14:04:13 -0500 Subject: [PATCH 18/25] Basic training loop --- src/classification/train.py | 123 +++++++++++++++++++++++++++++++----- src/classification/utils.py | 6 +- 2 files changed, 111 insertions(+), 18 deletions(-) diff --git a/src/classification/train.py b/src/classification/train.py index 6ca3964..e698daa 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -5,6 +5,7 @@ """ Main script for training classification models """ +import typing as tp from typing import Optional import torch @@ -19,9 +20,71 @@ set_random_seeds, ) +total_train_steps = 0 -def _train_model_for_one_epoch() -> None: + +def _save_model_checkpoint(model: torch.nn.Module, model_path: str) -> None: + """Save model to disk""" + pass + + +def _train_model_for_one_epoch( + model: torch.nn.Module, + device: str, + optimizer: torch.optim.Optimizer, + loss_function: torch.nn.Module, + train_dataloader: torch.utils.data.DataLoader, + learning_rate_scheduler: Optional[tp.Any], +) -> None: """Training model for one epoch""" + global total_train_steps + + model.train() + for batch_data in train_dataloader: + images, labels = batch_data + images, labels = images.to(device, non_blocking=True), labels.to( + device, non_blocking=True + ) + + optimizer.zero_grad() + outputs = model(images) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + + # Learning rate scheduler step + if learning_rate_scheduler: + total_train_steps += 1 + learning_rate_scheduler.step_update(num_updates=total_train_steps) + + # TODO: Calculate accuracy metrics + # TODO: Take loss average before returning + + return loss + + +def _validate_model( + model: torch.nn.Module, + device: str, + loss_function: torch.nn.Module, + val_dataloader: torch.utils.data.DataLoader, +) -> None: + """Validate model after one epoch""" + + model.eval() + for batch_data in val_dataloader: + images, labels = batch_data + images, labels = images.to(device, non_blocking=True), labels.to( + device, non_blocking=True + ) + + with torch.no_grad(): + outputs = model(images) + loss = loss_function(outputs, labels) + + # TODO: Take loss average before returning + + return loss def train_model( @@ -53,25 +116,25 @@ def train_model( device = "cuda" if torch.cuda.is_available() else "cpu" print(f"The available device is {device}.") model = build_model(device, model_type, num_classes, existing_weights) - print(model) # Setup dataloaders - training_dataloader = build_webdataset_pipeline( + train_dataloader = build_webdataset_pipeline( train_webdataset, image_input_size, batch_size, preprocess_mode, is_training=True, ) - validation_dataloader = build_webdataset_pipeline( + val_dataloader = build_webdataset_pipeline( val_webdataset, image_input_size, batch_size, preprocess_mode ) - test_dataloader = build_webdataset_pipeline( - test_webdataset, image_input_size, batch_size, preprocess_mode - ) + # test_dataloader = build_webdataset_pipeline( + # test_webdataset, image_input_size, batch_size, preprocess_mode + # ) # Other training ingredients optimizer = get_optimizer(optimizer_type, model, learning_rate, weight_decay) + learning_rate_scheduler = None if learning_rate_scheduler_type: train_data_length = get_webdataset_length(train_webdataset) steps_per_epoch = int((train_data_length - 1) / batch_size) + 1 @@ -85,14 +148,42 @@ def train_model( loss_function = get_loss_function( loss_function_type, label_smoothing=label_smoothing ) - print( - loss_function, - learning_rate_scheduler, - optimizer, - training_dataloader, - validation_dataloader, - test_dataloader, - ) # Model training - _train_model_for_one_epoch() + for epoch in range(1, total_epochs + 1): + _train_model_for_one_epoch( + model, + device, + optimizer, + loss_function, + train_dataloader, + learning_rate_scheduler, + ) + _validate_model(model, device, loss_function, val_dataloader) + + # TODO: Save model checkpoint + # TODO: Calculate accuracy metrics + # TODO: Receive epoch-level metrics and upload to W&B + + +if __name__ == "__main__": + train_model( + random_seed=42, + model_type="resnet50", + num_classes=29176, + existing_weights=None, + total_epochs=10, + warmup_epochs=1, + train_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-{000000..000001}.tar", + val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", + test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", + image_input_size=128, + batch_size=16, + preprocess_mode="torch", + optimizer_type="adamw", + learning_rate=0.001, + learning_rate_scheduler_type="cosine", + weight_decay=1e-5, + loss_function_type="cross_entropy", + label_smoothing=0.1, + ) diff --git a/src/classification/utils.py b/src/classification/utils.py index d2b25c9..6421bb8 100644 --- a/src/classification/utils.py +++ b/src/classification/utils.py @@ -83,7 +83,9 @@ def get_learning_rate_scheduler( ) -def get_loss_function(loss_function_name: str, label_smoothing: float = 0.0) -> tp.Any: +def get_loss_function( + loss_function_name: str, label_smoothing: float = 0.0 +) -> torch.nn.Module: """Loss function definitions""" if loss_function_name == CROSS_ENTROPY_LOSS: @@ -116,7 +118,7 @@ def build_model( num_classes: int, existing_weights: tp.Optional[str], pretrained: bool = True, -): +) -> torch.nn.Module: """Model builder""" if model_type not in AVAILABLE_MODELS: From d11345796728bb1649c2cf96baf90a6c3264dabf Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Wed, 13 Nov 2024 16:35:22 -0500 Subject: [PATCH 19/25] Add train and val loops --- src/classification/dataloader.py | 1 - src/classification/train.py | 44 +++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/classification/dataloader.py b/src/classification/dataloader.py index 758e9f3..888bb5d 100644 --- a/src/classification/dataloader.py +++ b/src/classification/dataloader.py @@ -56,7 +56,6 @@ def _random_resize(image: PIL.Image.Image, full_size: int) -> PIL.Image.Image: return image -# TODO: Add return type of this function def _get_transforms( input_size: int, is_training: bool, preprocess_mode: str = "torch" ) -> transforms.Compose: diff --git a/src/classification/train.py b/src/classification/train.py index e698daa..2fc7584 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -9,6 +9,7 @@ from typing import Optional import torch +from timm.utils import AverageMeter from src.classification.dataloader import build_webdataset_pipeline from src.classification.utils import ( @@ -20,12 +21,9 @@ set_random_seeds, ) -total_train_steps = 0 - def _save_model_checkpoint(model: torch.nn.Module, model_path: str) -> None: """Save model to disk""" - pass def _train_model_for_one_epoch( @@ -35,9 +33,12 @@ def _train_model_for_one_epoch( loss_function: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, learning_rate_scheduler: Optional[tp.Any], -) -> None: + total_train_steps: int, +) -> tuple[float, int]: # TODO: First element will eventually turn into a dict """Training model for one epoch""" - global total_train_steps + + total_train_steps_current = total_train_steps + running_loss = AverageMeter() model.train() for batch_data in train_dataloader: @@ -46,21 +47,25 @@ def _train_model_for_one_epoch( device, non_blocking=True ) + # Forward pass, loss calculation, backward pass, and optimizer step optimizer.zero_grad() outputs = model(images) loss = loss_function(outputs, labels) loss.backward() optimizer.step() + # Calculate the average loss per sample + running_loss.update(loss.item()) + # Learning rate scheduler step if learning_rate_scheduler: - total_train_steps += 1 - learning_rate_scheduler.step_update(num_updates=total_train_steps) + total_train_steps_current += 1 + learning_rate_scheduler.step_update(num_updates=total_train_steps_current) # TODO: Calculate accuracy metrics # TODO: Take loss average before returning - return loss + return running_loss.avg, total_train_steps_current def _validate_model( @@ -68,9 +73,11 @@ def _validate_model( device: str, loss_function: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader, -) -> None: +) -> float: """Validate model after one epoch""" + running_loss = AverageMeter() + model.eval() for batch_data in val_dataloader: images, labels = batch_data @@ -82,9 +89,9 @@ def _validate_model( outputs = model(images) loss = loss_function(outputs, labels) - # TODO: Take loss average before returning + running_loss.update(loss.item()) - return loss + return running_loss.avg def train_model( @@ -150,18 +157,25 @@ def train_model( ) # Model training + total_train_steps = 0 # total training batches processed + current_maximum_val_loss = 1e8 for epoch in range(1, total_epochs + 1): - _train_model_for_one_epoch( + current_train_loss, total_train_steps_current = _train_model_for_one_epoch( model, device, optimizer, loss_function, train_dataloader, learning_rate_scheduler, + total_train_steps, ) - _validate_model(model, device, loss_function, val_dataloader) + total_train_steps = total_train_steps_current + current_val_loss = _validate_model(model, device, loss_function, val_dataloader) + + if current_val_loss < current_maximum_val_loss: + # _save_model_checkpoint(model, ...) # TODO: Save model checkpoint + current_maximum_val_loss = current_val_loss - # TODO: Save model checkpoint # TODO: Calculate accuracy metrics # TODO: Receive epoch-level metrics and upload to W&B @@ -178,7 +192,7 @@ def train_model( val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", image_input_size=128, - batch_size=16, + batch_size=64, preprocess_mode="torch", optimizer_type="adamw", learning_rate=0.001, From b04665357e0ab6002c6cc042e49c968bd8b61cd3 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Thu, 14 Nov 2024 15:51:18 -0500 Subject: [PATCH 20/25] Add model saving function --- src/classification/cli.py | 8 ++++++ src/classification/train.py | 53 ++++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index eae3389..8c63110 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -180,6 +180,12 @@ default=0.1, help="Label smoothing for model regularization. No smoothing if 0.0", ) +@click.option( + "--model_save_directory", + type=str, + required=True, + help="Directory to save the trained model", +) def train_model_command( random_seed: int, model_type: str, @@ -199,6 +205,7 @@ def train_model_command( weight_decay: float, loss_function_type: str, label_smoothing: float, + model_save_directory: str, ): from src.classification.train import train_model @@ -221,6 +228,7 @@ def train_model_command( weight_decay=weight_decay, loss_function_type=loss_function_type, label_smoothing=label_smoothing, + model_save_directory=model_save_directory, ) diff --git a/src/classification/train.py b/src/classification/train.py index 2fc7584..d28e569 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -5,7 +5,10 @@ """ Main script for training classification models """ +import pathlib import typing as tp +from datetime import datetime +from pathlib import Path from typing import Optional import torch @@ -22,9 +25,33 @@ ) -def _save_model_checkpoint(model: torch.nn.Module, model_path: str) -> None: +def _save_model_checkpoint( + model: torch.nn.Module, + model_save_path: pathlib.Path, + optimizer: torch.optim.Optimizer, + learning_rate_scheduler: tp.Any, + epoch: int, + train_loss: float, + val_loss: float, +) -> None: """Save model to disk""" + if torch.cuda.device_count() > 1: + model_state_dict = model.module.state_dict() + else: + model_state_dict = model.state_dict() + model_checkpoint = { + "epoch": epoch, + "model_state_dict": model_state_dict, + "optimizer_state_dict": optimizer.state_dict(), + "lr_scheduler": learning_rate_scheduler.state_dict() + if learning_rate_scheduler is not None + else None, + "train_loss": train_loss, + "val_loss": val_loss, + } + torch.save(model_checkpoint, f"{model_save_path}_checkpoint.pt") + def _train_model_for_one_epoch( model: torch.nn.Module, @@ -113,6 +140,7 @@ def train_model( weight_decay: float, loss_function_type: str, label_smoothing: float, + model_save_directory: str, ) -> None: """Main training function""" @@ -155,10 +183,12 @@ def train_model( loss_function = get_loss_function( loss_function_type, label_smoothing=label_smoothing ) + current_date = datetime.now().date().strftime("%Y%m%d") + model_save_path = Path(model_save_directory) / f"{model_type}_{current_date}" # Model training total_train_steps = 0 # total training batches processed - current_maximum_val_loss = 1e8 + lowest_val_loss = 1e8 for epoch in range(1, total_epochs + 1): current_train_loss, total_train_steps_current = _train_model_for_one_epoch( model, @@ -172,9 +202,17 @@ def train_model( total_train_steps = total_train_steps_current current_val_loss = _validate_model(model, device, loss_function, val_dataloader) - if current_val_loss < current_maximum_val_loss: - # _save_model_checkpoint(model, ...) # TODO: Save model checkpoint - current_maximum_val_loss = current_val_loss + if current_val_loss < lowest_val_loss: + _save_model_checkpoint( + model, + model_save_path, + optimizer, + learning_rate_scheduler, + epoch, + current_train_loss, + current_val_loss, + ) + lowest_val_loss = current_val_loss # TODO: Calculate accuracy metrics # TODO: Receive epoch-level metrics and upload to W&B @@ -186,8 +224,8 @@ def train_model( model_type="resnet50", num_classes=29176, existing_weights=None, - total_epochs=10, - warmup_epochs=1, + total_epochs=1, + warmup_epochs=0, train_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-{000000..000001}.tar", val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", @@ -200,4 +238,5 @@ def train_model( weight_decay=1e-5, loss_function_type="cross_entropy", label_smoothing=0.1, + model_save_directory="/home/mila/a/aditya.jain/scratch", ) From 1f2380f2b57af612d2207400c808837f9c12b21c Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 15 Nov 2024 16:43:58 -0500 Subject: [PATCH 21/25] Ready to test a full training run --- src/classification/train.py | 69 +++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/src/classification/train.py b/src/classification/train.py index d28e569..819f425 100644 --- a/src/classification/train.py +++ b/src/classification/train.py @@ -61,11 +61,12 @@ def _train_model_for_one_epoch( train_dataloader: torch.utils.data.DataLoader, learning_rate_scheduler: Optional[tp.Any], total_train_steps: int, -) -> tuple[float, int]: # TODO: First element will eventually turn into a dict +) -> tuple[dict, int]: # TODO: First element will eventually turn into a dict """Training model for one epoch""" total_train_steps_current = total_train_steps running_loss = AverageMeter() + running_accuracy = AverageMeter() model.train() for batch_data in train_dataloader: @@ -84,6 +85,10 @@ def _train_model_for_one_epoch( # Calculate the average loss per sample running_loss.update(loss.item()) + # Calculate the batch accuracy and update to global accuracy + _, predicted = torch.max(outputs, 1) + running_accuracy.update((predicted == labels).sum().item() / labels.size(0)) + # Learning rate scheduler step if learning_rate_scheduler: total_train_steps_current += 1 @@ -92,7 +97,9 @@ def _train_model_for_one_epoch( # TODO: Calculate accuracy metrics # TODO: Take loss average before returning - return running_loss.avg, total_train_steps_current + metrics = {"train_loss": running_loss.avg, "train_accuracy": running_accuracy.avg} + + return metrics, total_train_steps_current def _validate_model( @@ -100,10 +107,11 @@ def _validate_model( device: str, loss_function: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader, -) -> float: +) -> dict: """Validate model after one epoch""" running_loss = AverageMeter() + running_accuracy = AverageMeter() model.eval() for batch_data in val_dataloader: @@ -116,9 +124,16 @@ def _validate_model( outputs = model(images) loss = loss_function(outputs, labels) + # Calculate the average loss per sample running_loss.update(loss.item()) - return running_loss.avg + # Calculate the batch accuracy and update to global accuracy + _, predicted = torch.max(outputs, 1) + running_accuracy.update((predicted == labels).sum().item() / labels.size(0)) + + metrics = {"val_loss": running_loss.avg, "val_accuracy": running_accuracy.avg} + + return metrics def train_model( @@ -190,7 +205,7 @@ def train_model( total_train_steps = 0 # total training batches processed lowest_val_loss = 1e8 for epoch in range(1, total_epochs + 1): - current_train_loss, total_train_steps_current = _train_model_for_one_epoch( + train_metrics, total_train_steps_current = _train_model_for_one_epoch( model, device, optimizer, @@ -200,43 +215,29 @@ def train_model( total_train_steps, ) total_train_steps = total_train_steps_current - current_val_loss = _validate_model(model, device, loss_function, val_dataloader) + val_metrics = _validate_model(model, device, loss_function, val_dataloader) - if current_val_loss < lowest_val_loss: + if val_metrics["val_loss"] < lowest_val_loss: _save_model_checkpoint( model, model_save_path, optimizer, learning_rate_scheduler, epoch, - current_train_loss, - current_val_loss, + train_metrics["train_loss"], + val_metrics["val_loss"], ) - lowest_val_loss = current_val_loss + lowest_val_loss = val_metrics["val_loss"] + + print( + f"Epoch [{epoch:02d}/{total_epochs}]: " + f"Train Loss: {train_metrics['train_loss']:.4f}, " + f"Val Loss: {val_metrics['val_loss']:.4f}, " + f"Train Accuracy: {train_metrics['train_accuracy']*100:.2f}, " + f"Val Accuracy: {val_metrics['val_accuracy']*100:.2f}, " + f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}", + flush=True, + ) # TODO: Calculate accuracy metrics # TODO: Receive epoch-level metrics and upload to W&B - - -if __name__ == "__main__": - train_model( - random_seed=42, - model_type="resnet50", - num_classes=29176, - existing_weights=None, - total_epochs=1, - warmup_epochs=0, - train_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/train/train450-{000000..000001}.tar", - val_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/val/val450-000000.tar", - test_webdataset="/home/mila/a/aditya.jain/scratch/global_model/webdataset/test/test450-000000.tar", - image_input_size=128, - batch_size=64, - preprocess_mode="torch", - optimizer_type="adamw", - learning_rate=0.001, - learning_rate_scheduler_type="cosine", - weight_decay=1e-5, - loss_function_type="cross_entropy", - label_smoothing=0.1, - model_save_directory="/home/mila/a/aditya.jain/scratch", - ) From 3b4bc4c44a5682b47fe96d21186e58292da8d055 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 15 Nov 2024 17:28:24 -0500 Subject: [PATCH 22/25] Minor rearrange --- src/classification/constants.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/classification/constants.py b/src/classification/constants.py index c67214e..16f9d20 100644 --- a/src/classification/constants.py +++ b/src/classification/constants.py @@ -14,14 +14,6 @@ VIT_B16_128 = "vit_base_patch16_128_in21k" VIT_B16_224 = "vit_base_patch16_224_in21k" VIT_B16_384 = "vit_base_patch16_384" - -ADAMW = "adamw" -SGD = "sgd" - -CROSS_ENTROPY_LOSS = "cross_entropy" - -COSINE_LR_SCHEDULER = "cosine" - AVAILABLE_MODELS = frozenset( [ EFFICIENTNETV2_B3, @@ -37,9 +29,12 @@ ] ) +ADAMW = "adamw" +SGD = "sgd" +AVAILABLE_OPTIMIZERS = frozenset([ADAMW, SGD]) +CROSS_ENTROPY_LOSS = "cross_entropy" AVAILABLE_LOSS_FUNCIONS = frozenset([CROSS_ENTROPY_LOSS]) -AVAILABLE_OPTIMIZERS = frozenset([ADAMW, SGD]) - +COSINE_LR_SCHEDULER = "cosine" AVAILABLE_LR_SCHEDULERS = frozenset([COSINE_LR_SCHEDULER]) From 659e48384301cd1ceea4b3f7313c997faa4c1ec4 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 15 Nov 2024 17:44:30 -0500 Subject: [PATCH 23/25] Remove avoidable if-else statments --- src/classification/dataloader.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/classification/dataloader.py b/src/classification/dataloader.py index 888bb5d..09ec28a 100644 --- a/src/classification/dataloader.py +++ b/src/classification/dataloader.py @@ -5,10 +5,10 @@ """ import os +import random from functools import partial from typing import Any -import numpy as np import PIL import torch import webdataset as wds @@ -32,12 +32,13 @@ def _pad_to_square(image: PIL.Image.Image) -> PIL.Image.Image: def _normalization(preprocess_mode: str) -> tuple[list[float], list[float]]: """Get the mean and std for normalization""" - if preprocess_mode == "torch": - mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - elif preprocess_mode == "tf": - mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] - else: - mean, std = [0.0, 0.0, 0.0], [1.0, 1.0, 1.0] + preprocess_params = { + "torch": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + "tf": ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + "default": ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), + } + + mean, std = preprocess_params.get(preprocess_mode, preprocess_params["default"]) return mean, std @@ -45,12 +46,13 @@ def _normalization(preprocess_mode: str) -> tuple[list[float], list[float]]: def _random_resize(image: PIL.Image.Image, full_size: int) -> PIL.Image.Image: """Mixed resolution transformation""" - random_num = np.random.uniform() - if random_num <= 0.25: - transform = transforms.Resize((int(0.5 * full_size), int(0.5 * full_size))) - image = transform(image) - elif random_num <= 0.5: - transform = transforms.Resize((int(0.25 * full_size), int(0.25 * full_size))) + values = [0.25, 0.5, None, None] + random_value = random.choice(values) + + if random_value: + transform = transforms.Resize( + (int(random_value * full_size), int(random_value * full_size)) + ) image = transform(image) return image @@ -92,11 +94,9 @@ def build_webdataset_pipeline( """Main dataset builder and loader function""" # Load the webdataset + dataset = wds.WebDataset(sharedurl, shardshuffle=is_training) if is_training: - dataset = wds.WebDataset(sharedurl, shardshuffle=True) dataset = dataset.shuffle(10000) - else: - dataset = wds.WebDataset(sharedurl, shardshuffle=False) # Get image transforms image_transform = _get_transforms(input_size, is_training, preprocess_mode) From 8c174f0c26576e900cc71c217c6cb44a4c48ba1b Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 15 Nov 2024 17:46:43 -0500 Subject: [PATCH 24/25] Minor default changes --- src/classification/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/classification/cli.py b/src/classification/cli.py index 8c63110..4274c42 100644 --- a/src/classification/cli.py +++ b/src/classification/cli.py @@ -33,6 +33,7 @@ AVAILABLE_MODELS, AVAILABLE_OPTIMIZERS, CROSS_ENTROPY_LOSS, + RESNET50, ) SupportedModels = tp.Literal[*AVAILABLE_MODELS] @@ -81,7 +82,7 @@ @click.option( "--model_type", type=click.Choice(tp.get_args(SupportedModels)), - required=True, + default=RESNET50, help="Model architecture", ) @click.option( From fc68af91c29290a028df44ae43e6acd36d232966 Mon Sep 17 00:00:00 2001 From: adityajain07 Date: Fri, 15 Nov 2024 17:47:02 -0500 Subject: [PATCH 25/25] A test training sbatch script --- scripts/job_train_classifier.sh | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 scripts/job_train_classifier.sh diff --git a/scripts/job_train_classifier.sh b/scripts/job_train_classifier.sh new file mode 100644 index 0000000..168a057 --- /dev/null +++ b/scripts/job_train_classifier.sh @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --job-name=test_classifer_training_code +#SBATCH --ntasks=1 +#SBATCH --time=14:00:00 +#SBATCH --mem=48G +#SBATCH --partition=main +#SBATCH --cpus-per-task=4 +#SBATCH --gres=gpu:rtx8000:1 +#SBATCH --output=test_classifer_training_code_%j.out + +# 1. Load the required modules +module load miniconda/3 + +# 2. Load your environment +conda activate ami-ml + +# 3. Load the environment variables outside of python script +set -o allexport +source .env +set +o allexport + +# Keep track of time +SECONDS=0 + +# 4. Copy your dataset to the compute node +cp $SAMPLE_TRAIN_WBDS_LINUX $SLURM_TMPDIR +cp $SAMPLE_VAL_WBDS_LINUX $SLURM_TMPDIR + +echo "Time taken to copy the data: $((SECONDS/60)) minutes" + +# 5. Launch your job #TODO: +ami-classification train-model \ +--num_classes 29176 \ +--train_webdataset "$SLURM_TMPDIR/train450-{000000..000976}.tar" \ +--val_webdataset "$SLURM_TMPDIR/val450-{000000..000089}.tar" \ +--test_webdataset "None" \ +--model_save_directory $TEST_PATH + +# Print time taken to execute the script +echo "Time taken to train the model: $((SECONDS/60)) minutes" \ No newline at end of file