Skip to content

Commit

Permalink
initialize_weights configurable param
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Nov 19, 2024
1 parent 4c61f38 commit 2951676
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 44 deletions.
59 changes: 31 additions & 28 deletions luxonis_train/nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).

**Parameters:**

| Key | Type | Default value | Description |
| ------------------ | ----------------------------------------------------------------- | --------------------------- | -------------------------------------------------------------------------- |
| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
| `channels_list` | `list[int]` | \[64, 128, 256, 512, 1024\] | List of number of channels for each block |
| `n_repeats` | `list[int]` | \[1, 6, 12, 18, 6\] | List of number of repeats of `RepVGGBlock` |
| `depth_mul` | `float` | `0.33` | Depth multiplier |
| `width_mul` | `float` | `0.25` | Width multiplier |
| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
| `download_weights` | `bool` | `True` | If True download weights from COCO (if available for specified variant) |
| Key | Type | Default value | Description |
| -------------------- | ----------------------------------------------------------------- | --------------------------- | -------------------------------------------------------------------------- |
| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
| `channels_list` | `list[int]` | \[64, 128, 256, 512, 1024\] | List of number of channels for each block |
| `n_repeats` | `list[int]` | \[1, 6, 12, 18, 6\] | List of number of repeats of `RepVGGBlock` |
| `depth_mul` | `float` | `0.33` | Depth multiplier |
| `width_mul` | `float` | `0.25` | Width multiplier |
| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
| `download_weights` | `bool` | `True` | If True download weights from COCO (if available for specified variant) |
| `initialize_weights` | `bool` | `True` | If True, initialize weights. |

### RexNetV1_lite

Expand Down Expand Up @@ -175,17 +176,18 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).

**Parameters:**

| Key | Type | Default value | Description |
| ------------------ | ----------------------------------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------- |
| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
| `n_heads` | `Literal[2,3,4]` | `3` | Number of output heads. Should be same also on the connected head in most cases |
| `channels_list` | `list[int]` | `[256, 128, 128, 256, 256, 512]` | List of number of channels for each block |
| `n_repeats` | `list[int]` | `[12, 12, 12, 12]` | List of number of repeats of `RepVGGBlock` |
| `depth_mul` | `float` | `0.33` | Depth multiplier |
| `width_mul` | `float` | `0.25` | Width multiplier |
| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
| `download_weights` | `bool` | `False` | If True download weights from COCO (if available for specified variant) |
| Key | Type | Default value | Description |
| -------------------- | ----------------------------------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------- |
| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
| `n_heads` | `Literal[2,3,4]` | `3` | Number of output heads. Should be same also on the connected head in most cases |
| `channels_list` | `list[int]` | `[256, 128, 128, 256, 256, 512]` | List of number of channels for each block |
| `n_repeats` | `list[int]` | `[12, 12, 12, 12]` | List of number of repeats of `RepVGGBlock` |
| `depth_mul` | `float` | `0.33` | Depth multiplier |
| `width_mul` | `float` | `0.25` | Width multiplier |
| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
| `download_weights` | `bool` | `False` | If True download weights from COCO (if available for specified variant) |
| `initialize_weights` | `bool` | `True` | If True, initialize weights. |

## Heads

Expand Down Expand Up @@ -217,13 +219,14 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).

**Parameters:**

| Key | Type | Default value | Description |
| ------------------ | ------- | ------------- | --------------------------------------------------------------------- |
| `n_heads` | `bool` | `3` | Number of output heads |
| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) |
| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) |
| `max_det` | `int` | `300` | Maximum number of detections retained after NMS |
| `download_weights` | `bool` | `False` | If True download weights from COCO |
| Key | Type | Default value | Description |
| -------------------- | ------- | ------------- | --------------------------------------------------------------------- |
| `n_heads` | `bool` | `3` | Number of output heads |
| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) |
| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) |
| `max_det` | `int` | `300` | Maximum number of detections retained after NMS |
| `download_weights` | `bool` | `False` | If True download weights from COCO |
| `initialize_weights` | `bool` | `True` | If True, initialize weights. |

### `EfficientKeypointBBoxHead`

Expand Down
6 changes: 5 additions & 1 deletion luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
block: Literal["RepBlock", "CSPStackRepBlock"] | None = None,
csp_e: float | None = None,
download_weights: bool = True,
initialize_weights: bool = True,
**kwargs: Any,
):
"""Implementation of the EfficientRep backbone. Supports the
Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(
overrides the variant value.
@type download_weights: bool
@param download_weights: If True download weights from COCO (if available for specified variant). Defaults to True.
@type initialize_weights: bool
@param initialize_weights: If True, initialize weights of the model.
"""
super().__init__(**kwargs)

Expand Down Expand Up @@ -125,7 +128,8 @@ def __init__(
)
)

self.initialize_weights()
if initialize_weights:
self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)
Expand Down
13 changes: 0 additions & 13 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,6 @@ def __init__(self, n_classes: int, in_channels: int):

prior_prob = 1e-2
self._initialize_weights_and_biases(prior_prob)
self.initialize_weights()

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
out_feature = self.decoder(x)
Expand Down
6 changes: 5 additions & 1 deletion luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
iou_thres: float = 0.45,
max_det: int = 300,
download_weights: bool = False,
initialize_weights: bool = True,
**kwargs: Any,
):
"""Head for object detection.
Expand All @@ -51,6 +52,8 @@ def __init__(
@type download_weights: bool
@param download_weights: If True download weights from COCO.
Defaults to False.
@type initialize_weights: bool
@param initialize_weights: If True, initialize weights.
"""
super().__init__(**kwargs)

Expand Down Expand Up @@ -95,7 +98,8 @@ def __init__(
f"output{i+1}_yolov6r2" for i in range(self.n_heads)
]

self.initialize_weights()
if initialize_weights:
self.initialize_weights()

if download_weights:
# TODO: Handle variants of head in a nicer way
Expand Down
6 changes: 5 additions & 1 deletion luxonis_train/nodes/necks/reppan_neck/reppan_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
block: Literal["RepBlock", "CSPStackRepBlock"] | None = None,
csp_e: float | None = None,
download_weights: bool = False,
initialize_weights: bool = True,
**kwargs: Any,
):
"""Implementation of the RepPANNeck module. Supports the version
Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(
overrides the variant value.
@type download_weights: bool
@param download_weights: If True download weights from COCO (if available for specified variant). Defaults to False.
@type initialize_weights: bool
@param initialize_weights: If True, initialize weights of the model.
"""

super().__init__(**kwargs)
Expand Down Expand Up @@ -165,7 +168,8 @@ def __init__(
out_channels = channels_list_down_blocks[2 * i + 1]
curr_n_repeats = n_repeats_down_blocks[i]

self.initialize_weights()
if initialize_weights:
self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)
Expand Down

0 comments on commit 2951676

Please sign in to comment.