-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add detectron2 sampler and fix the problem of multi-gpu
- Loading branch information
1 parent
9508821
commit 15d4c37
Showing
5 changed files
with
91 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from visualDet3D.utils.timer import Timer | ||
from visualDet3D.utils.utils import LossLogger, cfg_from_file | ||
from visualDet3D.networks.optimizers import optimizers, schedulers | ||
from visualDet3D.data.dataloader import build_dataloader | ||
|
||
def main(config="config/config.py", experiment_name="default", world_size=1, local_rank=-1): | ||
"""Main function for the training script. | ||
|
@@ -74,11 +75,13 @@ def main(config="config/config.py", experiment_name="default", world_size=1, loc | |
dataset_train = DATASET_DICT[cfg.data.train_dataset](cfg) | ||
dataset_val = DATASET_DICT[cfg.data.val_dataset](cfg, "validation") | ||
|
||
dataloader_train = DataLoader(dataset_train, num_workers=cfg.data.num_workers, | ||
batch_size=cfg.data.batch_size, collate_fn=dataset_train.collate_fn, shuffle=local_rank<0, drop_last=True, | ||
sampler=torch.utils.data.DistributedSampler(dataset_train, num_replicas=world_size, rank=local_rank, shuffle=True) if local_rank >= 0 else None) | ||
dataloader_val = DataLoader(dataset_val, num_workers=cfg.data.num_workers, | ||
batch_size=cfg.data.batch_size, collate_fn=dataset_val.collate_fn, shuffle=False, drop_last=True) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
dataloader_train = build_dataloader(dataset_train, | ||
num_workers=cfg.data.num_workers, | ||
batch_size=cfg.data.batch_size, | ||
collate_fn=dataset_train.collate_fn, | ||
local_rank=local_rank, | ||
This comment has been minimized.
Sorry, something went wrong.
Owen-Liuyuxuan
Author
Owner
|
||
world_size=world_size, | ||
sampler_cfg=getattr(cfg.data, 'sampler', dict())) | ||
|
||
## Create the model | ||
detector = DETECTOR_DICT[cfg.detector.name](cfg.detector) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .distributed_sampler import * | ||
from .dataloader_builder import build_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Callable, Optional | ||
from easydict import EasyDict | ||
from torch.utils.data import DataLoader | ||
from visualDet3D.networks.utils.registry import SAMPLER_DICT | ||
|
||
def build_dataloader(dataset, | ||
num_workers: int, | ||
batch_size: int, | ||
collate_fn: Callable, | ||
local_rank: int = -1, | ||
world_size: int = 1, | ||
sampler_cfg: Optional[EasyDict] = dict(), | ||
**kwargs): | ||
sampler_name = sampler_cfg.pop('name') if 'name' in sampler_cfg else 'TrainingSampler' | ||
sampler = SAMPLER_DICT[sampler_name](size=len(dataset), rank=local_rank, world_size=world_size, **sampler_cfg) | ||
|
||
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn, | ||
sampler=sampler, **kwargs, drop_last=True) | ||
return dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import itertools | ||
import torch | ||
from torch.utils.data.sampler import Sampler | ||
from visualDet3D.networks.utils.registry import SAMPLER_DICT | ||
|
||
@SAMPLER_DICT.register_module | ||
class TrainingSampler(Sampler): | ||
""" | ||
In training, we only care about the "infinite stream" of training data. | ||
So this sampler produces an infinite stream of indices and | ||
all workers cooperate to correctly shuffle the indices and sample different indices. | ||
The samplers in each worker effectively produces `indices[worker_id::num_workers]` | ||
where `indices` is an infinite stream of indices consisting of | ||
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) | ||
or `range(size) + range(size) + ...` (if shuffle is False) | ||
Note that this sampler does not shard based on pytorch DataLoader worker id. | ||
A sampler passed to pytorch DataLoader is used only with map-style dataset | ||
and will not be executed inside workers. | ||
But if this sampler is used in a way that it gets execute inside a dataloader | ||
worker, then extra work needs to be done to shard its outputs based on worker id. | ||
This is required so that workers don't produce identical data. | ||
:class:`ToIterableDataset` implements this logic. | ||
This note is true for all samplers in detectron2. | ||
""" | ||
|
||
def __init__(self, size: int, rank: int = -1, world_size: int = 1, shuffle: bool = True): | ||
""" | ||
Args: | ||
size (int): the total number of data of the underlying dataset to sample from | ||
shuffle (bool): whether to shuffle the indices or not | ||
seed (int): the initial seed of the shuffle. Must be the same | ||
across all workers. If None, will use a random seed shared | ||
among workers (require synchronization among all workers). | ||
""" | ||
if not isinstance(size, int): | ||
raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.") | ||
if size <= 0: | ||
raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.") | ||
self._size = size | ||
self._shuffle = shuffle | ||
|
||
self._rank = rank | ||
self._world_size = world_size | ||
self.generator = torch.Generator() | ||
|
||
def __len__(self): | ||
return self._size | ||
|
||
def __iter__(self): | ||
start = max(self._rank, 0) | ||
yield from itertools.islice(self._indices(), start, None, self._world_size) | ||
|
||
def _indices(self): | ||
if self._shuffle: | ||
yield from torch.randperm(self._size, generator=self.generator).tolist() | ||
else: | ||
yield from torch.arange(self._size).tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The problem with the original file is that:
we need to run set_epoch for the distributed sampler after each epoch. Otherwise the "shuffle" will give the same indices at different epochs, which is similar to just setting shuffle to False