From 08dae474697690f9d5b770178ef40433c0fa5c16 Mon Sep 17 00:00:00 2001 From: Matteo Bunino <48362942+matbun@users.noreply.github.com> Date: Wed, 12 Jun 2024 09:29:12 +0200 Subject: [PATCH] Patch Virgo dataloader and refactor (#154) * PATCH virgo dataloader and refactor * patch version * Refactor * UPDATE tests --- pyproject.toml | 4 +- src/itwinai/loggers.py | 15 ++++-- src/itwinai/torch/config.py | 10 ++++ src/itwinai/torch/trainer.py | 24 ++++++--- tests/test_loggers.py | 10 ++-- .../torch-tutorial-0-basics/train.py | 11 +++- use-cases/virgo/config.yaml | 8 ++- use-cases/virgo/train.py | 7 ++- use-cases/virgo/trainer.py | 51 +++++++++++-------- 9 files changed, 95 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38f73b29..c17867ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "itwinai" -version = "0.2.0" +version = "0.2.1" description = "AI and ML workflows module for scientific digital twins." readme = "README.md" requires-python = ">=3.10" @@ -61,7 +61,7 @@ dev = [ "pytest-cov>=4.1.0", "ipykernel", "ipython", - "tensorflow==2.16.*", # needed by tests on tensorboard + "tensorflow==2.16.*", # needed by tests on tensorboard ] [project.urls] diff --git a/src/itwinai/loggers.py b/src/itwinai/loggers.py index 4efb1408..00528a84 100644 --- a/src/itwinai/loggers.py +++ b/src/itwinai/loggers.py @@ -420,7 +420,7 @@ def log( ) -class WanDBLogger(Logger): +class WandBLogger(Logger): """Abstraction around WandB logger. Args: @@ -434,6 +434,8 @@ class WanDBLogger(Logger): more details. Defaults to 'epoch'. """ + # TODO: add support for artifacts logging + def __init__( self, savedir: str = 'mllogs', @@ -444,14 +446,15 @@ def __init__( super().__init__(savedir=savedir, log_freq=log_freq) self.project_name = project_name self.supported_types = [ - 'watch', 'metric', 'figure', 'image', 'artifact', 'torch', 'dict', + 'watch', 'metric', 'figure', 'image', 'torch', 'dict', 'param', 'text' ] def create_logger_context(self): """Initialize logger. Init WandB run.""" + os.makedirs(os.path.join(self.savedir, 'wandb'), exist_ok=True) self.active_run = wandb.init( - dir=self.savedir, + dir=os.path.abspath(self.savedir), project=self.project_name ) @@ -484,7 +487,7 @@ def log( kind (str, optional): type of the item to be logged. Must be one among the list of ``self.supported_types``. Defaults to 'metric'. - step (Optional[int], optional): logging step. Defaults to None. + step (Optional[int], optional): ignored by ``WandBLogger``. batch_idx (Optional[int], optional): DataLoader batch counter (i.e., batch idx), if available. Defaults to None. kwargs: keyword arguments to pass to the logger. @@ -495,7 +498,9 @@ def log( if kind == 'watch': wandb.watch(item) if kind in self.supported_types[1:]: - wandb.log({identifier: item}, step=step, commit=True) + # wandb.log({identifier: item}, step=step, commit=True) + # Let WandB use its preferred step + wandb.log({identifier: item}, commit=True) class TensorBoardLogger(Logger): diff --git a/src/itwinai/torch/config.py b/src/itwinai/torch/config.py index 70fb9567..64b3dbb6 100644 --- a/src/itwinai/torch/config.py +++ b/src/itwinai/torch/config.py @@ -28,6 +28,16 @@ class TrainingConfiguration(Configuration): #: Batch size. In a distributed environment it is usually the #: per-worker batch size. Defaults to 32. batch_size: int = 32 + #: Whether to shuffle train dataset when creating a torch ``DataLoader``. + #: Defaults to False. + shuffle_train: bool = False + #: Whether to shuffle validation dataset when creating a torch + #: ``DataLoader``. + #: Defaults to False. + shuffle_validation: bool = False + #: Whether to shuffle test dataset when creating a torch ``DataLoader``. + #: Defaults to False. + shuffle_test: bool = False #: Whether to pin GPU memory. Property of torch ``DataLoader``. #: Defaults to False. pin_memory: bool = False diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 6d2ecd96..67f2b5da 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -262,29 +262,31 @@ def create_dataloaders( # may be interested to override! # ################################### - # TODO: improve robustness of getting from config self.train_dataloader = self.strategy.create_dataloader( dataset=train_dataset, batch_size=self.config.batch_size, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, - generator=self.torch_rng + generator=self.torch_rng, + shuffle=self.config.shuffle_train ) if validation_dataset is not None: self.validation_dataloader = self.strategy.create_dataloader( - dataset=train_dataset, + dataset=validation_dataset, batch_size=self.config.batch_size, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, - generator=self.torch_rng + generator=self.torch_rng, + shuffle=self.config.shuffle_validation ) if test_dataset is not None: self.test_dataloader = self.strategy.create_dataloader( - dataset=train_dataset, + dataset=test_dataset, batch_size=self.config.batch_size, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, - generator=self.torch_rng + generator=self.torch_rng, + shuffle=self.config.shuffle_test ) def _setup_metrics(self): @@ -343,6 +345,14 @@ def _set_epoch_dataloaders(self, epoch: int): if self.test_dataloader is not None: self.test_dataloader.sampler.set_epoch(epoch) + def set_epoch(self, epoch: int) -> None: + """Set current epoch at the beginning of training. + + Args: + epoch (int): epoch number, from 0 to ``epochs-1``. + """ + self._set_epoch_dataloaders(epoch) + def log( self, item: Union[Any, List[Any]], @@ -429,7 +439,7 @@ def train(self): best_loss = float('inf') for epoch in range(self.epochs): epoch_n = epoch + 1 - self._set_epoch_dataloaders(epoch) + self.set_epoch(epoch) self.train_epoch() if self.validation_every and epoch_n % self.validation_every == 0: val_loss = self.validation_epoch() diff --git a/tests/test_loggers.py b/tests/test_loggers.py index 7f13ab45..2a1d6428 100644 --- a/tests/test_loggers.py +++ b/tests/test_loggers.py @@ -6,7 +6,7 @@ from itwinai.loggers import ( ConsoleLogger, MLFlowLogger, - WanDBLogger, + WandBLogger, TensorBoardLogger, LoggersCollection ) @@ -31,7 +31,7 @@ def mlflow_logger(): @pytest.fixture(scope="module") def wandb_logger(): - yield WanDBLogger(savedir='/tmp/wandb/test_mllogs', + yield WandBLogger(savedir='/tmp/wandb/test_mllogs', project_name='test_project') shutil.rmtree('/tmp/wandb/test_mllogs', ignore_errors=True) @@ -77,9 +77,9 @@ def test_wandb_logger_log(wandb_logger): with patch('wandb.init') as mock_init, patch('wandb.log') as mock_log: mock_init.return_value = MagicMock() wandb_logger.create_logger_context() - wandb_logger.log(0.5, 'test_metric', kind='metric', step=1) + wandb_logger.log(0.5, 'test_metric', kind='metric') mock_log.assert_called_once_with( - {'test_metric': 0.5}, step=1, commit=True) + {'test_metric': 0.5}, commit=True) def test_tensorboard_logger_log_tf(tensorboard_logger_tf): @@ -160,6 +160,6 @@ def test_loggers_collection_log(loggers_collection): mock_log_metric.assert_called_once_with( key='test_metric', value=0.5, step=1) mock_wandb_log.assert_called_once_with( - {'test_metric': 0.5}, step=1, commit=True) + {'test_metric': 0.5}, commit=True) loggers_collection.destroy_logger_context() diff --git a/tutorials/distributed-ml/torch-tutorial-0-basics/train.py b/tutorials/distributed-ml/torch-tutorial-0-basics/train.py index 7c124ec3..a2a4c6e3 100644 --- a/tutorials/distributed-ml/torch-tutorial-0-basics/train.py +++ b/tutorials/distributed-ml/torch-tutorial-0-basics/train.py @@ -79,16 +79,23 @@ def training_fn( model, optim, lr_scheduler=None, **distribute_kwargs ) - # Data + # Dataset train_set = UniformRndDataset(x_size=3, y_size=4) # Distributed dataloader train_loader = strategy.create_dataloader( - train_set, batch_size=args.batch_size, num_workers=1) + train_set, + batch_size=args.batch_size, + num_workers=1, + shuffle=True + ) # Device allocated for this worker device = strategy.device() for epoch in range(2): + # IMPORTANT: set current epoch ID in distributed sampler + train_loader.sampler.set_epoch(epoch) + for (x, y) in train_loader: # print(f"tensor to cuda:{device}") x = x.to(device) diff --git a/use-cases/virgo/config.yaml b/use-cases/virgo/config.yaml index da14342e..7866ce49 100644 --- a/use-cases/virgo/config.yaml +++ b/use-cases/virgo/config.yaml @@ -16,15 +16,16 @@ training_pipeline: init_args: train_proportion: 0.9 rnd_seed: 42 - images_dataset: data/Image_dataset_synthetic_64x64.pkl + images_dataset: ${data_root}/Image_dataset_synthetic_64x64.pkl - class_path: data.TimeSeriesProcessor - class_path: trainer.NoiseGeneratorTrainer init_args: - generator: unet + generator: simple #unet batch_size: ${batch_size} num_epochs: ${epochs} strategy: ${strategy} checkpoint_path: ${checkpoint_path} + random_seed: 17 logger: class_path: itwinai.loggers.LoggersCollection init_args: @@ -36,4 +37,7 @@ training_pipeline: init_args: experiment_name: Noise simulator (Virgo) log_freq: batch + - class_path: itwinai.loggers.WandBLogger + init_args: + log_freq: batch diff --git a/use-cases/virgo/train.py b/use-cases/virgo/train.py index 503f0990..2e616767 100644 --- a/use-cases/virgo/train.py +++ b/use-cases/virgo/train.py @@ -1,4 +1,7 @@ """ + +THIS SCRIPT IS DEPRECATED! Follow the instructions on README. + Simplified training script: data generation + training in one procedural script. This is an INTERMEDIATE step of integration in itwinai. """ @@ -22,8 +25,8 @@ from src.utils import init_weights, calculate_iou_2d # Global parameters -DATA_ROOT = "data" -LOAD_DATASET = True +DATA_ROOT = "tmp_data" +LOAD_DATASET = False BATCH_SIZE = 20 LR = 0.00005 SAVE_CHECKPOINT = 'choose_your_path.checkpoint_epoch_{}.pth' diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index ea4dae28..70332dc7 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -32,6 +32,7 @@ def __init__( checkpoint_path: str = "checkpoints/epoch_{}.pth", save_best: bool = True, logger: Optional[Logger] = None, + random_seed: Optional[int] = None, name: str | None = None ) -> None: super().__init__( @@ -39,6 +40,7 @@ def __init__( config={}, strategy=strategy, logger=logger, + random_seed=random_seed, name=name ) self.save_parameters(**self.locals2params(locals())) @@ -49,12 +51,12 @@ def __init__( self._loss = loss self.checkpoint_path = checkpoint_path os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) - # Global configuration - _config = dict( + # Global training configuration + self.config = TrainingConfiguration( batch_size=batch_size, - save_best=save_best + save_best=save_best, + shuffle_train=True ) - self.config = TrainingConfiguration(**_config) def create_model_loss_optimizer(self) -> None: # Select generator @@ -118,10 +120,13 @@ def train(self): acc_plot = [] val_acc_plot = [] best_val_loss = float('inf') - for epoch in tqdm(range(1, self.num_epochs+1)): + for epoch in tqdm(range(self.num_epochs)): + # itwinai - IMPORTANT: set current epoch ID + self.set_epoch(epoch) + st = time.time() epoch_loss = [] - epoch_acc = [] + # epoch_acc = [] for i, batch in enumerate(self.train_dataloader): # batch= transform(batch) target = batch[:, 0].unsqueeze(1).to(self.device) @@ -137,6 +142,7 @@ def train(self): loss.backward() self.optimizer.step() epoch_loss.append(loss.detach().cpu().numpy()) + # itwinai - log loss as metric self.log(loss.detach().cpu().numpy(), 'epoch_loss_batch', kind='metric', @@ -145,8 +151,8 @@ def train(self): # acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20) # epoch_acc.append(acc) val_loss = [] - val_acc = [] - for batch in (self.validation_dataloader): + # val_acc = [] + for i, batch in enumerate(self.validation_dataloader): # batch= transform(batch) target = batch[:, 0].unsqueeze(1).to(self.device) target = target.float() @@ -155,20 +161,21 @@ def train(self): generated = self.model(input.float()) # generated=normalize_(generated,1) loss = self.loss(generated, target) - val_loss.append(loss.detach().cpu().numpy()) - self.log(loss.detach().cpu().numpy(), - 'val_loss_batch', - kind='metric', - step=epoch*len(self.validation_dataloader) + i, - batch_idx=i) - # acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20) - # val_acc.append(acc) + val_loss.append(loss.detach().cpu().numpy()) + # itwinai -log loss as metric + self.log(loss.detach().cpu().numpy(), + 'val_loss_batch', + kind='metric', + step=epoch*len(self.validation_dataloader) + i, + batch_idx=i) + # acc=accuracy(generated.detach().cpu().numpy(),target.detach().cpu().numpy(),20) + # val_acc.append(acc) loss_plot.append(np.mean(epoch_loss)) val_loss_plot.append(np.mean(val_loss)) - acc_plot.append(np.mean(epoch_acc)) - val_acc_plot.append(np.mean(val_acc)) + # acc_plot.append(np.mean(epoch_acc)) + # val_acc_plot.append(np.mean(val_acc)) - # Log metrics/losses + # itwinai - Log metrics/losses self.log(np.mean(epoch_loss), 'epoch_loss', kind='metric', step=epoch) self.log(np.mean(val_loss), 'val_loss', @@ -182,12 +189,13 @@ def train(self): # accuracy: {}'.format(epoch,loss_plot[-1],val_loss_plot[-1], # acc_plot[-1],val_acc_plot[-1])) et = time.time() + # itwinai - print() in a multi-worker context (distributed) if self.strategy.is_main_worker: print('epoch: {} loss: {} val loss: {} time:{}s'.format( epoch, loss_plot[-1], val_loss_plot[-1], et-st)) # Save checkpoint every 100 epochs - if (epoch+1) % 1 == 0: + if epoch % 1 == 0: # uncomment the following if you want to save checkpoint every # 100 epochs regardless of the performance of the model # checkpoint = { @@ -203,6 +211,7 @@ def train(self): # torch.save(checkpoint, checkpoint_filename) # Average loss among all workers + # itwinai - gather local loss from all the workers worker_val_losses = self.strategy.gather_obj(val_loss_plot[-1]) if self.strategy.is_main_worker: # Save only in the main worker @@ -227,6 +236,7 @@ def train(self): checkpoint_filename = self.checkpoint_path.format( epoch) torch.save(checkpoint, checkpoint_filename) + # itwinai - log checkpoint as artifact self.log(checkpoint_filename, os.path.basename(checkpoint_filename), kind='artifact') @@ -237,6 +247,7 @@ def train(self): self.checkpoint_path.format('best') ) torch.save(checkpoint, best_checkpoint_filename) + # itwinai - log checkpoint as artifact self.log(best_checkpoint_filename, os.path.basename(best_checkpoint_filename), kind='artifact')