-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
175 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
failure-threshold: warning | ||
ignored: | ||
- DL3008 # Pin versions in apt get install. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
import copy | ||
import logging | ||
import re | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,14 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# - Jarl Sondre Sæther <[email protected]> - CERN | ||
# - Henry Mutegeki <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
import abc | ||
import functools | ||
import os | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
from typing import Generator | ||
|
||
import pytest | ||
|
@@ -11,25 +20,25 @@ | |
) | ||
|
||
|
||
@pytest.fixture(scope='package') | ||
@pytest.fixture(scope="package") | ||
def ddp_strategy() -> Generator[TorchDistributedStrategy, None, None]: | ||
"""Instantiate Torch's DistributedDataParallel strategy.""" | ||
strategy = TorchDDPStrategy(backend='nccl' if torch.cuda.is_available() else 'gloo') | ||
strategy = TorchDDPStrategy(backend="nccl" if torch.cuda.is_available() else "gloo") | ||
strategy.init() | ||
yield strategy | ||
strategy.clean_up() | ||
|
||
|
||
@pytest.fixture(scope='package') | ||
@pytest.fixture(scope="package") | ||
def deepspeed_strategy() -> Generator[DeepSpeedStrategy, None, None]: | ||
"""Instantiate DeepSpeed strategy.""" | ||
strategy = DeepSpeedStrategy(backend='nccl' if torch.cuda.is_available() else 'gloo') | ||
strategy = DeepSpeedStrategy(backend="nccl" if torch.cuda.is_available() else "gloo") | ||
strategy.init() | ||
yield strategy | ||
strategy.clean_up() | ||
|
||
|
||
@pytest.fixture(scope='package') | ||
@pytest.fixture(scope="package") | ||
def horovod_strategy() -> Generator[HorovodStrategy, None, None]: | ||
"""Instantiate Horovod strategy.""" | ||
strategy = HorovodStrategy() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
""" | ||
Test @distributed function decorator. To run this script, use the following | ||
command: | ||
|
@@ -19,7 +28,6 @@ | |
|
||
|
||
class Net(nn.Module): | ||
|
||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
|
@@ -48,9 +56,15 @@ def train(model, device, train_loader, optimizer, epoch): | |
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
print( | ||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | ||
epoch, | ||
batch_idx * len(data), | ||
len(train_loader.dataset), | ||
100.0 * batch_idx / len(train_loader), | ||
loss.item(), | ||
) | ||
) | ||
|
||
|
||
def test(model, device, test_loader): | ||
|
@@ -62,50 +76,63 @@ def test(model, device, test_loader): | |
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
# sum up batch loss | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
test_loss += F.nll_loss(output, target, reduction="sum").item() | ||
# get the index of the max log-probability | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print( | ||
'\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | ||
test_loss, | ||
correct, | ||
len(test_loader.dataset), | ||
100.0 * correct / len(test_loader.dataset), | ||
) | ||
) | ||
|
||
|
||
@distributed | ||
def train_func( | ||
model, train_dataloader, validation_dataloader, device, | ||
optimizer, scheduler, epochs=10 | ||
model, train_dataloader, validation_dataloader, device, optimizer, scheduler, epochs=10 | ||
): | ||
for epoch in range(1, epochs + 1): | ||
train(model, device, train_dataloader, optimizer, epoch) | ||
test(model, device, validation_dataloader) | ||
scheduler.step() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
if __name__ == "__main__": | ||
train_set = datasets.MNIST( | ||
'.tmp/', train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
".tmp/", | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
), | ||
) | ||
val_set = datasets.MNIST( | ||
'.tmp/', train=False, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
".tmp/", | ||
train=False, | ||
download=True, | ||
transform=transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
), | ||
) | ||
model = Net() | ||
train_dataloader = DataLoader(train_set, batch_size=32, pin_memory=True) | ||
validation_dataloader = DataLoader(val_set, batch_size=32, pin_memory=True) | ||
optimizer = optim.Adadelta(model.parameters(), lr=1e-3) | ||
scheduler = StepLR(optimizer, step_size=1, gamma=0.9) | ||
|
||
# Train distributed | ||
train_func(model, train_dataloader, validation_dataloader, 'cuda', | ||
optimizer, scheduler=scheduler, epochs=1) | ||
train_func( | ||
model, | ||
train_dataloader, | ||
validation_dataloader, | ||
"cuda", | ||
optimizer, | ||
scheduler=scheduler, | ||
epochs=1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,12 @@ | ||
# -------------------------------------------------------------------------------------- | ||
# Part of the interTwin Project: https://www.intertwin.eu/ | ||
# | ||
# Created by: Matteo Bunino | ||
# | ||
# Credit: | ||
# - Matteo Bunino <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
import pytest | ||
from pydantic import ValidationError | ||
|
||
|
@@ -6,22 +15,18 @@ | |
|
||
def test_values_parsing(): | ||
"""Check dynamic override and creation of new entries.""" | ||
cfg = TrainingConfiguration( | ||
batch_size='11', | ||
param_abc='11', | ||
param_xyz=1.1 | ||
) | ||
cfg = TrainingConfiguration(batch_size="11", param_abc="11", param_xyz=1.1) | ||
assert cfg.batch_size == 11 | ||
assert cfg.param_abc == '11' | ||
assert cfg.param_abc == "11" | ||
assert cfg.param_xyz == 1.1 | ||
assert isinstance(cfg.pin_gpu_memory, bool) | ||
|
||
# Check dict-like getitem | ||
assert cfg['batch_size'] == 11 | ||
assert cfg["batch_size"] == 11 | ||
|
||
|
||
def test_illegal_override(): | ||
"""Test that illegal type override fails.""" | ||
with pytest.raises(ValidationError) as exc_info: | ||
TrainingConfiguration(batch_size='hello') | ||
TrainingConfiguration(batch_size="hello") | ||
assert "batch_size" in str(exc_info.value) |
Oops, something went wrong.