Skip to content

Commit

Permalink
Add remove_on_export for export handling
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Sep 27, 2024
1 parent e16cde6 commit 337f193
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ModelNodeConfig(BaseModelExtraForbid):
inputs: list[str] = [] # From preceding nodes
input_sources: list[str] = [] # From data loader
freezing: FreezingConfig = FreezingConfig()
remove_on_export: bool = False
task: str | dict[TaskType, str] | None = None
params: Params = {}

Expand Down
8 changes: 7 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def __init__(
}
nodes[node_name] = (
Node,
{**node_cfg.params, "_tasks": node_cfg.task},
{
**node_cfg.params,
"_tasks": node_cfg.task,
"remove_on_export": node_cfg.remove_on_export,
},
)

# Handle inputs for this node
Expand Down Expand Up @@ -373,6 +377,8 @@ def forward(
for node_name, node, input_names, unprocessed in traverse_graph(
self.graph, cast(dict[str, BaseNode], self.nodes)
):
if node.export and node.remove_on_export:
continue

Check warning on line 381 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L381

Added line #L381 was not covered by tests
input_names += self.node_input_sources[node_name]

node_inputs: list[Packet[Tensor]] = []
Expand Down
15 changes: 15 additions & 0 deletions luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
n_classes: int | None = None,
n_keypoints: int | None = None,
in_sizes: Size | list[Size] | None = None,
remove_on_export: bool = False,
attach_index: AttachIndexType | None = None,
_tasks: dict[TaskType, str] | None = None,
):
Expand Down Expand Up @@ -187,6 +188,7 @@ class L{tasks} attribute. Shouldn't be provided by the user in most cases.
self._n_classes = n_classes
self._n_keypoints = n_keypoints
self._export = False
self._remove_on_export = remove_on_export
self._epoch = 0
self._in_sizes = in_sizes

Expand Down Expand Up @@ -507,6 +509,19 @@ def set_export_mode(self, mode: bool = True) -> None:
"""
self._export = mode

@property
def remove_on_export(self) -> bool:
"""Getter for the remove_on_export attribute."""
return self._remove_on_export

def set_remove_on_export_mode(self, mode: bool = True) -> None:
"""Sets the remove_on_export flag.
@type mode: bool
@param mode: Value to set remove_on_export to. Defaults to True.
"""
self._remove_on_export = mode

Check warning on line 523 in luxonis_train/nodes/base_node.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/nodes/base_node.py#L523

Added line #L523 was not covered by tests

def unwrap(self, inputs: list[Packet[Tensor]]) -> ForwardInputT:
"""Prepares inputs for the forward pass.
Expand Down

0 comments on commit 337f193

Please sign in to comment.