Skip to content

Commit

Permalink
Merge pull request #62 from leaprovenzano/feature/angular_penalty_sof…
Browse files Browse the repository at this point in the history
…t_margin_losses

Add Angular penalty soft margin losses
  • Loading branch information
leaprovenzano authored Feb 6, 2022
2 parents 66ed295 + 96d6a8d commit 33a174e
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/hearth/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def optimizer(self):
def optimizer(self, optimizer):
if isinstance(optimizer, LazyOptimizer) and not optimizer.initialized:
optimizer.add_model(self.model)
# add parameters from loss function if any exist...
# this will be a Running object so we need to access inner function...
optimizer.add_model(self.loss_fn.fn)
self._optimizer = optimizer

def to(self, device: Union[torch.device, str]):
Expand Down
200 changes: 200 additions & 0 deletions src/hearth/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,203 @@ def __init__(self, mask_target_value=-float('inf'), reduction: str = 'mean'):
def forward(self, input: Tensor, target: Tensor) -> Tensor:
mask = target != self.mask_target_value
return nn.functional.mse_loss(input[mask], target[mask], reduction=self.reduction)


class _AngularPenaltySoftMarginLoss(_BaseLoss):
"""base class for angular penalty soft margin losses."""

def __init__(
self,
embedding_features: int,
n_classes: int,
scale: float,
margin: float,
eps: float = 1e-7,
reduction='mean',
):
super().__init__(reduction=reduction)
self.embedding_features = embedding_features
self.n_classes = n_classes
self.scale = scale
self.margin = margin
self.eps = eps
self.weight = nn.Parameter(torch.FloatTensor(self.n_classes, self.embedding_features))
nn.init.xavier_uniform_(self.weight)

def _get_cos(self, inputs: Tensor) -> Tensor:
return nn.functional.linear(
nn.functional.normalize(inputs, p=2, dim=-1),
nn.functional.normalize(self.weight, p=2, dim=-1),
)

def _compute_unscaled_numerator(self, pos: Tensor):
return NotImplemented

def forward(self, inputs: Tensor, targets: Tensor) -> Tensor:
cos = self._get_cos(inputs)
n_classes = cos.shape[-1]
one_hot = torch.nn.functional.one_hot(targets, n_classes)
numerator = self.scale * self._compute_unscaled_numerator((cos * one_hot).sum(-1))
neg = cos.masked_fill(one_hot == 1, -float('inf'))
denominator = numerator.exp() + (self.scale * neg).exp().sum(-1)
err = -(numerator - torch.log(denominator))
return self._reduce(err)

def extra_repr(self) -> str:
out = [
f'embedding_features={self.embedding_features}',
f'n_classes={self.n_classes}',
f'scale={self.scale}',
f'margin={self.margin}',
f'eps={self.eps}',
f'reduction={self.reduction!r}',
]

return ', '.join(out)


class AdditiveAngularMarginLoss(_AngularPenaltySoftMarginLoss):
"""AKA ArcFaceLoss.
Reference:
`Deng et al. : ArcFace: Additive Angular Margin Loss for Deep Face Recognition\
<https://arxiv.org/abs/1801.07698>`_
Note:
- this loss has weights and requires an optimizer.
- :attr:`margin` is already expressed in radians.
Args:
embedding_features: number of embedding features inputs are expected to be
(batch, embedding_features).
n_classes: number of classes in projection.
scale: Defaults to 64.0.
margin: margin in radians. Defaults to 0.5 (28.6 degrees).
eps: epsilon for clamping . Defaults to 1e-7.
reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'.
Defaults to 'mean'.
Example:
>>> import torch
>>> from hearth.losses import AdditiveAngularMarginLoss
>>> _ = torch.manual_seed(666)
>>>
>>>
>>> # this would be embeddings coming out of your model...
>>> emb = torch.normal(0, 1, size=(5, 128))
>>> targets = torch.randint(0, 10, size=(5,))
>>> loss = AdditiveAngularMarginLoss(128, 10)
>>> loss(emb, targets)
tensor(41.8191, grad_fn=<MeanBackward0>)
"""

def __init__(
self,
embedding_features: int,
n_classes: int,
scale: float = 64.0,
margin: float = 0.5,
**kwargs,
):
super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs)

def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor:
clipped = torch.clamp(pos, -1.0 + self.eps, 1 - self.eps)
return torch.cos(clipped.acos() + self.margin)


class LargeMarginCosineLoss(_AngularPenaltySoftMarginLoss):
"""AKA CosFaceLoss.
Reference:
`Wang et al. : CosFace: Large Margin Cosine Loss for Deep Face Recognition\
<https://arxiv.org/abs/1801.09414>`_
Note:
- this loss has weights and requires an optimizer.
Args:
embedding_features: number of embedding features inputs are expected to be
(batch, embedding_features).
n_classes: number of classes in projection.
scale: Defaults to 30.0.
margin: Defaults to 0.4.
eps: epsilon for clamping . Defaults to 1e-7.
reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'.
Defaults to 'mean'.
Example:
>>> import torch
>>> from hearth.losses import LargeMarginCosineLoss
>>> _ = torch.manual_seed(666)
>>>
>>>
>>> # this would be embeddings coming out of your model...
>>> emb = torch.normal(0, 1, size=(5, 128))
>>> targets = torch.randint(0, 10, size=(5,))
>>> loss = LargeMarginCosineLoss(128, 10)
>>> loss(emb, targets)
tensor(17.5245, grad_fn=<MeanBackward0>)
"""

def __init__(
self,
embedding_features: int,
n_classes: int,
scale: float = 30.0,
margin: float = 0.4,
**kwargs,
):
super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs)

def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor:
return pos - self.margin


class SphereEmbeddingLoss(_AngularPenaltySoftMarginLoss):
"""AKA SphereFaceLoss.
Reference:
`Liu et al. : SphereFace: Deep Hypersphere Embedding for Face Recognition\
<https://arxiv.org/abs/1704.08063>`_
Note:
- this loss has weights and requires an optimizer.
Args:
embedding_features: number of embedding features inputs are expected to be
(batch, embedding_features).
n_classes: number of classes in projection.
scale: Defaults to 64.0.
margin: Defaults to 1.35.
eps: epsilon for clamping . Defaults to 1e-7.
reduction: batch reduction for this loss should be one of 'mean', 'sum', 'none'.
Defaults to 'mean'.
Example:
>>> import torch
>>> from hearth.losses import SphereEmbeddingLoss
>>> _ = torch.manual_seed(666)
>>>
>>>
>>> # this would be embeddings coming out of your model...
>>> emb = torch.normal(0, 1, size=(5, 128))
>>> targets = torch.randint(0, 10, size=(5,))
>>> loss = SphereEmbeddingLoss(128, 10)
>>> loss(emb, targets)
tensor(44.7385, grad_fn=<MeanBackward0>)
"""

def __init__(
self,
embedding_features: int,
n_classes: int,
scale: float = 64.0,
margin: float = 1.35,
**kwargs,
):
super().__init__(embedding_features, n_classes, scale=scale, margin=margin, **kwargs)

def _compute_unscaled_numerator(self, pos: Tensor) -> Tensor:
clipped = torch.clamp(pos, -1.0 + self.eps, 1 - self.eps)
return torch.cos(self.margin * clipped.acos())

0 comments on commit 33a174e

Please sign in to comment.