Skip to content

Commit

Permalink
Patch Virgo dataloader and refactor (#154)
Browse files Browse the repository at this point in the history
* PATCH virgo dataloader and refactor

* patch version

* Refactor

* UPDATE tests
  • Loading branch information
matbun authored Jun 12, 2024
1 parent a5dcb89 commit 08dae47
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 45 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "itwinai"
version = "0.2.0"
version = "0.2.1"
description = "AI and ML workflows module for scientific digital twins."
readme = "README.md"
requires-python = ">=3.10"
Expand Down Expand Up @@ -61,7 +61,7 @@ dev = [
"pytest-cov>=4.1.0",
"ipykernel",
"ipython",
"tensorflow==2.16.*", # needed by tests on tensorboard
"tensorflow==2.16.*", # needed by tests on tensorboard
]

[project.urls]
Expand Down
15 changes: 10 additions & 5 deletions src/itwinai/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def log(
)


class WanDBLogger(Logger):
class WandBLogger(Logger):
"""Abstraction around WandB logger.
Args:
Expand All @@ -434,6 +434,8 @@ class WanDBLogger(Logger):
more details. Defaults to 'epoch'.
"""

# TODO: add support for artifacts logging

def __init__(
self,
savedir: str = 'mllogs',
Expand All @@ -444,14 +446,15 @@ def __init__(
super().__init__(savedir=savedir, log_freq=log_freq)
self.project_name = project_name
self.supported_types = [
'watch', 'metric', 'figure', 'image', 'artifact', 'torch', 'dict',
'watch', 'metric', 'figure', 'image', 'torch', 'dict',
'param', 'text'
]

def create_logger_context(self):
"""Initialize logger. Init WandB run."""
os.makedirs(os.path.join(self.savedir, 'wandb'), exist_ok=True)
self.active_run = wandb.init(
dir=self.savedir,
dir=os.path.abspath(self.savedir),
project=self.project_name
)

Expand Down Expand Up @@ -484,7 +487,7 @@ def log(
kind (str, optional): type of the item to be logged. Must be
one among the list of ``self.supported_types``.
Defaults to 'metric'.
step (Optional[int], optional): logging step. Defaults to None.
step (Optional[int], optional): ignored by ``WandBLogger``.
batch_idx (Optional[int], optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
kwargs: keyword arguments to pass to the logger.
Expand All @@ -495,7 +498,9 @@ def log(
if kind == 'watch':
wandb.watch(item)
if kind in self.supported_types[1:]:
wandb.log({identifier: item}, step=step, commit=True)
# wandb.log({identifier: item}, step=step, commit=True)
# Let WandB use its preferred step
wandb.log({identifier: item}, commit=True)


class TensorBoardLogger(Logger):
Expand Down
10 changes: 10 additions & 0 deletions src/itwinai/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ class TrainingConfiguration(Configuration):
#: Batch size. In a distributed environment it is usually the
#: per-worker batch size. Defaults to 32.
batch_size: int = 32
#: Whether to shuffle train dataset when creating a torch ``DataLoader``.
#: Defaults to False.
shuffle_train: bool = False
#: Whether to shuffle validation dataset when creating a torch
#: ``DataLoader``.
#: Defaults to False.
shuffle_validation: bool = False
#: Whether to shuffle test dataset when creating a torch ``DataLoader``.
#: Defaults to False.
shuffle_test: bool = False
#: Whether to pin GPU memory. Property of torch ``DataLoader``.
#: Defaults to False.
pin_memory: bool = False
Expand Down
24 changes: 17 additions & 7 deletions src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,29 +262,31 @@ def create_dataloaders(
# may be interested to override! #
###################################

# TODO: improve robustness of getting from config
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
generator=self.torch_rng
generator=self.torch_rng,
shuffle=self.config.shuffle_train
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
dataset=validation_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
generator=self.torch_rng
generator=self.torch_rng,
shuffle=self.config.shuffle_validation
)
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
dataset=test_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
generator=self.torch_rng
generator=self.torch_rng,
shuffle=self.config.shuffle_test
)

def _setup_metrics(self):
Expand Down Expand Up @@ -343,6 +345,14 @@ def _set_epoch_dataloaders(self, epoch: int):
if self.test_dataloader is not None:
self.test_dataloader.sampler.set_epoch(epoch)

def set_epoch(self, epoch: int) -> None:
"""Set current epoch at the beginning of training.
Args:
epoch (int): epoch number, from 0 to ``epochs-1``.
"""
self._set_epoch_dataloaders(epoch)

def log(
self,
item: Union[Any, List[Any]],
Expand Down Expand Up @@ -429,7 +439,7 @@ def train(self):
best_loss = float('inf')
for epoch in range(self.epochs):
epoch_n = epoch + 1
self._set_epoch_dataloaders(epoch)
self.set_epoch(epoch)
self.train_epoch()
if self.validation_every and epoch_n % self.validation_every == 0:
val_loss = self.validation_epoch()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itwinai.loggers import (
ConsoleLogger,
MLFlowLogger,
WanDBLogger,
WandBLogger,
TensorBoardLogger,
LoggersCollection
)
Expand All @@ -31,7 +31,7 @@ def mlflow_logger():

@pytest.fixture(scope="module")
def wandb_logger():
yield WanDBLogger(savedir='/tmp/wandb/test_mllogs',
yield WandBLogger(savedir='/tmp/wandb/test_mllogs',
project_name='test_project')
shutil.rmtree('/tmp/wandb/test_mllogs', ignore_errors=True)

Expand Down Expand Up @@ -77,9 +77,9 @@ def test_wandb_logger_log(wandb_logger):
with patch('wandb.init') as mock_init, patch('wandb.log') as mock_log:
mock_init.return_value = MagicMock()
wandb_logger.create_logger_context()
wandb_logger.log(0.5, 'test_metric', kind='metric', step=1)
wandb_logger.log(0.5, 'test_metric', kind='metric')
mock_log.assert_called_once_with(
{'test_metric': 0.5}, step=1, commit=True)
{'test_metric': 0.5}, commit=True)


def test_tensorboard_logger_log_tf(tensorboard_logger_tf):
Expand Down Expand Up @@ -160,6 +160,6 @@ def test_loggers_collection_log(loggers_collection):
mock_log_metric.assert_called_once_with(
key='test_metric', value=0.5, step=1)
mock_wandb_log.assert_called_once_with(
{'test_metric': 0.5}, step=1, commit=True)
{'test_metric': 0.5}, commit=True)

loggers_collection.destroy_logger_context()
11 changes: 9 additions & 2 deletions tutorials/distributed-ml/torch-tutorial-0-basics/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,23 @@ def training_fn(
model, optim, lr_scheduler=None, **distribute_kwargs
)

# Data
# Dataset
train_set = UniformRndDataset(x_size=3, y_size=4)
# Distributed dataloader
train_loader = strategy.create_dataloader(
train_set, batch_size=args.batch_size, num_workers=1)
train_set,
batch_size=args.batch_size,
num_workers=1,
shuffle=True
)

# Device allocated for this worker
device = strategy.device()

for epoch in range(2):
# IMPORTANT: set current epoch ID in distributed sampler
train_loader.sampler.set_epoch(epoch)

for (x, y) in train_loader:
# print(f"tensor to cuda:{device}")
x = x.to(device)
Expand Down
8 changes: 6 additions & 2 deletions use-cases/virgo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ training_pipeline:
init_args:
train_proportion: 0.9
rnd_seed: 42
images_dataset: data/Image_dataset_synthetic_64x64.pkl
images_dataset: ${data_root}/Image_dataset_synthetic_64x64.pkl
- class_path: data.TimeSeriesProcessor
- class_path: trainer.NoiseGeneratorTrainer
init_args:
generator: unet
generator: simple #unet
batch_size: ${batch_size}
num_epochs: ${epochs}
strategy: ${strategy}
checkpoint_path: ${checkpoint_path}
random_seed: 17
logger:
class_path: itwinai.loggers.LoggersCollection
init_args:
Expand All @@ -36,4 +37,7 @@ training_pipeline:
init_args:
experiment_name: Noise simulator (Virgo)
log_freq: batch
- class_path: itwinai.loggers.WandBLogger
init_args:
log_freq: batch

7 changes: 5 additions & 2 deletions use-cases/virgo/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
THIS SCRIPT IS DEPRECATED! Follow the instructions on README.
Simplified training script: data generation + training in one
procedural script. This is an INTERMEDIATE step of integration in itwinai.
"""
Expand All @@ -22,8 +25,8 @@
from src.utils import init_weights, calculate_iou_2d

# Global parameters
DATA_ROOT = "data"
LOAD_DATASET = True
DATA_ROOT = "tmp_data"
LOAD_DATASET = False
BATCH_SIZE = 20
LR = 0.00005
SAVE_CHECKPOINT = 'choose_your_path.checkpoint_epoch_{}.pth'
Expand Down
Loading

0 comments on commit 08dae47

Please sign in to comment.