Skip to content

Commit

Permalink
Merge branch 'fix_vision_lora' into compile_break
Browse files Browse the repository at this point in the history
  • Loading branch information
Felipe Mello committed Nov 19, 2024
2 parents 4c8f8c5 + 34b9ceb commit d9f417d
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 24 deletions.
21 changes: 1 addition & 20 deletions torchtune/models/llama3_2_vision/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def lora_llama3_2_vision_encoder(
fusion_lora: bool,
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# clip encoder parameters
patch_size: int,
Expand Down Expand Up @@ -377,8 +376,6 @@ def lora_llama3_2_vision_encoder(
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
patch_size (int): The size of each patch. Used to divide the tiles into patches.
E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
Expand Down Expand Up @@ -412,7 +409,6 @@ def lora_llama3_2_vision_encoder(
lora_options = {
"lora_modules": lora_attn_modules,
"apply_lora_to_mlp": apply_lora_to_mlp,
"apply_lora_to_output": apply_lora_to_output,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
Expand Down Expand Up @@ -679,7 +675,6 @@ def lora_llama3_2_vision_projection_head(
num_hidden_inputs: int,
# LoRA args
apply_lora_to_mlp: bool,
apply_lora_to_output: bool,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
Expand All @@ -701,8 +696,6 @@ def lora_llama3_2_vision_projection_head(
num_hidden_inputs (int): number of hidden inputs to the projection head.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Expand Down Expand Up @@ -773,19 +766,7 @@ def lora_llama3_2_vision_projection_head(
# cross encoding
# TODO: quantize_base is not applied to final output_proj currently.
proj_in = clip_embed_dim * (num_hidden_inputs + 1)
adapter_cls = DoRALinear if use_dora else LoRALinear
output_proj = (
adapter_cls(
proj_in,
decoder_embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
use_bias=True,
)
if apply_lora_to_output
else nn.Linear(proj_in, decoder_embed_dim)
)
output_proj = nn.Linear(proj_in, decoder_embed_dim)
return Llama3VisionProjectionHead(
layers=layers,
output=output_proj,
Expand Down
2 changes: 0 additions & 2 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def lora_llama3_2_vision_11b(
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
patch_size=14,
num_heads=16,
clip_embed_dim=1280,
Expand Down Expand Up @@ -330,7 +329,6 @@ def lora_llama3_2_vision_90b(
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
patch_size=14,
num_heads=16,
clip_embed_dim=1280,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.use_bias = use_bias
self._quantize_base = quantize_base

if not self._quantize_base and quantization_kwargs:
if not self._quantize_base and any([v for v in quantization_kwargs.values()]):
raise ValueError(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.use_bias = use_bias
self._quantize_base = quantize_base

if not self._quantize_base and quantization_kwargs:
if not self._quantize_base and any([v for v in quantization_kwargs.values()]):
raise ValueError(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)
Expand Down

0 comments on commit d9f417d

Please sign in to comment.