Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom optimizer #132

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ model.tune()
- [**Callbacks**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md): Allow custom code to be executed at different stages of training.
- [**Optimizers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer): Control how the model's weights are updated.
- [**Schedulers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler): Adjust the learning rate during training.
- [**Training Strategy**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#training-strategy): Specify a custom combination of optimizer and scheduler to tailor the training process for specific use cases.

**Creating Custom Components:**

Expand All @@ -581,6 +582,7 @@ Registered components can be referenced in the config file. Custom components ne
- **Callbacks** - [`lightning.pytorch.callbacks.Callback`](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), requires manual registration to the `CALLBACKS` registry
- **Optimizers** - [`torch.optim.Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), requires manual registration to the `OPTIMIZERS` registry
- **Schedulers** - [`torch.optim.lr_scheduler.LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate), requires manual registration to the `SCHEDULERS` registry
- **Training Strategy** - [`BaseTrainingStrategy`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/strategies/base_strategy.py)

**Examples:**

Expand Down
31 changes: 31 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,37 @@ trainer:
eta_min: 0
```

### Training Strategy
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved

Defines the training strategy to be used. Currently, only the `TripleLRSGDStrategy` is supported, but more strategies will be added in the future.

| Key | Type | Default value | Description |
| ----------------- | ------- | ----------------------- | ---------------------------------------------- |
| `name` | `str` | `"TripleLRSGDStrategy"` | Name of the training strategy |
| `warmup_epochs` | `int` | `3` | Number of epochs for the warmup phase |
| `warmup_bias_lr` | `float` | `0.1` | Learning rate for bias during the warmup phase |
| `warmup_momentum` | `float` | `0.8` | Momentum value during the warmup phase |
| `lr` | `float` | `0.02` | Initial learning rate |
| `lre` | `float` | `0.0002` | End learning rate |
| `momentum` | `float` | `0.937` | Momentum for the optimizer |
| `weight_decay` | `float` | `0.0005` | Weight decay value |
| `nesterov` | `bool` | `true` | Use Nesterov momentum or not |

**Example:**

```yaml
training_strategy:
name: "TripleLRSGDStrategy"
warmup_epochs: 3
warmup_bias_lr: 0.1
warmup_momentum: 0.8
lr: 0.02
lre: 0.0002
momentum: 0.937
weight_decay: 0.0005
nesterov: true
```

## Exporter

Here you can define configuration for exporting.
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nodes import *
from .optimizers import *
from .schedulers import *
from .strategies import *
from .utils import *
except ImportError as e:
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def draw_predictions(
prediction = predictions[i]
mask = prediction[..., 2] < visibility_threshold
visible_kpts = prediction[..., :2] * (~mask).unsqueeze(-1).float()
visible_kpts[..., 0] = torch.clamp(
visible_kpts[..., 0], 0, canvas.size(-1) - 1
)
visible_kpts[..., 1] = torch.clamp(
visible_kpts[..., 1], 0, canvas.size(-2) - 1
)
viz[i] = draw_keypoints(
canvas[i].clone(),
visible_kpts[..., :2],
Expand Down
3 changes: 3 additions & 0 deletions luxonis_train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .metadata_logger import MetadataLogger
from .module_freezer import ModuleFreezer
from .test_on_train_end import TestOnTrainEnd
from .training_manager import TrainingManager
from .upload_checkpoint import UploadCheckpoint

CALLBACKS.register_module(module=EarlyStopping)
Expand All @@ -38,6 +39,7 @@
CALLBACKS.register_module(module=ModelPruning)
CALLBACKS.register_module(module=GradCamCallback)
CALLBACKS.register_module(module=EMACallback)
CALLBACKS.register_module(module=TrainingManager)


__all__ = [
Expand All @@ -53,4 +55,5 @@
"GPUStatsMonitor",
"GradCamCallback",
"EMACallback",
"TrainingManager",
]
28 changes: 28 additions & 0 deletions luxonis_train/callbacks/training_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytorch_lightning as pl

from luxonis_train.strategies.base_strategy import BaseTrainingStrategy


class TrainingManager(pl.Callback):
def __init__(self, strategy: BaseTrainingStrategy | None = None):
"""Training manager callback that updates the parameters of the
training strategy.

@type strategy: BaseTrainingStrategy
@param strategy: The strategy to be used.
"""
self.strategy = strategy

Check warning on line 14 in luxonis_train/callbacks/training_manager.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/callbacks/training_manager.py#L14

Added line #L14 was not covered by tests

def on_after_backward(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
"""PyTorch Lightning hook that is called after the backward
pass.

@type trainer: pl.Trainer
@param trainer: The trainer object.
@type pl_module: pl.LightningModule
@param pl_module: The pl_module object.
"""
if self.strategy is not None:
self.strategy.update_parameters(pl_module)

Check warning on line 28 in luxonis_train/callbacks/training_manager.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/callbacks/training_manager.py#L27-L28

Added lines #L27 - L28 were not covered by tests
51 changes: 38 additions & 13 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,17 @@ class CallbackConfig(BaseModelExtraForbid):


class OptimizerConfig(BaseModelExtraForbid):
name: str = "Adam"
name: str
params: Params = {}


class SchedulerConfig(BaseModelExtraForbid):
name: str = "ConstantLR"
name: str
params: Params = {}


class TrainingStrategyConfig(BaseModelExtraForbid):
name: str
params: Params = {}


Expand Down Expand Up @@ -380,8 +385,9 @@ class TrainerConfig(BaseModelExtraForbid):

callbacks: list[CallbackConfig] = []

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
optimizer: OptimizerConfig | None = None
scheduler: SchedulerConfig | None = None
training_strategy: TrainingStrategyConfig | None = None

@model_validator(mode="after")
def validate_deterministic(self) -> Self:
Expand Down Expand Up @@ -511,6 +517,7 @@ def get_config(
) -> "Config":
instance = super().get_config(cfg, overrides)
if not isinstance(cfg, str):
cls.smart_auto_populate(instance)
return instance
fs = LuxonisFileSystem(cfg)
if fs.is_mlflow:
Expand All @@ -530,16 +537,34 @@ def smart_auto_populate(cls, instance: "Config") -> None:
"""Automatically populates config fields based on rules, with
warnings."""

# Rule: Set default optimizer and scheduler if training_strategy is not defined and optimizer and scheduler are None
if instance.trainer.training_strategy is None:
if instance.trainer.optimizer is None:
instance.trainer.optimizer = OptimizerConfig(
name="Adam", params={}
)
logger.warning(
"Optimizer not specified. Automatically set to `Adam`."
)
if instance.trainer.scheduler is None:
instance.trainer.scheduler = SchedulerConfig(
name="ConstantLR", params={}
)
logger.warning(
"Scheduler not specified. Automatically set to `ConstantLR`."
)

# Rule: CosineAnnealingLR should have T_max set to the number of epochs if not provided
scheduler = instance.trainer.scheduler
if (
scheduler.name == "CosineAnnealingLR"
and "T_max" not in scheduler.params
):
scheduler.params["T_max"] = instance.trainer.epochs
logger.warning(
"`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs."
)
if instance.trainer.scheduler is not None:
scheduler = instance.trainer.scheduler
if (
scheduler.name == "CosineAnnealingLR"
and "T_max" not in scheduler.params
):
scheduler.params["T_max"] = instance.trainer.epochs
logger.warning(
"`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs."
)

