diff --git a/setup.cfg b/setup.cfg index 1bd817c..afe20ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.0 +current_version = 0.4.0 [bumpversion:file:setup.py] search = version='{current_version}' diff --git a/setup.py b/setup.py index 3046f1a..aa9632d 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,6 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/leaprovenzano/hearth', - version='0.3.0', + version='0.4.0', zip_safe=False, ) diff --git a/src/hearth/__init__.py b/src/hearth/__init__.py index 81fe3ef..905c8be 100644 --- a/src/hearth/__init__.py +++ b/src/hearth/__init__.py @@ -4,4 +4,4 @@ __author__ = """Lea Provenzano""" __email__ = 'leaprovenzano@gmail.com' -__version__ = '0.3.0' +__version__ = '0.4.0' diff --git a/src/hearth/losses.py b/src/hearth/losses.py index 6435002..e4be045 100644 --- a/src/hearth/losses.py +++ b/src/hearth/losses.py @@ -1,6 +1,9 @@ from typing import Optional, Union, Dict, Mapping, Callable import torch -from torch import nn +from torch import nn, Tensor +from torch.nn.modules.loss import _Loss + + from hearth.containers import TensorDict, NumberDict from hearth._multihead import _MultiHeadFunc @@ -359,3 +362,34 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: def extra_repr(self) -> str: parent_args = super().extra_repr() return f'alpha={self.alpha!r}, gamma={self.gamma}, {parent_args}' + + +class MaskedMSELoss(_Loss): + """MSELoss with support for masked targets. + + Args: + mask_target_value: ignore targets with this value. Defaults to -inf. + reduction: Defaults to 'mean'. + + Example: + >>> import torch + >>> _ = torch.manual_seed(0) + >>> from hearth.losses import MaskedMSELoss + >>> + >>> ninf = -float('inf') + >>> loss = MaskedMSELoss() + >>> inputs = torch.rand(3, 5) # (batch, timesteps) + >>> targets = torch.tensor([[ 1.1721, 0.3909, -5.2731, ninf, ninf], + ... [ 2.4388, 2.5159, -1.0815, -1.9472, -0.5450], + ... [ 4.0665, -2.5141, ninf, ninf, ninf]]) + >>> loss(inputs, targets) + tensor(7.0101) + """ + + def __init__(self, mask_target_value=-float('inf'), reduction: str = 'mean'): + super().__init__(reduction=reduction) + self.mask_target_value = mask_target_value + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + mask = target != self.mask_target_value + return nn.functional.mse_loss(input[mask], target[mask], reduction=self.reduction)