From c275a8c92af7b7ffb3516d66a3c55d41906aea56 Mon Sep 17 00:00:00 2001 From: Matteo Bunino Date: Wed, 8 Jan 2025 10:09:07 +0100 Subject: [PATCH] fix linter --- src/itwinai/torch/distributed.py | 12 ++++++------ .../distributed-ml/torch-kubeflow-1/train-cpu.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/itwinai/torch/distributed.py b/src/itwinai/torch/distributed.py index 452d0902..adcdec23 100644 --- a/src/itwinai/torch/distributed.py +++ b/src/itwinai/torch/distributed.py @@ -494,9 +494,9 @@ def local_world_size(self) -> int: return torch.cuda.device_count() if "LOCAL_WORLD_SIZE" not in os.environ: raise RuntimeError( - "Could not retrieve local world size as CUDA is unavailable and there is " - "no 'LOCAL_WORLD_SIZE' environment variable." - ) + "Could not retrieve local world size as CUDA is unavailable and there is " + "no 'LOCAL_WORLD_SIZE' environment variable." + ) return int(os.environ["LOCAL_WORLD_SIZE"]) @check_initialized @@ -676,9 +676,9 @@ def local_world_size(self) -> int: return torch.cuda.device_count() if "LOCAL_WORLD_SIZE" not in os.environ: raise RuntimeError( - "Could not retrieve local world size as CUDA is unavailable and there is " - "no 'LOCAL_WORLD_SIZE' environment variable." - ) + "Could not retrieve local world size as CUDA is unavailable and there is " + "no 'LOCAL_WORLD_SIZE' environment variable." + ) return int(os.environ["LOCAL_WORLD_SIZE"]) @check_initialized diff --git a/tutorials/distributed-ml/torch-kubeflow-1/train-cpu.py b/tutorials/distributed-ml/torch-kubeflow-1/train-cpu.py index 1cb818cc..6a6cd703 100644 --- a/tutorials/distributed-ml/torch-kubeflow-1/train-cpu.py +++ b/tutorials/distributed-ml/torch-kubeflow-1/train-cpu.py @@ -114,10 +114,11 @@ def main(): open("DATASET_READY", "w") else: import time + while not os.path.exists("DATASET_READY"): # Wait for the dataset to be downloaded time.sleep(1) - + # Dataset creation train_dataset = datasets.MNIST("data", train=True, download=False, transform=transform) validation_dataset = datasets.MNIST(