From 0f90c31d913d64122458c21dd4b25cdfce361522 Mon Sep 17 00:00:00 2001 From: Eugen Hotaj Date: Tue, 17 Dec 2024 09:53:02 -0500 Subject: [PATCH] [EZ] Fix set_torch_num_threads in multi-node. Current code assumes we only have a single node and sets num_threads incorrectly. This will fail entirely when world_size > num_threads. --- torchtune/training/_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 025b9db15..e46b1ceec 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -110,7 +110,7 @@ def set_torch_num_threads() -> None: things like CPU affinity is set. """ num_threads = os.cpu_count() // ( - torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + torch.cuda.device_count() if torch.cuda.is_available() else 1 ) torch.set_num_threads(num_threads) _log.info(f"Set intra op parallelism no. of threads to {num_threads}")