Skip to content

Commit

Permalink
Merge pull request #7 from dirac-institute/awo/implement-resnet50-model
Browse files Browse the repository at this point in the history
Initial attempt at implementing resnet50 for use with CIFAR data.
  • Loading branch information
drewoldag authored Oct 11, 2024
2 parents d5515ed + 3f9d533 commit 6a7bf63
Show file tree
Hide file tree
Showing 5 changed files with 1,781 additions and 9 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,7 @@ _html/
.initialize_new_project.sh

# Model files
**/*.pth
**/*.pth

# Run results
results/
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ repos:
name: Clear output from Jupyter notebooks
description: Clear output from Jupyter notebooks.
files: \.ipynb$
exclude: ^docs/pre_executed
stages: [commit]
language: system
entry: jupyter nbconvert --clear-output
Expand Down
1,701 changes: 1,701 additions & 0 deletions docs/pre_executed/CNN_filter.ipynb

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log_level = "info" # Emit informational messages, warnings and all errors
# log_level = "debug" # Very verbose, emit all log messages.

data_dir = "/home/drew/code/fibad/data/"
results_dir = "./results" # Results get named <verb>-<timestamp> under this directory

[download]
sw = "22asec"
Expand Down Expand Up @@ -52,20 +53,20 @@ mask = false
[model]
# The name of the built-in model to use or the libpath to an external model
# e.g. "user_package.submodule.ExternalModel" or "ExampleAutoencoder"
name = "kbmod_ml.models.cnn.CNN"
name = "kbmod_ml.models.resnet50.RESNET50"

weights_filepath = "example_model.pth"
weights_filepath = "resnet50.pth"
epochs = 10

base_channel_size = 32
latent_dim =64
num_classes = 10

[data_loader]

[data_set]
# Name of the built-in data loader to use or the libpath to an external data loader
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "CifarDataLoader"


[data_loader]
# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and
Expand All @@ -83,9 +84,10 @@ crop_to = false
filters = false

# Default PyTorch DataLoader parameters
batch_size = 4
batch_size = 10
shuffle = true
num_workers = 2
num_workers = 10

[predict]
model_weights_file = false
batch_size = 32
65 changes: 65 additions & 0 deletions src/kbmod_ml/models/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim
from fibad.models.model_registry import fibad_model
from torchvision.models import resnet50

logger = logging.getLogger(__name__)


@fibad_model
class RESNET50(nn.Module):
def __init__(self, model_config, shape):
logger.info("This is an external model, not in FIBAD!!!")
super().__init__()

self.config = model_config

self.model = resnet50(pretrained=False, num_classes=self.config["model"]["num_classes"])

# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)`
# but we define them as methods as a way to allow for more flexibility in the future.
self.optimizer = self._optimizer()
self.criterion = self._criterion()

def forward(self, x):
return self.model(x)

def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
contents of the inner loop of a ML training process.
Parameters
----------
batch : tuple
A tuple containing the inputs and labels for the current batch.
Returns
-------
Current loss value
The loss value for the current batch.
"""
inputs, labels = batch

self.optimizer.zero_grad()
outputs = self(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}

def _criterion(self):
return nn.CrossEntropyLoss()

def _optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

def save(self):
torch.save(self.state_dict(), self.config.get("weights_filepath"))

0 comments on commit 6a7bf63

Please sign in to comment.