Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom optimizer #132

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nodes import *
from .optimizers import *
from .schedulers import *
from .strategies import *
from .utils import *
except ImportError as e:
warnings.warn(
Expand Down
3 changes: 3 additions & 0 deletions luxonis_train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .metadata_logger import MetadataLogger
from .module_freezer import ModuleFreezer
from .test_on_train_end import TestOnTrainEnd
from .training_manager import TrainingManager
from .upload_checkpoint import UploadCheckpoint

CALLBACKS.register_module(module=EarlyStopping)
Expand All @@ -38,6 +39,7 @@
CALLBACKS.register_module(module=ModelPruning)
CALLBACKS.register_module(module=GradCamCallback)
CALLBACKS.register_module(module=EMACallback)
CALLBACKS.register_module(module=TrainingManager)


__all__ = [
Expand All @@ -53,4 +55,5 @@
"GPUStatsMonitor",
"GradCamCallback",
"EMACallback",
"TrainingManager",
]
28 changes: 28 additions & 0 deletions luxonis_train/callbacks/training_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytorch_lightning as pl

from luxonis_train.strategies.base_strategy import BaseTrainingStrategy


class TrainingManager(pl.Callback):
def __init__(self, strategy: BaseTrainingStrategy | None = None):
"""Training manager callback that updates the parameters of the
training strategy.

@type strategy: BaseTrainingStrategy
@param strategy: The strategy to be used.
"""
self.strategy = strategy

def on_after_backward(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
"""PyTorch Lightning hook that is called after the backward
pass.

@type trainer: pl.Trainer
@param trainer: The trainer object.
@type pl_module: pl.LightningModule
@param pl_module: The pl_module object.
"""
if self.strategy is not None:
self.strategy.update_parameters(pl_module)
6 changes: 6 additions & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ class SchedulerConfig(BaseModelExtraForbid):
params: Params = {}


class TrainingStrategyConfig(BaseModelExtraForbid):
name: str = "TripleLRSGDStrategy"
params: Params = {}


class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()
use_rich_progress_bar: bool = True
Expand Down Expand Up @@ -382,6 +387,7 @@ class TrainerConfig(BaseModelExtraForbid):

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
training_strategy: TrainingStrategyConfig = TrainingStrategyConfig()

@model_validator(mode="after")
def validate_deterministic(self) -> Self:
Expand Down
23 changes: 22 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
combine_visualizations,
get_denormalized_images,
)
from luxonis_train.callbacks import BaseLuxonisProgressBar, ModuleFreezer
from luxonis_train.callbacks import (
BaseLuxonisProgressBar,
ModuleFreezer,
TrainingManager,
)
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.nodes import BaseNode
from luxonis_train.utils import (
Expand All @@ -42,6 +46,7 @@
CALLBACKS,
OPTIMIZERS,
SCHEDULERS,
STRATEGIES,
Registry,
)

Expand Down Expand Up @@ -268,6 +273,16 @@ def __init__(

self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy.params:
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
self.training_strategy = STRATEGIES.get(
self.cfg.trainer.training_strategy.name
)(
pl_module=self,
params=self.cfg.trainer.training_strategy.params,
)
else:
self.training_strategy = None

@property
def core(self) -> "luxonis_train.core.LuxonisModel":
"""Returns the core model."""
Expand Down Expand Up @@ -849,6 +864,9 @@ def configure_callbacks(self) -> list[pl.Callback]:
CALLBACKS.get(callback.name)(**callback.params)
)

if self.training_strategy is not None:
callbacks.append(TrainingManager(strategy=self.training_strategy)) # type: ignore

return callbacks

def configure_optimizers(
Expand All @@ -858,6 +876,9 @@ def configure_optimizers(
list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers."""
if self.training_strategy is not None:
return self.training_strategy.configure_optimizers()

cfg_optimizer = self.cfg.trainer.optimizer
cfg_scheduler = self.cfg.trainer.scheduler

Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,23 @@ def __init__(
)
)

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def set_export_mode(self, mode: bool = True) -> None:
"""Reparametrizes instances of L{RepVGGBlock} in the network.

Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def __init__(self, n_classes: int, in_channels: int):

prior_prob = 1e-2
self._initialize_weights_and_biases(prior_prob)
self.initialize_weights()

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
out_feature = self.decoder(x)
Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,26 @@ def __init__(
f"output{i+1}_yolov6r2" for i in range(self.n_heads)
]

self.initialize_weights()

if download_weights:
# TODO: Handle variants of head in a nicer way
if self.in_channels == [32, 64, 128]:
weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/efficientbbox_head_n_coco.ckpt"
self.load_checkpoint(weights_path, strict=False)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/necks/reppan_neck/reppan_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,23 @@ def __init__(
out_channels = channels_list_down_blocks[2 * i + 1]
curr_n_repeats = n_repeats_down_blocks[i]

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, inputs: list[Tensor]) -> list[Tensor]:
x = inputs[-1]
up_block_outs: list[Tensor] = []
Expand Down
5 changes: 5 additions & 0 deletions luxonis_train/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .triple_lr_sgd import TripleLRScheduler

__all__ = [
"TripleLRScheduler",
]
27 changes: 27 additions & 0 deletions luxonis_train/strategies/base_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import Tuple

import pytorch_lightning as pl
from luxonis_ml.utils.registry import AutoRegisterMeta
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from luxonis_train.utils.registry import STRATEGIES


class BaseTrainingStrategy(
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
ABC,
metaclass=AutoRegisterMeta,
register=False,
registry=STRATEGIES,
):
def __init__(self, pl_module: pl.LightningModule):
self.pl_module = pl_module

@abstractmethod
def configure_optimizers(self) -> Tuple[_LRScheduler, Optimizer]:
pass

@abstractmethod
def update_parameters(self, *args, **kwargs):
pass
Loading
Loading