From 5b895d741ce89106d12d8c33a333b31a4bdefa2f Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Fri, 22 Apr 2022 11:03:19 +0800 Subject: [PATCH] feat(exp): get_trainer method, add pre-commit (#1263) --- .pre-commit-config.yaml | 43 +++++++++++++++++++++++++++++++++++++++++ tools/train.py | 8 ++++---- yolox/core/trainer.py | 3 ++- yolox/exp/yolox_base.py | 10 +++++++--- 4 files changed, 56 insertions(+), 8 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..5120983f9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +repos: + - repo: https://github.com/pycqa/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.1.0 + hooks: + - id: check-added-large-files + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-json + - id: check-yaml + args: ["--unsafe"] + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: trailing-whitespace + - repo: https://github.com/jorisroovers/gitlint + rev: v0.15.1 + hooks: + - id: gitlint + - repo: https://github.com/pycqa/isort + rev: 4.3.21 + hooks: + - id: isort + + - repo: https://github.com/PyCQA/autoflake + rev: v1.4 + hooks: + - id: autoflake + name: Remove unused variables and imports + entry: autoflake + language: python + args: + [ + "--in-place", + "--remove-all-unused-imports", + "--remove-unused-variables", + "--expand-star-imports", + "--ignore-init-module-imports", + ] + files: \.py$ diff --git a/tools/train.py b/tools/train.py index 60102ee08..aeab4a681 100644 --- a/tools/train.py +++ b/tools/train.py @@ -10,8 +10,8 @@ import torch import torch.backends.cudnn as cudnn -from yolox.core import Trainer, launch -from yolox.exp import get_exp +from yolox.core import launch +from yolox.exp import Exp, get_exp from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices @@ -97,7 +97,7 @@ def make_parser(): @logger.catch -def main(exp, args): +def main(exp: Exp, args): if exp.seed is not None: random.seed(exp.seed) torch.manual_seed(exp.seed) @@ -113,7 +113,7 @@ def main(exp, args): configure_omp() cudnn.benchmark = True - trainer = Trainer(exp, args) + trainer = exp.get_trainer(args) trainer.train() diff --git a/yolox/core/trainer.py b/yolox/core/trainer.py index a9ee2a681..17acd6e1e 100644 --- a/yolox/core/trainer.py +++ b/yolox/core/trainer.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from yolox.data import DataPrefetcher +from yolox.exp import Exp from yolox.utils import ( MeterBuffer, ModelEMA, @@ -33,7 +34,7 @@ class Trainer: - def __init__(self, exp, args): + def __init__(self, exp: Exp, args): # init function only defines some basic attr, other attrs like model, optimizer are built in # before_train methods. self.exp = exp diff --git a/yolox/exp/yolox_base.py b/yolox/exp/yolox_base.py index 6c0388dce..9ef932c89 100644 --- a/yolox/exp/yolox_base.py +++ b/yolox/exp/yolox_base.py @@ -127,9 +127,7 @@ def init_yolo(M): self.model.train() return self.model - def get_data_loader( - self, batch_size, is_distributed, no_aug=False, cache_img=False - ): + def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False): from yolox.data import ( COCODataset, TrainTransform, @@ -314,5 +312,11 @@ def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False) ) return evaluator + def get_trainer(self, args): + from yolox.core import Trainer + trainer = Trainer(self, args) + # NOTE: trainer shouldn't be an attribute of exp object + return trainer + def eval(self, model, evaluator, is_distributed, half=False): return evaluator.evaluate(model, is_distributed, half)