diff --git a/resemble_enhance/enhancer/__main__.py b/resemble_enhance/enhancer/__main__.py index 0c1ad5c..8d619a6 100644 --- a/resemble_enhance/enhancer/__main__.py +++ b/resemble_enhance/enhancer/__main__.py @@ -4,11 +4,12 @@ from pathlib import Path import torch +import torch.distributed as dist import torchaudio from tqdm import tqdm from .inference import denoise, enhance - +from ..utils.distributed import local_rank, fix_unset_envs @torch.inference_mode() def main(): @@ -68,6 +69,11 @@ def main(): action="store_true", help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel", ) + parser.add_argument( + "--distributed_mode", + action="store_true", + help="Enable distributed training across multiple GPUs", + ) args = parser.parse_args() @@ -86,6 +92,14 @@ def main(): if args.parallel_mode: random.shuffle(paths) + if args.distributed_mode: + fix_unset_envs() + dist.init_process_group(backend='nccl' if device == "cuda" else "gloo") + torch.cuda.set_device(local_rank()) + num_processed = dist.get_world_size() + rank = dist.get_rank() + paths = paths[rank::num_processed] + if len(paths) == 0: print(f"No {args.suffix} files found in the following path: {args.in_dir}") return