diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bf60e9d..4d3e8c5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # CHANGELOG -## 0.1 +## 0.2.0 + +- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes. + +## 0.1.0 - Move all content of `__init__.py` files to sub-modules. - Add `Trainer` class to replace `operator.fit` method. @@ -24,7 +28,6 @@ - Add `benchmarks` infrastructure. - An `Operator` now takes a `device` argument. - Add `QuantileScaler` class. -- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes. ## 0.0.0 (2024-02-22) diff --git a/src/continuiti/networks/attention.py b/src/continuiti/networks/attention.py index 82e24dee..0fa6aea2 100644 --- a/src/continuiti/networks/attention.py +++ b/src/continuiti/networks/attention.py @@ -12,10 +12,10 @@ class Attention(nn.Module): """Base class for various attention implementations. - Attention assigns different parts of an input varying importance without set kernels. The importance of different - components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g. + Attention assigns different parts of an input varying importance without set + kernels. The importance of different components is designated using "soft" + weights. These weights are assigned according to specific algorithms (e.g. scaled-dot-product attention). - """ def __init__(self): @@ -26,7 +26,7 @@ def forward( self, query: torch.Tensor, key: torch.Tensor, - value: torch, + value: torch.Tensor, attn_mask: torch.Tensor = None, ) -> torch.Tensor: """Calculates the attention scores. diff --git a/src/continuiti/networks/multi_head_attention.py b/src/continuiti/networks/multi_head_attention.py index d07fc0c2..e6eacb6f 100644 --- a/src/continuiti/networks/multi_head_attention.py +++ b/src/continuiti/networks/multi_head_attention.py @@ -14,9 +14,11 @@ class MultiHeadAttention(Attention): r"""Multi-Head Attention module. - Module as described in the paper [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) - with optional bias for the projections. This implementation allows to use attention implementations other than the - standard scaled dot product attention implemented by the MultiheadAttention PyTorch module. + Module as described in the paper [Attention is All you + Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) + with optional bias for the projections. This implementation allows to use + attention implementations other than the standard scaled dot product + attention implemented by the MultiheadAttention PyTorch module. $$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$ @@ -67,7 +69,7 @@ def forward( self, query: torch.Tensor, key: torch.Tensor, - value: torch, + value: torch.Tensor, attn_mask: torch.Tensor = None, ) -> torch.Tensor: r"""Compute the attention scores. diff --git a/src/continuiti/networks/scaled_dot_product_attention.py b/src/continuiti/networks/scaled_dot_product_attention.py index 64f60bd3..752fb765 100644 --- a/src/continuiti/networks/scaled_dot_product_attention.py +++ b/src/continuiti/networks/scaled_dot_product_attention.py @@ -12,10 +12,12 @@ class ScaledDotProductAttention(Attention): """Scaled dot product attention module. - This module is a wrapper for the torch implementation of the scaled dot product attention mechanism as described in - the paper "Attention Is All You Need" by Vaswani et al. (2017). This attention mechanism computes the attention - weights based on the dot product of the query and key matrices, scaled by the square root of the dimension of the - key vectors. The weights are then applied to the value vectors to obtain the final output. + This module is a wrapper for the torch implementation of the scaled dot + product attention mechanism as described in the paper "Attention Is All You + Need" by Vaswani et al. (2017). This attention mechanism computes the + attention weights based on the dot product of the query and key matrices, + scaled by the square root of the dimension of the key vectors. The weights + are then applied to the value vectors to obtain the final output. """ def __init__(self, dropout_p: float = 0.0): @@ -26,13 +28,10 @@ def forward( self, query: torch.Tensor, key: torch.Tensor, - value: torch, + value: torch.Tensor, attn_mask: torch.Tensor = None, ) -> torch.Tensor: - if self.training: - dropout_p = self.dropout_p - else: - dropout_p = 0.0 + dropout_p = self.dropout_p if self.training else 0.0 return scaled_dot_product_attention( query=query, key=key,