Skip to content

Commit

Permalink
Merge pull request #108 from aai-institute/feature/multi-head-attention
Browse files Browse the repository at this point in the history
Feature: Multi-Head Attention
  • Loading branch information
Samuel Burbulla authored Jun 27, 2024
2 parents 03f8918 + d5f7ac3 commit f07fe22
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 2 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# CHANGELOG

## 0.1
## 0.2.0

- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.

## 0.1.0

- Move all content of `__init__.py` files to sub-modules.
- Add `Trainer` class to replace `operator.fit` method.
Expand Down
9 changes: 8 additions & 1 deletion src/continuiti/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@

from .fully_connected import FullyConnected
from .deep_residual_network import DeepResidualNetwork
from .multi_head_attention import MultiHeadAttention
from .scaled_dot_product_attention import ScaledDotProductAttention

__all__ = ["FullyConnected", "DeepResidualNetwork"]
__all__ = [
"FullyConnected",
"DeepResidualNetwork",
"MultiHeadAttention",
"ScaledDotProductAttention",
]
44 changes: 44 additions & 0 deletions src/continuiti/networks/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
`continuiti.networks.attention`
Attention base class in continuiti.
"""

from abc import abstractmethod
import torch.nn as nn
import torch


class Attention(nn.Module):
"""Base class for various attention implementations.
Attention assigns different parts of an input varying importance without set
kernels. The importance of different components is designated using "soft"
weights. These weights are assigned according to specific algorithms (e.g.
scaled-dot-product attention).
"""

def __init__(self):
super().__init__()

@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
"""Calculates the attention scores.
Args:
query: query tensor; shape (batch_size, target_seq_length, hidden_dim)
key: key tensor; shape (batch_size, source_seq_length, hidden_dim)
value: value tensor; shape (batch_size, source_seq_length, hidden_dim)
attn_mask: tensor indicating which values are used to calculate the output;
shape (batch_size, target_seq_length, source_seq_length)
Returns:
tensor containing the outputs of the attention implementation;
shape (batch_size, target_seq_length, hidden_dim)
"""
138 changes: 138 additions & 0 deletions src/continuiti/networks/multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
`continuiti.networks.multi_head_attention`
Multi-Head-Attention in continuiti.
"""

import torch
import torch.nn as nn

from .attention import Attention
from .scaled_dot_product_attention import ScaledDotProductAttention


class MultiHeadAttention(Attention):
r"""Multi-Head Attention module.
Module as described in the paper [Attention is All you
Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
with optional bias for the projections. This implementation allows to use
attention implementations other than the standard scaled dot product
attention implemented by the MultiheadAttention PyTorch module.
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$
where
$$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$
Args:
hidden_dim: dimension of the hidden layers (embedding dimension).
n_heads: number of attention heads.
attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments
`query`, `key`, `value`, `attn_mask`, and `dropout_p`.
dropout_p: dropout probability.
bias: If True, then the projection onto the different heads is performed with bias.
"""

def __init__(
self,
hidden_dim: int,
n_heads: int,
attention: Attention = None,
dropout_p: float = 0,
bias: bool = True,
):
super().__init__()

self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.dropout_p = dropout_p
self.bias = bias

if attention is None:
attention = ScaledDotProductAttention()
self.attention = attention

self.head_dim = hidden_dim // n_heads
assert (
self.head_dim * n_heads == hidden_dim
), "hidden_dim must be divisible by n_heads"

# projection networks
self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
r"""Compute the attention scores.
Args:
query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim).
key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim).
value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim).
attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length).
Returns:
Attention scores of shape (batch_size, target_sequence_length, hidden_dim).
"""
assert query.ndim == key.ndim == value.ndim == 3, (
"Query, key, and value need to have three dimensions (batch_size, ..., hidden_dim). This format ensures that"
"the module can correctly apply the multi-head attention mechanism, which includes splitting embeddings "
"into multiple heads, applying the internal attention implementation for each head, concatenating and "
"projecting results, while ensuring that the attention mask is applied correctly."
)
assert (
query.size(0) == key.size(0) == value.size(0)
), "Batch size does not match for input tensors"
assert (
query.size(-1) == key.size(-1) == value.size(-1)
), "Embedding/hidden dimension does not match for input tensors"

batch_size = query.size(0)
src_len = key.size(1)
tgt_len = query.size(1)

# project values
query = self.query_project(query)
key = self.key_project(key)
value = self.value_project(value)

# form individual heads
query = query.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)

# reshape attention mask to match heads
if attn_mask is not None:
assert (
attn_mask.size(0) == batch_size
), "Attention mask batch size does not match input tensors."
assert (
attn_mask.size(1) == tgt_len
), "First dimension of the attention mask needs to match target length."
assert (
attn_mask.size(2) == src_len
), "Second dimension of the attention mask needs to match source length."

attn_mask = attn_mask.unsqueeze(1) # mask for a single head
attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head

# perform attention
attn_out = self.attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
)
attn_out = attn_out.transpose(1, 2).reshape(batch_size, -1, self.hidden_dim)

# output projection
return self.out_project(attn_out)
41 changes: 41 additions & 0 deletions src/continuiti/networks/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
`continuiti.networks.scaled_dot_product_attention`
Scaled dot product attention module.
"""
import torch

from .attention import Attention
from torch.nn.functional import scaled_dot_product_attention


class ScaledDotProductAttention(Attention):
"""Scaled dot product attention module.
This module is a wrapper for the torch implementation of the scaled dot
product attention mechanism as described in the paper "Attention Is All You
Need" by Vaswani et al. (2017). This attention mechanism computes the
attention weights based on the dot product of the query and key matrices,
scaled by the square root of the dimension of the key vectors. The weights
are then applied to the value vectors to obtain the final output.
"""

def __init__(self, dropout_p: float = 0.0):
super().__init__()
self.dropout_p = dropout_p

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
) -> torch.Tensor:
dropout_p = self.dropout_p if self.training else 0.0
return scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
)
Loading

0 comments on commit f07fe22

Please sign in to comment.