diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index e52f5790..c84d8575 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -327,7 +327,7 @@ def prepare( set(self.supported_tasks) & set(self.node_tasks) ) x = self.get_input_tensors(inputs) - if labels is None: + if labels is None or len(labels) == 0: return x, None # type: ignore label, task_type = self._get_label(labels) if task_type in [TaskType.CLASSIFICATION, TaskType.SEGMENTATION]: diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index e7c3a074..b4996751 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -6,10 +6,13 @@ import cv2 import numpy as np import torch +import torch.utils.data as torch_data +from luxonis_ml.data import LuxonisDataset from torch import Tensor import luxonis_train from luxonis_train.attached_modules.visualizers import get_denormalized_images +from luxonis_train.loaders import LuxonisLoaderTorch IMAGE_FORMATS = { ".bmp", @@ -122,6 +125,61 @@ def infer_from_video( writer.release() +def infer_from_loader( + model: "luxonis_train.core.LuxonisModel", + loader: torch_data.DataLoader, + save_dir: Path | None, + img_paths: list[Path] | None = None, +) -> None: + """Runs inference on images from the dataset. + + @type model: L{LuxonisModel} + @param model: The model to use for inference. + @type loader: torch_data.DataLoader + @param loader: The loader to use for inference. + @type save_dir: str | Path | None + @param save_dir: The directory to save the visualizations to. + @type img_paths: list[Path] | None + @param img_paths: The paths to the images. + """ + + predictions = model.pl_trainer.predict(model.lightning_module, loader) + + broken = False + if predictions is None: + return + + for i, outputs in enumerate(predictions): + if broken: # pragma: no cover + break + visualizations = outputs.visualizations # type: ignore + batch_size = next( + iter(next(iter(visualizations.values())).values()) + ).shape[0] + renders = process_visualizations( + visualizations, + batch_size=batch_size, + ) + for j in range(batch_size): + for (node_name, viz_name), visualizations in renders.items(): + viz = visualizations[j] + if save_dir is not None: + if img_paths is not None: + img_path = img_paths[i * batch_size + j] + name = f"{img_path.stem}_{node_name}_{viz_name}" + else: + name = f"{node_name}_{viz_name}_{i * batch_size + j}" + cv2.imwrite(str(save_dir / f"{name}.png"), viz) + else: + cv2.imshow(f"{node_name}/{viz_name}", viz) + + if not save_dir and window_closed(): # pragma: no cover + broken = True + break + + cv2.destroyAllWindows() + + def infer_from_directory( model: "luxonis_train.core.LuxonisModel", img_paths: Iterable[Path], @@ -136,27 +194,34 @@ def infer_from_directory( @type save_dir: Path | None @param save_dir: The directory to save the visualizations to. """ - for img_path in img_paths: - img = cv2.imread(str(img_path)) - outputs = prepare_and_infer_image(model, img) - renders = process_visualizations(outputs.visualizations, batch_size=1) + img_paths = list(img_paths) + + def generator(): + for img_path in img_paths: + yield { + "file": img_path, + } + + dataset_name = "infer_from_directory" + dataset = LuxonisDataset(dataset_name=dataset_name, delete_existing=True) + dataset.add(generator()) + dataset.make_splits( + {"train": 0.0, "val": 0.0, "test": 1.0}, replace_old_splits=True + ) - for (node_name, viz_name), [viz] in renders.items(): - if save_dir is not None: - cv2.imwrite( - str( - save_dir - / f"{img_path.stem}_{node_name}_{viz_name}.png" - ), - viz, - ) - else: # pragma: no cover - cv2.imshow(f"{node_name}/{viz_name}", viz) + loader = LuxonisLoaderTorch( + dataset_name=dataset_name, + image_source="image", + view="test", + augmentations=model.val_augmentations, + ) + loader = torch_data.DataLoader( + loader, batch_size=model.cfg.trainer.batch_size + ) - if not save_dir and window_closed(): # pragma: no cover - break + infer_from_loader(model, loader, save_dir, img_paths) - cv2.destroyAllWindows() + dataset.delete_dataset() def infer_from_dataset( @@ -173,33 +238,6 @@ def infer_from_dataset( @type save_dir: str | Path | None @param save_dir: The directory to save the visualizations to. """ - broken = False - for i, (inputs, labels) in enumerate(model.pytorch_loaders[view]): - if broken: # pragma: no cover - break - - images = get_denormalized_images(model.cfg, inputs) - batch_size = images.shape[0] - outputs = model.lightning_module.forward( - inputs, labels, images=images, compute_visualizations=True - ) - renders = process_visualizations( - outputs.visualizations, - batch_size=batch_size, - ) - for j in range(batch_size): - for (node_name, viz_name), visualizations in renders.items(): - viz = visualizations[j] - if save_dir is not None: - name = f"{node_name}_{viz_name}" - cv2.imwrite( - str(save_dir / f"{name}_{i * batch_size + j}.png"), viz - ) - else: - cv2.imshow(f"{node_name}/{viz_name}", viz) - - if not save_dir and window_closed(): # pragma: no cover - broken = True - break - cv2.destroyAllWindows() + loader = model.pytorch_loaders[view] + infer_from_loader(model, loader, save_dir) diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py index ea797abc..011c3983 100644 --- a/luxonis_train/models/luxonis_lightning.py +++ b/luxonis_train/models/luxonis_lightning.py @@ -686,6 +686,22 @@ def test_step( """Performs one step of testing with provided batch.""" return self._evaluation_step("test", test_batch) + def predict_step( + self, batch: tuple[dict[str, Tensor], Labels] + ) -> LuxonisOutput: + """Performs one step of prediction with provided batch.""" + inputs, labels = batch + images = get_denormalized_images(self.cfg, inputs) + outputs = self.forward( + inputs, + labels, + images=images, + compute_visualizations=True, + compute_loss=False, + compute_metrics=False, + ) + return outputs + def on_train_epoch_end(self) -> None: """Performs train epoch end operations.""" epoch_train_losses = self._average_losses(self.training_step_outputs)