-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from aai-institute/feature/multi-head-attention
Feature: Multi-Head Attention
- Loading branch information
Showing
7 changed files
with
449 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
`continuiti.networks.attention` | ||
Attention base class in continuiti. | ||
""" | ||
|
||
from abc import abstractmethod | ||
import torch.nn as nn | ||
import torch | ||
|
||
|
||
class Attention(nn.Module): | ||
"""Base class for various attention implementations. | ||
Attention assigns different parts of an input varying importance without set | ||
kernels. The importance of different components is designated using "soft" | ||
weights. These weights are assigned according to specific algorithms (e.g. | ||
scaled-dot-product attention). | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@abstractmethod | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
attn_mask: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
"""Calculates the attention scores. | ||
Args: | ||
query: query tensor; shape (batch_size, target_seq_length, hidden_dim) | ||
key: key tensor; shape (batch_size, source_seq_length, hidden_dim) | ||
value: value tensor; shape (batch_size, source_seq_length, hidden_dim) | ||
attn_mask: tensor indicating which values are used to calculate the output; | ||
shape (batch_size, target_seq_length, source_seq_length) | ||
Returns: | ||
tensor containing the outputs of the attention implementation; | ||
shape (batch_size, target_seq_length, hidden_dim) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
`continuiti.networks.multi_head_attention` | ||
Multi-Head-Attention in continuiti. | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .attention import Attention | ||
from .scaled_dot_product_attention import ScaledDotProductAttention | ||
|
||
|
||
class MultiHeadAttention(Attention): | ||
r"""Multi-Head Attention module. | ||
Module as described in the paper [Attention is All you | ||
Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) | ||
with optional bias for the projections. This implementation allows to use | ||
attention implementations other than the standard scaled dot product | ||
attention implemented by the MultiheadAttention PyTorch module. | ||
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$ | ||
where | ||
$$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$ | ||
Args: | ||
hidden_dim: dimension of the hidden layers (embedding dimension). | ||
n_heads: number of attention heads. | ||
attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments | ||
`query`, `key`, `value`, `attn_mask`, and `dropout_p`. | ||
dropout_p: dropout probability. | ||
bias: If True, then the projection onto the different heads is performed with bias. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_dim: int, | ||
n_heads: int, | ||
attention: Attention = None, | ||
dropout_p: float = 0, | ||
bias: bool = True, | ||
): | ||
super().__init__() | ||
|
||
self.hidden_dim = hidden_dim | ||
self.n_heads = n_heads | ||
self.dropout_p = dropout_p | ||
self.bias = bias | ||
|
||
if attention is None: | ||
attention = ScaledDotProductAttention() | ||
self.attention = attention | ||
|
||
self.head_dim = hidden_dim // n_heads | ||
assert ( | ||
self.head_dim * n_heads == hidden_dim | ||
), "hidden_dim must be divisible by n_heads" | ||
|
||
# projection networks | ||
self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
attn_mask: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
r"""Compute the attention scores. | ||
Args: | ||
query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim). | ||
key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim). | ||
value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim). | ||
attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length). | ||
Returns: | ||
Attention scores of shape (batch_size, target_sequence_length, hidden_dim). | ||
""" | ||
assert query.ndim == key.ndim == value.ndim == 3, ( | ||
"Query, key, and value need to have three dimensions (batch_size, ..., hidden_dim). This format ensures that" | ||
"the module can correctly apply the multi-head attention mechanism, which includes splitting embeddings " | ||
"into multiple heads, applying the internal attention implementation for each head, concatenating and " | ||
"projecting results, while ensuring that the attention mask is applied correctly." | ||
) | ||
assert ( | ||
query.size(0) == key.size(0) == value.size(0) | ||
), "Batch size does not match for input tensors" | ||
assert ( | ||
query.size(-1) == key.size(-1) == value.size(-1) | ||
), "Embedding/hidden dimension does not match for input tensors" | ||
|
||
batch_size = query.size(0) | ||
src_len = key.size(1) | ||
tgt_len = query.size(1) | ||
|
||
# project values | ||
query = self.query_project(query) | ||
key = self.key_project(key) | ||
value = self.value_project(value) | ||
|
||
# form individual heads | ||
query = query.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) | ||
|
||
# reshape attention mask to match heads | ||
if attn_mask is not None: | ||
assert ( | ||
attn_mask.size(0) == batch_size | ||
), "Attention mask batch size does not match input tensors." | ||
assert ( | ||
attn_mask.size(1) == tgt_len | ||
), "First dimension of the attention mask needs to match target length." | ||
assert ( | ||
attn_mask.size(2) == src_len | ||
), "Second dimension of the attention mask needs to match source length." | ||
|
||
attn_mask = attn_mask.unsqueeze(1) # mask for a single head | ||
attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head | ||
|
||
# perform attention | ||
attn_out = self.attention( | ||
query=query, | ||
key=key, | ||
value=value, | ||
attn_mask=attn_mask, | ||
) | ||
attn_out = attn_out.transpose(1, 2).reshape(batch_size, -1, self.hidden_dim) | ||
|
||
# output projection | ||
return self.out_project(attn_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
`continuiti.networks.scaled_dot_product_attention` | ||
Scaled dot product attention module. | ||
""" | ||
import torch | ||
|
||
from .attention import Attention | ||
from torch.nn.functional import scaled_dot_product_attention | ||
|
||
|
||
class ScaledDotProductAttention(Attention): | ||
"""Scaled dot product attention module. | ||
This module is a wrapper for the torch implementation of the scaled dot | ||
product attention mechanism as described in the paper "Attention Is All You | ||
Need" by Vaswani et al. (2017). This attention mechanism computes the | ||
attention weights based on the dot product of the query and key matrices, | ||
scaled by the square root of the dimension of the key vectors. The weights | ||
are then applied to the value vectors to obtain the final output. | ||
""" | ||
|
||
def __init__(self, dropout_p: float = 0.0): | ||
super().__init__() | ||
self.dropout_p = dropout_p | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
attn_mask: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
dropout_p = self.dropout_p if self.training else 0.0 | ||
return scaled_dot_product_attention( | ||
query=query, | ||
key=key, | ||
value=value, | ||
attn_mask=attn_mask, | ||
dropout_p=dropout_p, | ||
) |
Oops, something went wrong.