Skip to content

Commit

Permalink
Updated how we access config version from luxonis-ml (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 authored Nov 8, 2024
1 parent 7eb6b9e commit c29601c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
19 changes: 12 additions & 7 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(
self.pl_trainer = create_trainer(
self.cfg.trainer,
logger=self.tracker,
callbacks=LuxonisRichProgressBar()
if self.cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar(),
callbacks=(
LuxonisRichProgressBar()
if self.cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar()
),
precision=self.cfg.trainer.precision,
)

Expand Down Expand Up @@ -565,9 +567,11 @@ def _objective(trial: optuna.trial.Trial) -> float:
_core=self,
)
callbacks = [
LuxonisRichProgressBar()
if cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar()
(
LuxonisRichProgressBar()
if cfg.trainer.use_rich_progress_bar
else LuxonisTQDMProgressBar()
)
]

pruner_callback = PyTorchLightningPruningCallback(
Expand Down Expand Up @@ -732,6 +736,7 @@ def _mult(lst: list[float | int]) -> list[float]:
),
"reverse_channels": self.cfg.trainer.preprocessing.train_rgb,
"interleaved_to_planar": False, # TODO: make it modifiable?
"dai_type": "RGB888p",
}

inputs_dict = get_inputs(path)
Expand Down Expand Up @@ -774,7 +779,7 @@ def _mult(lst: list[float | int]) -> list[float]:
}

cfg_dict = {
"config_version": CONFIG_VERSION.__args__[-1], # type: ignore
"config_version": CONFIG_VERSION,
"model": model,
}

Expand Down
3 changes: 1 addition & 2 deletions luxonis_train/core/utils/archive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import onnx
from luxonis_ml.nn_archive.config_building_blocks import (
DataType,
ObjectDetectionSubtypeYOLO,
)
from onnx.onnx_pb import TensorProto

Expand Down Expand Up @@ -143,7 +142,7 @@ def _get_head_specific_parameters(
ImplementedHeadsIsSoxtmaxed, head_name
).value
elif head_name == "EfficientBBoxHead":
parameters["subtype"] = ObjectDetectionSubtypeYOLO.YOLOv6.value
parameters["subtype"] = "yolov6"
head_node = nodes[head_alias]
parameters["iou_threshold"] = head_node.iou_thres
parameters["conf_threshold"] = head_node.conf_thres
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
blobconverter>=1.4.2
lightning>=2.4.0
luxonis-ml[data,tracker]>=0.4.0
luxonis-ml[data,tracker]>=0.5.0
onnx>=1.12.0
onnxruntime>=1.13.1
onnxsim>=0.4.10
Expand Down
37 changes: 19 additions & 18 deletions tests/integration/parking_lot.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
57.375
],
"reverse_channels": true,
"interleaved_to_planar": false
"interleaved_to_planar": false,
"dai_type": "RGB888p"
}
}
],
Expand Down Expand Up @@ -153,8 +154,8 @@
"metadata": {
"postprocessor_path": null,
"classes": [
"motorbike",
"car"
"car",
"motorbike"
],
"n_classes": 2,
"iou_threshold": 0.45,
Expand Down Expand Up @@ -229,29 +230,29 @@
"metadata": {
"postprocessor_path": null,
"classes": [
"Kawasaki",
"alfa-romeo",
"aprilia",
"background",
"chrysler",
"bmw",
"ducati",
"buick",
"chrysler",
"dodge",
"ducati",
"ferrari",
"infiniti",
"land-rover",
"roll-royce",
"saab",
"Kawasaki",
"moto",
"truimph",
"alfa-romeo",
"harley",
"honda",
"infiniti",
"isuzu",
"jeep",
"aprilia",
"land-rover",
"moto",
"piaggio",
"yamaha",
"buick",
"pontiac",
"isuzu"
"roll-royce",
"saab",
"truimph",
"yamaha"
],
"n_classes": 23,
"is_softmax": false
Expand Down Expand Up @@ -279,4 +280,4 @@
}
]
}
}
}

0 comments on commit c29601c

Please sign in to comment.