diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 025b9db15..e46b1ceec 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -110,7 +110,7 @@ def set_torch_num_threads() -> None: things like CPU affinity is set. """ num_threads = os.cpu_count() // ( - torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + torch.cuda.device_count() if torch.cuda.is_available() else 1 ) torch.set_num_threads(num_threads) _log.info(f"Set intra op parallelism no. of threads to {num_threads}")