Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/inference #82

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions luxonis_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from importlib.metadata import version
from pathlib import Path
from typing import Annotated, Optional
from typing import Annotated

import typer
import yaml
Expand All @@ -25,7 +25,7 @@ class _ViewType(str, Enum):


ConfigType = Annotated[
Optional[str],
str | None,
typer.Option(
help="Path to the configuration file.",
show_default=False,
Expand All @@ -34,7 +34,7 @@ class _ViewType(str, Enum):
]

OptsType = Annotated[
Optional[list[str]],
list[str] | None,
typer.Argument(
help="A list of optional CLI overrides of the config file.",
show_default=False,
Expand All @@ -46,16 +46,23 @@ class _ViewType(str, Enum):
]

SaveDirType = Annotated[
Optional[Path],
Path | None,
typer.Option(help="Where to save the inference results."),
]

ImgPathType = Annotated[
str | None,
typer.Option(
help="Path to an image file or a directory containing images for inference."
),
]


@app.command()
def train(
config: ConfigType = None,
resume: Annotated[
Optional[str],
str | None,
typer.Option(help="Resume training from this checkpoint."),
] = None,
opts: OptsType = None,
Expand Down Expand Up @@ -99,12 +106,15 @@ def infer(
config: ConfigType = None,
view: ViewType = _ViewType.VAL,
save_dir: SaveDirType = None,
img_path: ImgPathType = None,
opts: OptsType = None,
):
"""Run inference."""
from luxonis_train.core import LuxonisModel

LuxonisModel(config, opts).infer(view=view.value, save_dir=save_dir)
LuxonisModel(config, opts).infer(
view=view.value, save_dir=save_dir, img_path=img_path
)


@app.command()
Expand Down Expand Up @@ -200,7 +210,7 @@ def common(
),
] = False,
source: Annotated[
Optional[Path],
Path | None,
typer.Option(
help="Path to a python file with custom components. "
"Will be sourced before running the command.",
Expand Down
29 changes: 21 additions & 8 deletions luxonis_train/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging
from typeguard import typechecked

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.callbacks import (
LuxonisRichProgressBar,
LuxonisTQDMProgressBar,
Expand All @@ -35,7 +34,10 @@
replace_weights,
try_onnx_simplify,
)
from .utils.infer_utils import render_visualizations
from .utils.infer_utils import (
process_dataset_images,
process_images,
)
from .utils.train_utils import create_trainer

logger = getLogger(__name__)
Expand Down Expand Up @@ -419,6 +421,7 @@
self,
view: Literal["train", "val", "test"] = "val",
save_dir: str | Path | None = None,
img_path: str | None = None,
) -> None:
"""Runs inference.

Expand All @@ -429,15 +432,25 @@
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
@type img_path: str | None
@param img_path: Path to the image file or directory for inference.
If None, defaults to using dataset images.
"""
self.lightning_module.eval()

for inputs, labels in self.pytorch_loaders[view]:
images = get_unnormalized_images(self.cfg, inputs)
outputs = self.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)
if img_path:
img_path_obj = Path(img_path)
if img_path_obj.is_file():
process_images(self, [img_path_obj], view, save_dir)
elif img_path_obj.is_dir():
image_files = [

Check warning on line 446 in luxonis_train/core/core.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/core.py#L442-L446

Added lines #L442 - L446 were not covered by tests
f
for f in img_path_obj.iterdir()
if f.suffix.lower() in {".png", ".jpg", ".jpeg"}
]
process_images(self, image_files, view, save_dir)

Check warning on line 451 in luxonis_train/core/core.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/core.py#L451

Added line #L451 was not covered by tests
else:
process_dataset_images(self, view, save_dir)

def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
Expand Down
68 changes: 68 additions & 0 deletions luxonis_train/core/utils/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from pathlib import Path

import cv2
import torch
from torch import Tensor

from luxonis_train.attached_modules.visualizers import get_unnormalized_images
from luxonis_train.enums import TaskType


def render_visualizations(
visualizations: dict[str, dict[str, Tensor]], save_dir: str | Path | None
) -> None:
"""Render or save visualizations."""
save_dir = Path(save_dir) if save_dir is not None else None
if save_dir is not None:
save_dir.mkdir(exist_ok=True, parents=True)
Expand All @@ -28,3 +33,66 @@
if save_dir is None:
if cv2.waitKey(0) == ord("q"):
exit()


def process_images(
model, img_paths: list[Path], view: str, save_dir: str | Path | None
) -> None:
"""Handles inference on one or more images."""
first_image = cv2.cvtColor(

Check warning on line 42 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L42

Added line #L42 was not covered by tests
cv2.imread(str(img_paths[0])), cv2.COLOR_BGR2RGB
)
labels = create_dummy_labels(model, view, first_image.shape)
for img_path in img_paths:
img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB)
img, _ = (

Check warning on line 48 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L45-L48

Added lines #L45 - L48 were not covered by tests
model.train_augmentations([(img, {})])
if view == "train"
else model.val_augmentations([(img, {})])
)

inputs = {

Check warning on line 54 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L54

Added line #L54 was not covered by tests
"image": torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2).float()
}
images = get_unnormalized_images(model.cfg, inputs)

Check warning on line 57 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L57

Added line #L57 was not covered by tests

outputs = model.lightning_module.forward(

Check warning on line 59 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L59

Added line #L59 was not covered by tests
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)

Check warning on line 62 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L62

Added line #L62 was not covered by tests


def process_dataset_images(
model, view: str, save_dir: str | Path | None
) -> None:
"""Handles the inference on dataset images."""
for inputs, labels in model.pytorch_loaders[view]:
images = get_unnormalized_images(model.cfg, inputs)
outputs = model.lightning_module.forward(
inputs, labels, images=images, compute_visualizations=True
)
render_visualizations(outputs.visualizations, save_dir)


def create_dummy_labels(model, view: str, img_shape: tuple) -> dict:
"""Prepares the labels for different tasks (classification,
keypoints, etc.)."""
tasks = list(model.loaders["train"].get_classes().keys())
h, w, _ = img_shape
labels = {}
nk = model.loaders[view].get_n_keypoints()["keypoints"]

Check warning on line 83 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L80-L83

Added lines #L80 - L83 were not covered by tests

for task in tasks:
if task == "classification":
labels[task] = [-1, TaskType.CLASSIFICATION]
elif task == "keypoints":
labels[task] = [torch.zeros((1, nk * 3 + 2)), TaskType.KEYPOINTS]
elif task == "segmentation":
labels[task] = [torch.zeros((1, h, w)), TaskType.SEGMENTATION]
elif task == "boundingbox":
labels[task] = [

Check warning on line 93 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L85-L93

Added lines #L85 - L93 were not covered by tests
torch.tensor([[-1, 0, 0, 0, 0, 0]]),
TaskType.BOUNDINGBOX,
]

return labels

Check warning on line 98 in luxonis_train/core/utils/infer_utils.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/core/utils/infer_utils.py#L98

Added line #L98 was not covered by tests