diff --git a/docs/en/3_exist_data_new_model.md b/docs/en/3_exist_data_new_model.md index c69afd20386..7201ffdecb6 100644 --- a/docs/en/3_exist_data_new_model.md +++ b/docs/en/3_exist_data_new_model.md @@ -41,6 +41,12 @@ mmdetection ``` +Or you can set your dataset root through +```bash +export MMDET_DATASETS=$data_root +``` +We will replace dataset root with `$MMDET_DATASETS`, so you don't have to modify the corresponding path in config files. + The cityscapes annotations have to be converted into the coco format using `tools/dataset_converters/cityscapes.py`: ```shell diff --git a/docs/zh_cn/3_exist_data_new_model.md b/docs/zh_cn/3_exist_data_new_model.md index 5ac09c01afb..a9c19ca9428 100644 --- a/docs/zh_cn/3_exist_data_new_model.md +++ b/docs/zh_cn/3_exist_data_new_model.md @@ -40,6 +40,13 @@ mmdetection │ │ ├── VOC2012 ``` +你也可以通过如下方式设定数据集根路径 +```bash +export MMDET_DATASETS=$data_root +``` +我们将会使用环境便变量 `$MMDET_DATASETS` 作为数据集的根目录,因此你无需再修改相应配置文件的路径信息。 + + 你需要使用脚本 `tools/dataset_converters/cityscapes.py` 将 cityscapes 标注转化为 coco 标注格式。 ```shell diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 3873ec09c67..a6635d3c0f2 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env from .logger import get_caller_name, get_root_logger, log_img_scale -from .misc import find_latest_checkpoint +from .misc import find_latest_checkpoint, update_data_root from .setup_env import setup_multi_processes __all__ = [ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', - 'setup_multi_processes', 'get_caller_name', 'log_img_scale' + 'update_data_root', 'setup_multi_processes', 'get_caller_name', + 'log_img_scale' ] diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py index f5c425300e4..4113672acfb 100644 --- a/mmdet/utils/misc.py +++ b/mmdet/utils/misc.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import glob +import os import os.path as osp import warnings +import mmcv +from mmcv.utils import print_log + def find_latest_checkpoint(path, suffix='pth'): """Find the latest checkpoint from the working directory. @@ -36,3 +40,37 @@ def find_latest_checkpoint(path, suffix='pth'): latest = count latest_path = checkpoint return latest_path + + +def update_data_root(cfg, logger=None): + """Update data root according to env MMDET_DATASETS. + + If set env MMDET_DATASETS, update cfg.data_root according to + MMDET_DATASETS. Otherwise, using cfg.data_root as default. + + Args: + cfg (mmcv.Config): The model config need to modify + logger (logging.Logger | str | None): the way to print msg + """ + assert isinstance(cfg, mmcv.Config), \ + f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' + + if 'MMDET_DATASETS' in os.environ: + dst_root = os.environ['MMDET_DATASETS'] + print_log(f'MMDET_DATASETS has been set to be {dst_root}.' + f'Using {dst_root} as data root.') + else: + return + + assert isinstance(cfg, mmcv.Config), \ + f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' + + def update(cfg, src_str, dst_str): + for k, v in cfg.items(): + if isinstance(v, mmcv.ConfigDict): + update(cfg[k], src_str, dst_str) + if isinstance(v, str) and src_str in v: + cfg[k] = v.replace(src_str, dst_str) + + update(cfg.data, cfg.data_root, dst_root) + cfg.data_root = dst_root diff --git a/tools/analysis_tools/analyze_results.py b/tools/analysis_tools/analyze_results.py index cb79587a65c..15db07e41c7 100644 --- a/tools/analysis_tools/analyze_results.py +++ b/tools/analysis_tools/analyze_results.py @@ -9,6 +9,7 @@ from mmdet.core.evaluation import eval_map from mmdet.core.visualization import imshow_gt_det_bboxes from mmdet.datasets import build_dataset, get_loading_pipeline +from mmdet.utils import update_data_root def bbox_map_eval(det_result, annotation): @@ -186,6 +187,10 @@ def main(): mmcv.check_file_exist(args.prediction_path) cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) cfg.data.test.test_mode = True diff --git a/tools/analysis_tools/benchmark.py b/tools/analysis_tools/benchmark.py index 91f34c74063..2be2d14d7b4 100644 --- a/tools/analysis_tools/benchmark.py +++ b/tools/analysis_tools/benchmark.py @@ -13,6 +13,7 @@ from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.models import build_detector +from mmdet.utils import update_data_root def parse_args(): @@ -170,6 +171,10 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index a2531ba81f9..224b93b314a 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -10,6 +10,7 @@ from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.datasets import build_dataset +from mmdet.utils import update_data_root def parse_args(): @@ -230,6 +231,10 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) diff --git a/tools/analysis_tools/eval_metric.py b/tools/analysis_tools/eval_metric.py index 1fcdc1c1fa9..a074c9e1850 100644 --- a/tools/analysis_tools/eval_metric.py +++ b/tools/analysis_tools/eval_metric.py @@ -5,6 +5,7 @@ from mmcv import Config, DictAction from mmdet.datasets import build_dataset +from mmdet.utils import update_data_root def parse_args(): @@ -48,6 +49,10 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + assert args.eval or args.format_only, ( 'Please specify at least one operation (eval/format the results) with ' 'the argument "--eval", "--format-only"') diff --git a/tools/analysis_tools/optimize_anchors.py b/tools/analysis_tools/optimize_anchors.py index d0da0cbc61d..acf72acb26c 100644 --- a/tools/analysis_tools/optimize_anchors.py +++ b/tools/analysis_tools/optimize_anchors.py @@ -29,7 +29,7 @@ from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh from mmdet.datasets import build_dataset -from mmdet.utils import get_root_logger +from mmdet.utils import get_root_logger, update_data_root def parse_args(): @@ -325,6 +325,9 @@ def main(): cfg = args.config cfg = Config.fromfile(cfg) + # update data root according to MMDET_DATASETS + update_data_root(cfg) + input_shape = args.input_shape assert len(input_shape) == 2 diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index 3e70c8b8741..14db64ee050 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -11,6 +11,7 @@ from mmdet.core.utils import mask2ndarray from mmdet.core.visualization import imshow_det_bboxes from mmdet.datasets.builder import build_dataset +from mmdet.utils import update_data_root def parse_args(): @@ -55,6 +56,10 @@ def skip_pipeline_steps(config): ] cfg = Config.fromfile(config_path) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if cfg_options is not None: cfg.merge_from_dict(cfg_options) train_data_cfg = cfg.data.train diff --git a/tools/misc/print_config.py b/tools/misc/print_config.py index 1b2cb30c24c..7bb20fa60de 100644 --- a/tools/misc/print_config.py +++ b/tools/misc/print_config.py @@ -4,6 +4,8 @@ from mmcv import Config, DictAction +from mmdet.utils import update_data_root + def parse_args(): parser = argparse.ArgumentParser(description='Print the whole config') @@ -42,6 +44,10 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) print(f'Config:\n{cfg.pretty_text}') diff --git a/tools/test.py b/tools/test.py index dfbc425869e..baa149d8418 100644 --- a/tools/test.py +++ b/tools/test.py @@ -17,7 +17,7 @@ from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.models import build_detector -from mmdet.utils import setup_multi_processes +from mmdet.utils import setup_multi_processes, update_data_root def parse_args(): @@ -133,6 +133,10 @@ def main(): raise ValueError('The output file must be a pkl file.') cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) diff --git a/tools/train.py b/tools/train.py index 2ccc1c88f84..5f184608f1c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -17,7 +17,8 @@ from mmdet.apis import init_random_seed, set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector -from mmdet.utils import collect_env, get_root_logger, setup_multi_processes +from mmdet.utils import (collect_env, get_root_logger, setup_multi_processes, + update_data_root) def parse_args(): @@ -103,6 +104,10 @@ def main(): args = parse_args() cfg = Config.fromfile(args.config) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options)