Skip to content

Commit

Permalink
Merge pull request #61 from leaprovenzano/feature/add_masked_mse_loss
Browse files Browse the repository at this point in the history
add masked mse loss
  • Loading branch information
leaprovenzano authored Jan 30, 2022
2 parents ec758c9 + eab155e commit 66ed295
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.0
current_version = 0.4.0

[bumpversion:file:setup.py]
search = version='{current_version}'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@
test_suite='tests',
tests_require=test_requirements,
url='https://github.com/leaprovenzano/hearth',
version='0.3.0',
version='0.4.0',
zip_safe=False,
)
2 changes: 1 addition & 1 deletion src/hearth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__author__ = """Lea Provenzano"""
__email__ = '[email protected]'
__version__ = '0.3.0'
__version__ = '0.4.0'
36 changes: 35 additions & 1 deletion src/hearth/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional, Union, Dict, Mapping, Callable
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.modules.loss import _Loss


from hearth.containers import TensorDict, NumberDict
from hearth._multihead import _MultiHeadFunc

Expand Down Expand Up @@ -359,3 +362,34 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
def extra_repr(self) -> str:
parent_args = super().extra_repr()
return f'alpha={self.alpha!r}, gamma={self.gamma}, {parent_args}'


class MaskedMSELoss(_Loss):
"""MSELoss with support for masked targets.
Args:
mask_target_value: ignore targets with this value. Defaults to -inf.
reduction: Defaults to 'mean'.
Example:
>>> import torch
>>> _ = torch.manual_seed(0)
>>> from hearth.losses import MaskedMSELoss
>>>
>>> ninf = -float('inf')
>>> loss = MaskedMSELoss()
>>> inputs = torch.rand(3, 5) # (batch, timesteps)
>>> targets = torch.tensor([[ 1.1721, 0.3909, -5.2731, ninf, ninf],
... [ 2.4388, 2.5159, -1.0815, -1.9472, -0.5450],
... [ 4.0665, -2.5141, ninf, ninf, ninf]])
>>> loss(inputs, targets)
tensor(7.0101)
"""

def __init__(self, mask_target_value=-float('inf'), reduction: str = 'mean'):
super().__init__(reduction=reduction)
self.mask_target_value = mask_target_value

def forward(self, input: Tensor, target: Tensor) -> Tensor:
mask = target != self.mask_target_value
return nn.functional.mse_loss(input[mask], target[mask], reduction=self.reduction)

0 comments on commit 66ed295

Please sign in to comment.