From 15d4c376bd84eb33c003a59e4bb52bf1aa7fdf08 Mon Sep 17 00:00:00 2001 From: ramlabserver8 Date: Tue, 19 Jul 2022 02:32:03 +0000 Subject: [PATCH] add detectron2 sampler and fix the problem of multi-gpu --- scripts/train.py | 13 ++-- visualDet3D/data/dataloader/__init__.py | 2 + .../data/dataloader/dataloader_builder.py | 19 ++++++ .../data/dataloader/distributed_sampler.py | 60 +++++++++++++++++++ visualDet3D/networks/utils/registry.py | 3 +- 5 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 visualDet3D/data/dataloader/__init__.py create mode 100644 visualDet3D/data/dataloader/dataloader_builder.py create mode 100644 visualDet3D/data/dataloader/distributed_sampler.py diff --git a/scripts/train.py b/scripts/train.py index abcd5d7..ec6f706 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -21,6 +21,7 @@ from visualDet3D.utils.timer import Timer from visualDet3D.utils.utils import LossLogger, cfg_from_file from visualDet3D.networks.optimizers import optimizers, schedulers +from visualDet3D.data.dataloader import build_dataloader def main(config="config/config.py", experiment_name="default", world_size=1, local_rank=-1): """Main function for the training script. @@ -74,11 +75,13 @@ def main(config="config/config.py", experiment_name="default", world_size=1, loc dataset_train = DATASET_DICT[cfg.data.train_dataset](cfg) dataset_val = DATASET_DICT[cfg.data.val_dataset](cfg, "validation") - dataloader_train = DataLoader(dataset_train, num_workers=cfg.data.num_workers, - batch_size=cfg.data.batch_size, collate_fn=dataset_train.collate_fn, shuffle=local_rank<0, drop_last=True, - sampler=torch.utils.data.DistributedSampler(dataset_train, num_replicas=world_size, rank=local_rank, shuffle=True) if local_rank >= 0 else None) - dataloader_val = DataLoader(dataset_val, num_workers=cfg.data.num_workers, - batch_size=cfg.data.batch_size, collate_fn=dataset_val.collate_fn, shuffle=False, drop_last=True) + dataloader_train = build_dataloader(dataset_train, + num_workers=cfg.data.num_workers, + batch_size=cfg.data.batch_size, + collate_fn=dataset_train.collate_fn, + local_rank=local_rank, + world_size=world_size, + sampler_cfg=getattr(cfg.data, 'sampler', dict())) ## Create the model detector = DETECTOR_DICT[cfg.detector.name](cfg.detector) diff --git a/visualDet3D/data/dataloader/__init__.py b/visualDet3D/data/dataloader/__init__.py new file mode 100644 index 0000000..e17174e --- /dev/null +++ b/visualDet3D/data/dataloader/__init__.py @@ -0,0 +1,2 @@ +from .distributed_sampler import * +from .dataloader_builder import build_dataloader diff --git a/visualDet3D/data/dataloader/dataloader_builder.py b/visualDet3D/data/dataloader/dataloader_builder.py new file mode 100644 index 0000000..b1cd042 --- /dev/null +++ b/visualDet3D/data/dataloader/dataloader_builder.py @@ -0,0 +1,19 @@ +from typing import Callable, Optional +from easydict import EasyDict +from torch.utils.data import DataLoader +from visualDet3D.networks.utils.registry import SAMPLER_DICT + +def build_dataloader(dataset, + num_workers: int, + batch_size: int, + collate_fn: Callable, + local_rank: int = -1, + world_size: int = 1, + sampler_cfg: Optional[EasyDict] = dict(), + **kwargs): + sampler_name = sampler_cfg.pop('name') if 'name' in sampler_cfg else 'TrainingSampler' + sampler = SAMPLER_DICT[sampler_name](size=len(dataset), rank=local_rank, world_size=world_size, **sampler_cfg) + + dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn, + sampler=sampler, **kwargs, drop_last=True) + return dataloader diff --git a/visualDet3D/data/dataloader/distributed_sampler.py b/visualDet3D/data/dataloader/distributed_sampler.py new file mode 100644 index 0000000..af6d00b --- /dev/null +++ b/visualDet3D/data/dataloader/distributed_sampler.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import torch +from torch.utils.data.sampler import Sampler +from visualDet3D.networks.utils.registry import SAMPLER_DICT + +@SAMPLER_DICT.register_module +class TrainingSampler(Sampler): + """ + In training, we only care about the "infinite stream" of training data. + So this sampler produces an infinite stream of indices and + all workers cooperate to correctly shuffle the indices and sample different indices. + + The samplers in each worker effectively produces `indices[worker_id::num_workers]` + where `indices` is an infinite stream of indices consisting of + `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) + or `range(size) + range(size) + ...` (if shuffle is False) + + Note that this sampler does not shard based on pytorch DataLoader worker id. + A sampler passed to pytorch DataLoader is used only with map-style dataset + and will not be executed inside workers. + But if this sampler is used in a way that it gets execute inside a dataloader + worker, then extra work needs to be done to shard its outputs based on worker id. + This is required so that workers don't produce identical data. + :class:`ToIterableDataset` implements this logic. + This note is true for all samplers in detectron2. + """ + + def __init__(self, size: int, rank: int = -1, world_size: int = 1, shuffle: bool = True): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + if not isinstance(size, int): + raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.") + if size <= 0: + raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.") + self._size = size + self._shuffle = shuffle + + self._rank = rank + self._world_size = world_size + self.generator = torch.Generator() + + def __len__(self): + return self._size + + def __iter__(self): + start = max(self._rank, 0) + yield from itertools.islice(self._indices(), start, None, self._world_size) + + def _indices(self): + if self._shuffle: + yield from torch.randperm(self._size, generator=self.generator).tolist() + else: + yield from torch.arange(self._size).tolist() diff --git a/visualDet3D/networks/utils/registry.py b/visualDet3D/networks/utils/registry.py index da0798e..6bb1ccd 100644 --- a/visualDet3D/networks/utils/registry.py +++ b/visualDet3D/networks/utils/registry.py @@ -46,4 +46,5 @@ def register_module(self, cls=None): BACKBONE_DICT = Registry("backbones") DETECTOR_DICT = Registry("detectors") PIPELINE_DICT = Registry("pipelines") -AUGMENTATION_DICT = Registry("augmentation") \ No newline at end of file +AUGMENTATION_DICT = Registry("augmentation") +SAMPLER_DICT = Registry("sampler")