Skip to content

Commit

Permalink
Pass quantization_kwargs to CLIP builders (#1994)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored Nov 13, 2024
1 parent 912af64 commit 51b31c8
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
33 changes: 27 additions & 6 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,20 @@ def clip_mlp(
hidden_dim: int,
activation: nn.Module,
quantize_base: bool = False,
**quantization_kwargs,
) -> FeedForward:
"""
Build the MLP layer associated with the clip model.
"""
gate_proj = (
nn.Linear(in_dim, hidden_dim)
if not quantize_base
else FrozenNF4Linear(in_dim, hidden_dim, bias=True)
else FrozenNF4Linear(in_dim, hidden_dim, bias=True, **quantization_kwargs)
)
down_proj = (
nn.Linear(hidden_dim, out_dim)
if not quantize_base
else FrozenNF4Linear(hidden_dim, out_dim, bias=True)
else FrozenNF4Linear(hidden_dim, out_dim, bias=True, **quantization_kwargs)
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
Expand Down Expand Up @@ -210,6 +211,7 @@ def lora_clip_vision_encoder(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> VisionTransformer:
"""
Build a LoRA implementation of the CLIP vision encoder.
Expand Down Expand Up @@ -277,6 +279,7 @@ def lora_clip_vision_encoder(
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
**quantization_kwargs,
)
if apply_lora_to_mlp:
mlp = lora_clip_mlp(
Expand All @@ -289,6 +292,7 @@ def lora_clip_vision_encoder(
quantize_base=quantize_base,
lora_dropout=lora_dropout,
use_dora=use_dora,
**quantization_kwargs,
)
else:
mlp = clip_mlp(
Expand All @@ -297,6 +301,7 @@ def lora_clip_vision_encoder(
out_dim=embed_dim,
activation=activation(),
quantize_base=quantize_base,
**quantization_kwargs,
)
transformer_layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand Down Expand Up @@ -367,6 +372,7 @@ def lora_clip_attention(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> MultiHeadAttention:
"""
Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA
Expand Down Expand Up @@ -414,12 +420,15 @@ def lora_clip_attention(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "q_proj" in lora_modules
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
else FrozenNF4Linear(
embed_dim, num_heads * head_dim, bias=False, **quantization_kwargs
)
)
)
k_proj = (
Expand All @@ -430,12 +439,15 @@ def lora_clip_attention(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "k_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else FrozenNF4Linear(
embed_dim, num_kv_heads * head_dim, bias=False, **quantization_kwargs
)
)
)
v_proj = (
Expand All @@ -446,12 +458,15 @@ def lora_clip_attention(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "v_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else FrozenNF4Linear(
embed_dim, num_kv_heads * head_dim, bias=False, **quantization_kwargs
)
)
)
output_proj = (
Expand All @@ -462,12 +477,15 @@ def lora_clip_attention(
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
**quantization_kwargs,
)
if "output_proj" in lora_modules
else (
nn.Linear(embed_dim, embed_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
else FrozenNF4Linear(
embed_dim, embed_dim, bias=False, **quantization_kwargs
)
)
)

Expand Down Expand Up @@ -497,6 +515,7 @@ def lora_clip_mlp(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> FeedForward:
"""
Build the MLP layer with LoRA applied to the gate and down projections.
Expand All @@ -510,6 +529,7 @@ def lora_clip_mlp(
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
**quantization_kwargs,
)
down_proj = adapter_cls(
in_dim=hidden_dim,
Expand All @@ -519,6 +539,7 @@ def lora_clip_mlp(
dropout=lora_dropout,
quantize_base=quantize_base,
use_bias=True,
**quantization_kwargs,
)
return FeedForward(
gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation
Expand Down
6 changes: 6 additions & 0 deletions torchtune/models/llama3_2_vision/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def lora_llama3_2_vision_encoder(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> Llama3VisionEncoder:
"""
Build the Llama 3.2 vision encoder by combining the CLIP image model with an additional
Expand Down Expand Up @@ -417,6 +418,7 @@ def lora_llama3_2_vision_encoder(
"lora_dropout": lora_dropout,
"use_dora": use_dora,
"quantize_base": quantize_base,
**quantization_kwargs,
}

# clip encoder
Expand Down Expand Up @@ -683,6 +685,7 @@ def lora_llama3_2_vision_projection_head(
lora_dropout: float = 0.0,
use_dora: bool = False,
quantize_base: bool = False,
**quantization_kwargs,
) -> Llama3VisionProjectionHead:
"""
Build the Llama 3.2 Vision Projection Head with LoRA applied to a subset of the layers.
Expand Down Expand Up @@ -729,6 +732,7 @@ def lora_llama3_2_vision_projection_head(
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
**quantization_kwargs,
)

if apply_lora_to_mlp:
Expand All @@ -742,6 +746,7 @@ def lora_llama3_2_vision_projection_head(
quantize_base=quantize_base,
lora_dropout=lora_dropout,
use_dora=use_dora,
**quantization_kwargs,
)
else:
mlp = clip_mlp(
Expand All @@ -750,6 +755,7 @@ def lora_llama3_2_vision_projection_head(
out_dim=clip_embed_dim,
activation=nn.GELU(),
quantize_base=quantize_base,
**quantization_kwargs,
)

layer = TransformerSelfAttentionLayer(
Expand Down
6 changes: 6 additions & 0 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def lora_llama3_2_vision_11b(
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
# Update scaler block size to ensure that weights can be quantized evenly across 1, 2, 4, 6, 8 GPUs.
# This is dependent on ``clip_embed_dim`` so if that is updated, this variable should be as well
scaler_block_size=200 if quantize_base else None,
)
decoder = lora_llama3_2_vision_decoder(
decoder_lora=decoder_type == LoRATrainable.LORA,
Expand Down Expand Up @@ -348,6 +351,9 @@ def lora_llama3_2_vision_90b(
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
# Update scaler block size to ensure that weights can be quantized evenly across 1, 2, 4, 6, 8 GPUs.
# This is dependent on ``clip_embed_dim`` so if that is updated, this variable should be as well
scaler_block_size=200 if quantize_base else None,
)
decoder = lora_llama3_2_vision_decoder(
decoder_lora=decoder_type == LoRATrainable.LORA,
Expand Down
6 changes: 5 additions & 1 deletion torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,11 @@ def load_from_full_model_state_dict(
if hasattr(sharded_meta_param, "_local_tensor") and isinstance(
sharded_meta_param._local_tensor, NF4Tensor
):
full_tensor = to_nf4(full_tensor)
block_size = sharded_meta_param._local_tensor.block_size
scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size
full_tensor = to_nf4(
full_tensor, block_size=block_size, scaler_block_size=scaler_block_size
)
# replicating logic from `_fsdp_param.py`` `_init_sharded_param`
# otherwise `distribute_tensor(DTensor(local=NF4))`
# requires dispatching `c10d.scatter_``
Expand Down

0 comments on commit 51b31c8

Please sign in to comment.