Skip to content

Commit

Permalink
Fix LoRA parallel composition (#752)
Browse files Browse the repository at this point in the history
Currently, many model implementation don't handle bsz replication caused
by parallel composition correctly for attention matrix LoRAs. This PRs
loosens bsz reshaping to enable this. Also adds a test case for LoRA
parallel.

Resolves #744.
  • Loading branch information
calpt authored Oct 30, 2024
1 parent 0c1039b commit bcace97
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 38 deletions.
26 changes: 23 additions & 3 deletions src/adapters/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
import torch.utils.checkpoint
from torch import nn

from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer
from transformers.models.bart.modeling_bart import (
BartAttention,
BartDecoderLayer,
BartEncoderLayer,
BartFlashAttention2,
BartSdpaAttention,
)
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel
Expand All @@ -32,6 +38,10 @@
class BartAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

# Loosen constraint on batch_size to allow parallel adapter composition
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -164,7 +174,12 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartAttention):
class BartFlashAttention2WithAdapters(BartAttentionAdaptersMixin, BartFlashAttention2):

# Loosen constraint on batch_size to allow parallel adapter composition
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -275,7 +290,12 @@ def forward(
return attn_output, attn_weights, past_key_value


class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartAttention):
class BartSdpaAttentionWithAdapters(BartAttentionAdaptersMixin, BartSdpaAttention):

# Loosen constraint on batch_size to allow parallel adapter composition
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def forward(

def shape(x: torch.Tensor) -> torch.Tensor:
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
# keep first dim due to parallel composition
return x.view(x.shape[0], -1, self.n_heads, dim_per_head).transpose(1, 2)

def unshape(x: torch.Tensor) -> torch.Tensor:
"""group heads"""
Expand Down
22 changes: 13 additions & 9 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
Expand Down Expand Up @@ -188,9 +189,11 @@ def forward(
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
Expand Down Expand Up @@ -320,9 +323,10 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
class MBartAttentionWithAdapters(BartAttentionAdaptersMixin, MBartAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

# Loosen constraint on batch_size to allow parallel adapter composition
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
33 changes: 18 additions & 15 deletions src/adapters/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
Expand Down Expand Up @@ -153,15 +154,16 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
# >>> END AH Changes <<<

kv_seq_len = key_states.shape[-2]
Expand Down Expand Up @@ -310,15 +312,16 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
# Loosen constraint on batch_size to allow parallel adapter composition
query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
(attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask)
(attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids)
# >>> END AH Changes <<<

cos, sin = self.rotary_emb(value_states, position_ids)
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def forward(

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# keep first dim due to parallel composition
return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
Expand Down
14 changes: 14 additions & 0 deletions src/adapters/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
class PLBartAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

# Loosen constraint on batch_size to allow parallel adapter composition
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -169,6 +173,11 @@ def forward(


class PLBartFlashAttention2WithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):

# Loosen constraint on batch_size to allow parallel adapter composition
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -280,6 +289,11 @@ def forward(


class PLBartSdpaAttentionWithAdapters(PLBartAttentionAdaptersMixin, PLBartAttention):

# Loosen constraint on batch_size to allow parallel adapter composition
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(tensor.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def forward(

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# keep first dim due to parallel composition
return states.view(states.shape[0], -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
Expand Down
20 changes: 15 additions & 5 deletions tests/composition/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import torch

from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig, SeqBnConfig, T5AdapterModel
from adapters import (
ADAPTER_MODEL_MAPPING,
AutoAdapterModel,
LoRAConfig,
PrefixTuningConfig,
SeqBnConfig,
T5AdapterModel,
)
from adapters.composition import BatchSplit, Parallel
from adapters.models.bert_generation.adapter_model import BertGenerationAdapterModel
from transformers import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, Trainer, TrainingArguments
Expand Down Expand Up @@ -276,13 +283,16 @@ def run_parallel_training_equivalent_to_single(self, adapter_config):
self.assertTrue(torch.allclose(v, state_dict[k.replace(b1, b2)], atol=1e-5))

def test_parallel_training_bottleneck(self):
self.run_parallel_training_test(SeqBnConfig(), "adapters.{}")
self.run_parallel_training_test(SeqBnConfig(reduction_factor=48), "adapters.{}")

def test_parallel_training_lora(self):
self.run_parallel_training_test(LoRAConfig(r=1), "loras.{}")

def test_parallel_training_prefix_tuning(self):
self.run_parallel_training_test(PrefixTuningConfig(), "prefix_tunings.{}")

def test_parallel_training_equivalent_to_single_bottleneck(self):
self.run_parallel_training_equivalent_to_single(SeqBnConfig())
self.run_parallel_training_equivalent_to_single(SeqBnConfig(reduction_factor=48))

def test_parallel_training_equivalent_to_single_prefix_tuning(self):
self.run_parallel_training_equivalent_to_single(PrefixTuningConfig())
Expand All @@ -291,8 +301,8 @@ def test_parallel_training_single_forward_pass(self):
model = AutoAdapterModel.from_config(self.config())
model.eval()

a1, a2 = self.create_twin_adapters(model, "a", SeqBnConfig())
b1, b2 = self.create_twin_adapters(model, "b", SeqBnConfig())
a1, a2 = self.create_twin_adapters(model, "a", SeqBnConfig(reduction_factor=48))
b1, b2 = self.create_twin_adapters(model, "b", SeqBnConfig(reduction_factor=48))

state_dict = model.state_dict()
for k, v in state_dict.items():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class DebertaAdapterTest(
DebertaAdapterTestBase,
unittest.TestCase,
):
pass
def test_parallel_training_lora(self):
self.skipTest("Not supported for DeBERTa")


@require_torch
Expand Down
3 changes: 2 additions & 1 deletion tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class GPT2AdapterTest(
GPT2AdapterTestBase,
unittest.TestCase,
):
pass
def test_parallel_training_lora(self):
self.skipTest("Not supported for GPT2")


@require_torch
Expand Down
3 changes: 2 additions & 1 deletion tests/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class WhisperAdapterTest(
WhisperAdapterTestBase,
unittest.TestCase,
):
pass
def test_parallel_training_lora(self):
self.skipTest("Not supported for Whisper")


@require_torch
Expand Down

0 comments on commit bcace97

Please sign in to comment.