From cb825d7ba31baf49d983286327a1206e937cd0ca Mon Sep 17 00:00:00 2001 From: JakobEliasWagner <42122260+JakobEliasWagner@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:38:40 +0200 Subject: [PATCH] Feature: Masked Operator (#147) * add masked operator base class --------- Co-authored-by: Samuel Burbulla --- CHANGELOG.md | 1 + src/continuiti/operators/operator.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d3e8c5c..08641556 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.2.0 - Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes. +- Add `MaskedOperator` base class. ## 0.1.0 diff --git a/src/continuiti/operators/operator.py b/src/continuiti/operators/operator.py index 9ba828ba..1be3d6a3 100644 --- a/src/continuiti/operators/operator.py +++ b/src/continuiti/operators/operator.py @@ -74,3 +74,35 @@ def num_params(self) -> int: def __str__(self): """Return string representation of the operator.""" return self.__class__.__name__ + + +class MaskedOperator(Operator, ABC): + """Masked operator base class. + + A masked operator can apply masks during the forward pass to selectively use or ignore parts of the input. Masked + operators allow for different numbers of sensors in addition to the common property of being able to handle + varying numbers of evaluations. + + """ + + @abstractmethod + def forward( + self, + x: torch.Tensor, + u: torch.Tensor, + y: torch.Tensor, + sensor_mask: Optional[torch.Tensor] = None, + eval_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the operator. + + Args: + x: Sensor positions of shape (batch_size, x_dim, num_sensors...). + u: Input function values of shape (batch_size, u_dim, num_sensors...). + y: Evaluation coordinates of shape (batch_size, y_dim, num_evaluations...). + sensor_mask: Boolean mask for x and u of shape (batch_size, 1, num_sensors...). + eval_mask: Boolean mask for y of shape (batch_size, 1, num_evaluations...). + + Returns: + Evaluations of the mapped function with shape (batch_size, v_dim, num_evaluations...). + """