diff --git a/src/hearth/loop.py b/src/hearth/loop.py index 314df7c..02fd7b5 100644 --- a/src/hearth/loop.py +++ b/src/hearth/loop.py @@ -54,6 +54,9 @@ def optimizer(self): def optimizer(self, optimizer): if isinstance(optimizer, LazyOptimizer) and not optimizer.initialized: optimizer.add_model(self.model) + # add parameters from loss function if any exist... + # this will be a Running object so we need to access inner function... + optimizer.add_model(self.loss_fn.fn) self._optimizer = optimizer def to(self, device: Union[torch.device, str]): diff --git a/src/hearth/losses.py b/src/hearth/losses.py index e4be045..7c77c54 100644 --- a/src/hearth/losses.py +++ b/src/hearth/losses.py @@ -393,3 +393,203 @@ def __init__(self, mask_target_value=-float('inf'), reduction: str = 'mean'): def forward(self, input: Tensor, target: Tensor) -> Tensor: mask = target != self.mask_target_value return nn.functional.mse_loss(input[mask], target[mask], reduction=self.reduction) + + +class _AngularPenaltySoftMarginLoss(_BaseLoss): + """base class for angular penalty soft margin losses.""" + + def __init__( + self, + embedding_features: int, + n_classes: int, + scale: float, + margin: float, + eps: float = 1e-7, + reduction='mean', + ): + super().__init__(reduction=reduction) + self.embedding_features = embedding_features + self.n_classes = n_classes + self.scale = scale + self.margin = margin + self.eps = eps + self.weight = nn.Parameter(torch.FloatTensor(self.n_classes, self.embedding_features)) + nn.init.xavier_uniform_(self.weight) + + def _get_cos(self, inputs: Tensor) -> Tensor: + return nn.functional.linear( + nn.functional.normalize(inputs, p=2, dim=-1), + nn.functional.normalize(self.weight, p=2, dim=-1), + ) + + def _compute_unscaled_numerator(self, pos: Tensor): + return NotImplemented + + def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: + cos = self._get_cos(inputs) + n_classes = cos.shape[-1] + one_hot = torch.nn.functional.one_hot(targets, n_classes) + numerator = self.scale * self._compute_unscaled_numerator((cos * one_hot).sum(-1)) + neg = cos.masked_fill(one_hot == 1, -float('inf')) + denominator = numerator.exp() + (self.scale * neg).exp().sum(-1) + err = -(numerator - torch.log(denominator)) + return self._reduce(err) + + def extra_repr(self) -> str: + out = [ + f'embedding_features={self.embedding_features}', + f'n_classes={self.n_classes}', + f'scale={self.scale}', + f'margin={self.margin}', + f'eps={self.eps}', + f'reduction={self.reduction!r}', + ] + + return ', '.join(out) + + +class AdditiveAngularMarginLoss(_AngularPenaltySoftMarginLoss): + """AKA ArcFaceLoss. + + Reference: + `Deng et al. : ArcFace: Additive Angular Margin Loss for Deep Face Recognition\ + `_ + + Note: + - this loss has weights and requires an optimizer. + - :attr:`margin` is already expressed in radians. + + Args: + embedding_features: number of embedding features inputs are expected to be + (batch, embedding_features). + n_classes: number of classes in projection. + scale: Defaults to 64.0. + margin: margin in radians. Defaults to 0.5 (28.6 degrees). + eps: epsilon for clamping . Defaults to 1e-7. + reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'. + Defaults to 'mean'. + + Example: + >>> import torch + >>> from hearth.losses import AdditiveAngularMarginLoss + >>> _ = torch.manual_seed(666) + >>> + >>> + >>> # this would be embeddings coming out of your model... + >>> emb = torch.normal(0, 1, size=(5, 128)) + >>> targets = torch.randint(0, 10, size=(5,)) + >>> loss = AdditiveAngularMarginLoss(128, 10) + >>> loss(emb, targets) + tensor(41.8191, grad_fn=) + """ + + def __init__( + self, + embedding_features: int, + n_classes: int, + scale: float = 64.0, + margin: float = 0.5, + **kwargs, + ): + super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs) + + def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor: + clipped = torch.clamp(pos, -1.0 + self.eps, 1 - self.eps) + return torch.cos(clipped.acos() + self.margin) + + +class LargeMarginCosineLoss(_AngularPenaltySoftMarginLoss): + """AKA CosFaceLoss. + + Reference: + `Wang et al. : CosFace: Large Margin Cosine Loss for Deep Face Recognition\ + `_ + + Note: + - this loss has weights and requires an optimizer. + + Args: + embedding_features: number of embedding features inputs are expected to be + (batch, embedding_features). + n_classes: number of classes in projection. + scale: Defaults to 30.0. + margin: Defaults to 0.4. + eps: epsilon for clamping . Defaults to 1e-7. + reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'. + Defaults to 'mean'. + + Example: + >>> import torch + >>> from hearth.losses import LargeMarginCosineLoss + >>> _ = torch.manual_seed(666) + >>> + >>> + >>> # this would be embeddings coming out of your model... + >>> emb = torch.normal(0, 1, size=(5, 128)) + >>> targets = torch.randint(0, 10, size=(5,)) + >>> loss = LargeMarginCosineLoss(128, 10) + >>> loss(emb, targets) + tensor(17.5245, grad_fn=) + """ + + def __init__( + self, + embedding_features: int, + n_classes: int, + scale: float = 30.0, + margin: float = 0.4, + **kwargs, + ): + super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs) + + def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor: + return pos - self.margin + + +class SphereEmbeddingLoss(_AngularPenaltySoftMarginLoss): + """AKA SphereFaceLoss. + + Reference: + `Liu et al. : SphereFace: Deep Hypersphere Embedding for Face Recognition\ + `_ + + Note: + - this loss has weights and requires an optimizer. + + Args: + embedding_features: number of embedding features inputs are expected to be + (batch, embedding_features). + n_classes: number of classes in projection. + scale: Defaults to 64.0. + margin: Defaults to 1.35. + eps: epsilon for clamping . Defaults to 1e-7. + reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'. + Defaults to 'mean'. + + Example: + >>> import torch + >>> from hearth.losses import SphereEmbeddingLoss + >>> _ = torch.manual_seed(666) + >>> + >>> + >>> # this would be embeddings coming out of your model... + >>> emb = torch.normal(0, 1, size=(5, 128)) + >>> targets = torch.randint(0, 10, size=(5,)) + >>> loss = SphereEmbeddingLoss(128, 10) + >>> loss(emb, targets) + tensor(44.7385, grad_fn=) + """ + + def __init__( + self, + embedding_features: int, + n_classes: int, + scale: float = 64.0, + margin: float = 1.35, + **kwargs, + ): + super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs) + + def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor: + clipped = torch.clamp(pos, -1.0 + self.eps, 1 - self.eps) + return torch.cos(self.margin * clipped.acos())