Skip to content

Commit

Permalink
feat: distributed mode (resemble-ai#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Burnaev committed Dec 29, 2024
1 parent 8e97814 commit d8fe041
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion resemble_enhance/enhancer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from pathlib import Path

import torch
import torch.distributed as dist
import torchaudio
from tqdm import tqdm

from .inference import denoise, enhance

from ..utils.distributed import local_rank, fix_unset_envs

@torch.inference_mode()
def main():
Expand Down Expand Up @@ -68,6 +69,11 @@ def main():
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
parser.add_argument(
"--distributed_mode",
action="store_true",
help="Enable distributed training across multiple GPUs",
)

args = parser.parse_args()

Expand All @@ -86,6 +92,14 @@ def main():
if args.parallel_mode:
random.shuffle(paths)

if args.distributed_mode:
fix_unset_envs()
dist.init_process_group(backend='nccl' if device == "cuda" else "gloo")
torch.cuda.set_device(local_rank())
num_processed = dist.get_world_size()
rank = dist.get_rank()
paths = paths[rank::num_processed]

if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
Expand Down

0 comments on commit d8fe041

Please sign in to comment.