Skip to content

Commit

Permalink
Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Nov 21, 2024
1 parent cd95494 commit d41f6dd
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 110 deletions.
9 changes: 9 additions & 0 deletions .github/linters/.hadolint.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# --------------------------------------------------------------------------------------

failure-threshold: warning
ignored:
- DL3008 # Pin versions in apt get install.
Expand Down
9 changes: 9 additions & 0 deletions ci/src/main/k8s.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# --------------------------------------------------------------------------------------

import copy
import logging
import re
Expand Down
11 changes: 11 additions & 0 deletions src/itwinai/torch/distributed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# - Jarl Sondre Sæther <[email protected]> - CERN
# - Henry Mutegeki <[email protected]> - CERN
# --------------------------------------------------------------------------------------

import abc
import functools
import os
Expand Down
19 changes: 14 additions & 5 deletions tests/torch/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# --------------------------------------------------------------------------------------

from typing import Generator

import pytest
Expand All @@ -11,25 +20,25 @@
)


@pytest.fixture(scope='package')
@pytest.fixture(scope="package")
def ddp_strategy() -> Generator[TorchDistributedStrategy, None, None]:
"""Instantiate Torch's DistributedDataParallel strategy."""
strategy = TorchDDPStrategy(backend='nccl' if torch.cuda.is_available() else 'gloo')
strategy = TorchDDPStrategy(backend="nccl" if torch.cuda.is_available() else "gloo")
strategy.init()
yield strategy
strategy.clean_up()


@pytest.fixture(scope='package')
@pytest.fixture(scope="package")
def deepspeed_strategy() -> Generator[DeepSpeedStrategy, None, None]:
"""Instantiate DeepSpeed strategy."""
strategy = DeepSpeedStrategy(backend='nccl' if torch.cuda.is_available() else 'gloo')
strategy = DeepSpeedStrategy(backend="nccl" if torch.cuda.is_available() else "gloo")
strategy.init()
yield strategy
strategy.clean_up()


@pytest.fixture(scope='package')
@pytest.fixture(scope="package")
def horovod_strategy() -> Generator[HorovodStrategy, None, None]:
"""Instantiate Horovod strategy."""
strategy = HorovodStrategy()
Expand Down
75 changes: 51 additions & 24 deletions tests/torch/distribtued_decorator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# --------------------------------------------------------------------------------------

"""
Test @distributed function decorator. To run this script, use the following
command:
Expand All @@ -19,7 +28,6 @@


class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
Expand Down Expand Up @@ -48,9 +56,15 @@ def train(model, device, train_loader, optimizer, epoch):
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)


def test(model, device, test_loader):
Expand All @@ -62,50 +76,63 @@ def test(model, device, test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print(
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)


@distributed
def train_func(
model, train_dataloader, validation_dataloader, device,
optimizer, scheduler, epochs=10
model, train_dataloader, validation_dataloader, device, optimizer, scheduler, epochs=10
):
for epoch in range(1, epochs + 1):
train(model, device, train_dataloader, optimizer, epoch)
test(model, device, validation_dataloader)
scheduler.step()


if __name__ == '__main__':

if __name__ == "__main__":
train_set = datasets.MNIST(
'.tmp/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
".tmp/",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
val_set = datasets.MNIST(
'.tmp/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
".tmp/",
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
model = Net()
train_dataloader = DataLoader(train_set, batch_size=32, pin_memory=True)
validation_dataloader = DataLoader(val_set, batch_size=32, pin_memory=True)
optimizer = optim.Adadelta(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9)

# Train distributed
train_func(model, train_dataloader, validation_dataloader, 'cuda',
optimizer, scheduler=scheduler, epochs=1)
train_func(
model,
train_dataloader,
validation_dataloader,
"cuda",
optimizer,
scheduler=scheduler,
epochs=1,
)
21 changes: 13 additions & 8 deletions tests/torch/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# --------------------------------------------------------------------------------------

import pytest
from pydantic import ValidationError

Expand All @@ -6,22 +15,18 @@

def test_values_parsing():
"""Check dynamic override and creation of new entries."""
cfg = TrainingConfiguration(
batch_size='11',
param_abc='11',
param_xyz=1.1
)
cfg = TrainingConfiguration(batch_size="11", param_abc="11", param_xyz=1.1)
assert cfg.batch_size == 11
assert cfg.param_abc == '11'
assert cfg.param_abc == "11"
assert cfg.param_xyz == 1.1
assert isinstance(cfg.pin_gpu_memory, bool)

# Check dict-like getitem
assert cfg['batch_size'] == 11
assert cfg["batch_size"] == 11


def test_illegal_override():
"""Test that illegal type override fails."""
with pytest.raises(ValidationError) as exc_info:
TrainingConfiguration(batch_size='hello')
TrainingConfiguration(batch_size="hello")
assert "batch_size" in str(exc_info.value)
Loading

0 comments on commit d41f6dd

Please sign in to comment.