# Rule: Mosaic4 should have out_width and out_height matching train_image_size if not provided
for augmentation in instance.trainer.preprocessing.augmentations:
Expand Down
36 changes: 35 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
combine_visualizations,
get_denormalized_images,
)
from luxonis_train.callbacks import BaseLuxonisProgressBar, ModuleFreezer
from luxonis_train.callbacks import (
BaseLuxonisProgressBar,
ModuleFreezer,
TrainingManager,
)
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.nodes import BaseNode
from luxonis_train.utils import (
Expand All @@ -42,6 +46,7 @@
CALLBACKS,
OPTIMIZERS,
SCHEDULERS,
STRATEGIES,
Registry,
)

Expand Down Expand Up @@ -268,6 +273,24 @@

self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy is not None:
if (

Check warning on line 277 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L277

Added line #L277 was not covered by tests
self.cfg.trainer.optimizer is not None
or self.cfg.trainer.scheduler is not None
):
raise ValueError(

Check warning on line 281 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L281

Added line #L281 was not covered by tests
"Training strategy is defined, but optimizer or scheduler is also defined. "
"Please remove optimizer and scheduler from the config."
)
self.training_strategy = STRATEGIES.get(

Check warning on line 285 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L285

Added line #L285 was not covered by tests
self.cfg.trainer.training_strategy.name
)(
pl_module=self,
params=self.cfg.trainer.training_strategy.params, # type: ignore
)
else:
self.training_strategy = None

@property
def core(self) -> "luxonis_train.core.LuxonisModel":
"""Returns the core model."""
Expand Down Expand Up @@ -849,6 +872,9 @@
CALLBACKS.get(callback.name)(**callback.params)
)

if self.training_strategy is not None:
callbacks.append(TrainingManager(strategy=self.training_strategy)) # type: ignore

Check warning on line 876 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L876

Added line #L876 was not covered by tests

return callbacks

def configure_optimizers(
Expand All @@ -858,9 +884,17 @@
list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers."""
if self.training_strategy is not None:
return self.training_strategy.configure_optimizers()

Check warning on line 888 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L888

Added line #L888 was not covered by tests

cfg_optimizer = self.cfg.trainer.optimizer
cfg_scheduler = self.cfg.trainer.scheduler

if cfg_optimizer is None or cfg_scheduler is None:
raise ValueError(

Check warning on line 894 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L894

Added line #L894 was not covered by tests
"Optimizer and scheduler configuration must not be None."
)

optim_params = cfg_optimizer.params | {
"params": filter(lambda p: p.requires_grad, self.parameters()),
}
Expand Down
Loading