Skip to content

Commit

Permalink
feat(exp): get_trainer method, add pre-commit (#1263)
Browse files Browse the repository at this point in the history
  • Loading branch information
FateScript authored Apr 22, 2022
1 parent 68408b4 commit 5b895d7
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
43 changes: 43 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
repos:
- repo: https://github.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:
- id: check-added-large-files
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-json
- id: check-yaml
args: ["--unsafe"]
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/jorisroovers/gitlint
rev: v0.15.1
hooks:
- id: gitlint
- repo: https://github.com/pycqa/isort
rev: 4.3.21
hooks:
- id: isort

- repo: https://github.com/PyCQA/autoflake
rev: v1.4
hooks:
- id: autoflake
name: Remove unused variables and imports
entry: autoflake
language: python
args:
[
"--in-place",
"--remove-all-unused-imports",
"--remove-unused-variables",
"--expand-star-imports",
"--ignore-init-module-imports",
]
files: \.py$
8 changes: 4 additions & 4 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
import torch.backends.cudnn as cudnn

from yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.core import launch
from yolox.exp import Exp, get_exp
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices


Expand Down Expand Up @@ -97,7 +97,7 @@ def make_parser():


@logger.catch
def main(exp, args):
def main(exp: Exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
Expand All @@ -113,7 +113,7 @@ def main(exp, args):
configure_omp()
cudnn.benchmark = True

trainer = Trainer(exp, args)
trainer = exp.get_trainer(args)
trainer.train()


Expand Down
3 changes: 2 additions & 1 deletion yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.utils.tensorboard import SummaryWriter

from yolox.data import DataPrefetcher
from yolox.exp import Exp
from yolox.utils import (
MeterBuffer,
ModelEMA,
Expand All @@ -33,7 +34,7 @@


class Trainer:
def __init__(self, exp, args):
def __init__(self, exp: Exp, args):
# init function only defines some basic attr, other attrs like model, optimizer are built in
# before_train methods.
self.exp = exp
Expand Down
10 changes: 7 additions & 3 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def init_yolo(M):
self.model.train()
return self.model

def get_data_loader(
self, batch_size, is_distributed, no_aug=False, cache_img=False
):
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
from yolox.data import (
COCODataset,
TrainTransform,
Expand Down Expand Up @@ -314,5 +312,11 @@ def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False)
)
return evaluator

def get_trainer(self, args):
from yolox.core import Trainer
trainer = Trainer(self, args)
# NOTE: trainer shouldn't be an attribute of exp object
return trainer

def eval(self, model, evaluator, is_distributed, half=False):
return evaluator.evaluate(model, is_distributed, half)

0 comments on commit 5b895d7

Please sign in to comment.