Skip to content

Commit

Permalink
Main updated to work with the new trainer implementation closes #1 #4
Browse files Browse the repository at this point in the history
I think that the folder structure is clear by now and the refactor is completed so I consider this issues closed.
  • Loading branch information
pab1s committed Mar 19, 2024
1 parent 4f10560 commit a9f52ac
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 111 deletions.
14 changes: 0 additions & 14 deletions evaluate.py

This file was deleted.

101 changes: 42 additions & 59 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,45 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets
from torch.utils.data import DataLoader
from train import train
from evaluate import evaluate as eval

num_classes = 10
num_epochs = 10
batch_size = 32

# Transforms
transform = models.EfficientNet_B0_Weights.DEFAULT.transforms()

train_data = datasets.CIFAR10(
root=".",
train=True,
download=True,
transform=transform,
)

test_data = datasets.CIFAR10(
root=".",
train=False,
download=True,
transform=transform,
)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Load the pre-trained EfficientNet-B0 model
model = models.efficientnet_b0(weights="DEFAULT")

# Freeze all the parameters in the pre-trained model
for param in model.parameters():
param.requires_grad = False

# Replace the last fully connected layer with a new one
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)

# Define the number of classes
num_classes = 10

# Print the model architecture
print(model)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
import yaml
from datetime import datetime
from utils.data_utils import get_dataloaders
from models import get_model
from trainers import get_trainer
from os import path

def main(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_dataloaders(config)

# Loads the specified model based on the configuration
model = get_model(
config['model']['name'],
config['model']['num_classes'],
pretrained=config['model']['pretrained']
).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])

# Prepare filenames for logging and plotting
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model_dataset_time = f"{config['model']['name']}_{config['data']['name']}_{current_time}"
log_filename = path.join(config['paths']['log_path'], f"log_{model_dataset_time}.csv")
plot_filename = path.join(config['paths']['plot_path'], f"plot_{model_dataset_time}.png")

trainer = get_trainer(config['trainer'], model=model, device=device)

trainer.build(criterion=criterion, optimizer=optimizer)

trainer.train(
train_loader=train_loader,
num_epochs=config['training']['num_epochs'],
log_path=log_filename,
plot_path=plot_filename
)
trainer.evaluate(test_loader=test_loader)

if __name__ == "__main__":
train(num_epochs, model, device, criterion, optimizer, train_loader, "logs/log.csv", "outputs/figures/plot.png")
eval(model, device)
main("config/config.yaml")
38 changes: 0 additions & 38 deletions train.py

This file was deleted.

0 comments on commit a9f52ac

Please sign in to comment.