Skip to content

Commit

Permalink
modify the logic for assigning the optimizer and scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Nov 18, 2024
1 parent 75096da commit 23dda47
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
44 changes: 31 additions & 13 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ class CallbackConfig(BaseModelExtraForbid):


class OptimizerConfig(BaseModelExtraForbid):
name: str = "Adam"
name: str
params: Params = {}


class SchedulerConfig(BaseModelExtraForbid):
name: str = "ConstantLR"
name: str
params: Params = {}


Expand Down Expand Up @@ -385,8 +385,8 @@ class TrainerConfig(BaseModelExtraForbid):

callbacks: list[CallbackConfig] = []

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
optimizer: OptimizerConfig | None = None
scheduler: SchedulerConfig | None = None
training_strategy: TrainingStrategyConfig | None = None

@model_validator(mode="after")
Expand Down Expand Up @@ -536,16 +536,34 @@ def smart_auto_populate(cls, instance: "Config") -> None:
"""Automatically populates config fields based on rules, with
warnings."""

# Rule: Set default optimizer and scheduler if training_strategy is not defined and optimizer and scheduler are None
if instance.trainer.training_strategy is None:
if instance.trainer.optimizer is None:
instance.trainer.optimizer = OptimizerConfig(
name="Adam", params={}
)
logger.warning(
"Optimizer not specified. Automatically set to `Adam`."
)
if instance.trainer.scheduler is None:
instance.trainer.scheduler = SchedulerConfig(
name="ConstantLR", params={}
)
logger.warning(
"Scheduler not specified. Automatically set to `ConstantLR`."
)

# Rule: CosineAnnealingLR should have T_max set to the number of epochs if not provided
scheduler = instance.trainer.scheduler
if (
scheduler.name == "CosineAnnealingLR"
and "T_max" not in scheduler.params
):
scheduler.params["T_max"] = instance.trainer.epochs
logger.warning(
"`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs."
)
if instance.trainer.scheduler is not None:
scheduler = instance.trainer.scheduler
if (
scheduler.name == "CosineAnnealingLR"
and "T_max" not in scheduler.params
):
scheduler.params["T_max"] = instance.trainer.epochs
logger.warning(
"`T_max` was not set for `CosineAnnealingLR`. Automatically set `T_max` to number of epochs."
)

# Rule: Mosaic4 should have out_width and out_height matching train_image_size if not provided
for augmentation in instance.trainer.preprocessing.augmentations:
Expand Down
14 changes: 7 additions & 7 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,13 @@ def __init__(
self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy is not None:
if self.cfg.trainer.optimizer is not None:
logger.warning(
"Training strategy is active; the specified optimizer will be ignored."
)
if self.cfg.trainer.scheduler is not None:
logger.warning(
"Training strategy is active; the specified scheduler will be ignored."
if (
self.cfg.trainer.optimizer is not None
or self.cfg.trainer.scheduler is not None
):
raise ValueError(
"Training strategy is defined, but optimizer or scheduler is also defined. "
"Please remove optimizer and scheduler from the config."
)
self.training_strategy = STRATEGIES.get(
self.cfg.trainer.training_strategy.name
Expand Down

0 comments on commit 23dda47

Please sign in to comment.