Skip to content

Commit

Permalink
add detectron2 sampler and fix the problem of multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Owen-Liuyuxuan committed Jul 19, 2022
1 parent 9508821 commit 15d4c37
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
13 changes: 8 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from visualDet3D.utils.timer import Timer
from visualDet3D.utils.utils import LossLogger, cfg_from_file
from visualDet3D.networks.optimizers import optimizers, schedulers
from visualDet3D.data.dataloader import build_dataloader

def main(config="config/config.py", experiment_name="default", world_size=1, local_rank=-1):
"""Main function for the training script.
Expand Down Expand Up @@ -74,11 +75,13 @@ def main(config="config/config.py", experiment_name="default", world_size=1, loc
dataset_train = DATASET_DICT[cfg.data.train_dataset](cfg)
dataset_val = DATASET_DICT[cfg.data.val_dataset](cfg, "validation")

dataloader_train = DataLoader(dataset_train, num_workers=cfg.data.num_workers,
batch_size=cfg.data.batch_size, collate_fn=dataset_train.collate_fn, shuffle=local_rank<0, drop_last=True,
sampler=torch.utils.data.DistributedSampler(dataset_train, num_replicas=world_size, rank=local_rank, shuffle=True) if local_rank >= 0 else None)
dataloader_val = DataLoader(dataset_val, num_workers=cfg.data.num_workers,
batch_size=cfg.data.batch_size, collate_fn=dataset_val.collate_fn, shuffle=False, drop_last=True)

This comment has been minimized.

Copy link
@Owen-Liuyuxuan

Owen-Liuyuxuan Jul 19, 2022

Author Owner

The problem with the original file is that:
we need to run set_epoch for the distributed sampler after each epoch. Otherwise the "shuffle" will give the same indices at different epochs, which is similar to just setting shuffle to False

dataloader_train = build_dataloader(dataset_train,
num_workers=cfg.data.num_workers,
batch_size=cfg.data.batch_size,
collate_fn=dataset_train.collate_fn,
local_rank=local_rank,

This comment has been minimized.

Copy link
@Owen-Liuyuxuan

Owen-Liuyuxuan Jul 19, 2022

Author Owner

Now we turn to detectron2's implementation of data samplers, which automatically handles this set-epoch problem.

And now, writing data samplers with custom codes is more convenient.

world_size=world_size,
sampler_cfg=getattr(cfg.data, 'sampler', dict()))

## Create the model
detector = DETECTOR_DICT[cfg.detector.name](cfg.detector)
Expand Down
2 changes: 2 additions & 0 deletions visualDet3D/data/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .distributed_sampler import *
from .dataloader_builder import build_dataloader
19 changes: 19 additions & 0 deletions visualDet3D/data/dataloader/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Callable, Optional
from easydict import EasyDict
from torch.utils.data import DataLoader
from visualDet3D.networks.utils.registry import SAMPLER_DICT

def build_dataloader(dataset,
num_workers: int,
batch_size: int,
collate_fn: Callable,
local_rank: int = -1,
world_size: int = 1,
sampler_cfg: Optional[EasyDict] = dict(),
**kwargs):
sampler_name = sampler_cfg.pop('name') if 'name' in sampler_cfg else 'TrainingSampler'
sampler = SAMPLER_DICT[sampler_name](size=len(dataset), rank=local_rank, world_size=world_size, **sampler_cfg)

dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn,
sampler=sampler, **kwargs, drop_last=True)
return dataloader
60 changes: 60 additions & 0 deletions visualDet3D/data/dataloader/distributed_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import torch
from torch.utils.data.sampler import Sampler
from visualDet3D.networks.utils.registry import SAMPLER_DICT

@SAMPLER_DICT.register_module
class TrainingSampler(Sampler):
"""
In training, we only care about the "infinite stream" of training data.
So this sampler produces an infinite stream of indices and
all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
where `indices` is an infinite stream of indices consisting of
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
or `range(size) + range(size) + ...` (if shuffle is False)
Note that this sampler does not shard based on pytorch DataLoader worker id.
A sampler passed to pytorch DataLoader is used only with map-style dataset
and will not be executed inside workers.
But if this sampler is used in a way that it gets execute inside a dataloader
worker, then extra work needs to be done to shard its outputs based on worker id.
This is required so that workers don't produce identical data.
:class:`ToIterableDataset` implements this logic.
This note is true for all samplers in detectron2.
"""

def __init__(self, size: int, rank: int = -1, world_size: int = 1, shuffle: bool = True):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
shuffle (bool): whether to shuffle the indices or not
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
if not isinstance(size, int):
raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.")
if size <= 0:
raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.")
self._size = size
self._shuffle = shuffle

self._rank = rank
self._world_size = world_size
self.generator = torch.Generator()

def __len__(self):
return self._size

def __iter__(self):
start = max(self._rank, 0)
yield from itertools.islice(self._indices(), start, None, self._world_size)

def _indices(self):
if self._shuffle:
yield from torch.randperm(self._size, generator=self.generator).tolist()
else:
yield from torch.arange(self._size).tolist()
3 changes: 2 additions & 1 deletion visualDet3D/networks/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ def register_module(self, cls=None):
BACKBONE_DICT = Registry("backbones")
DETECTOR_DICT = Registry("detectors")
PIPELINE_DICT = Registry("pipelines")
AUGMENTATION_DICT = Registry("augmentation")
AUGMENTATION_DICT = Registry("augmentation")
SAMPLER_DICT = Registry("sampler")

0 comments on commit 15d4c37

Please sign in to comment.