Skip to content

Commit

Permalink
add tensor parallelism with LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
matthew-hippocraticai committed Dec 24, 2024
1 parent aa8f365 commit 44a26c6
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 28 deletions.
102 changes: 97 additions & 5 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from omegaconf import DictConfig, ListConfig

from torch import nn
from torch.distributed import destroy_process_group, init_process_group
from torch.distributed import destroy_process_group, DeviceMesh, init_process_group
from torch.distributed.device_mesh import init_device_mesh

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -116,6 +117,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
device_mesh (DeviceMesh): DeviceMesh object that contains the device topology
Raises:
ValueError: If ``dtype`` is set to fp16.
Expand All @@ -126,9 +128,10 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""

def __init__(self, cfg: DictConfig) -> None:
def __init__(self, cfg: DictConfig, device_mesh: DeviceMesh) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
self.device_mesh = device_mesh

if self._dtype == torch.float16:
raise ValueError(
Expand Down Expand Up @@ -268,6 +271,34 @@ def setup(self, cfg: DictConfig) -> None:
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._compile = cfg.get("compile", False)

# Function to remap the base_model_state_dict so we can shard the LoRALinear modules with Tensor Parallel
# Since DTensor currently only shards nn.Linear and nn.Embedding (not LoRALinear)
# we need to remap the original nn.Linear weights in the LoRALinear modules
def remap_base_model_state_dict(base_model_state_dict):
new_state_dict = {}
for k, v in base_model_state_dict.items():
if "q_proj.bias" in k or "output_proj.bias" in k or "v_proj.bias" in k:
new_state_dict[k.replace(".bias", ".linear.bias")] = v
elif (
"q_proj.weight" in k
or "output_proj.weight" in k
or "v_proj.weight" in k
):
new_state_dict[k.replace(".weight", ".linear.weight")] = v
elif "w1.bias" in k or "w2.bias" in k or "w3.bias" in k:
new_state_dict[k.replace(".bias", ".linear.bias")] = v
elif "w1.weight" in k or "w2.weight" in k or "w3.weight" in k:
new_state_dict[k.replace(".weight", ".linear.weight")] = v
else:
new_state_dict[k] = v
return new_state_dict

# Remap the base model state dict
base_model_state_dict = remap_base_model_state_dict(
checkpoint_dict[training.MODEL_KEY]
)
checkpoint_dict.update({training.MODEL_KEY: base_model_state_dict})

self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
Expand Down Expand Up @@ -444,7 +475,7 @@ def _setup_model(

utils.log_rank_zero(
log,
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
"FSDP and TP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
)
init_start = time.perf_counter()

Expand All @@ -470,6 +501,7 @@ def _setup_model(
]
training.shard_model(
model=model,
device_mesh=self.device_mesh,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
Expand Down Expand Up @@ -602,8 +634,17 @@ def _setup_data(
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

# Setup dp_rank and dp_size
dp_mesh = self.device_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()

# Create DistributedSampler with appropriate settings
sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds,
num_replicas=dp_degree, # number of dp ranks
rank=dp_rank, # use dp_rank, not the global rank
shuffle=shuffle,
seed=0,
)

dataloader = DataLoader(
Expand Down Expand Up @@ -657,6 +698,43 @@ def save_checkpoint(
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
state_dict = self._model.state_dict()

# Function to unmap the state_dict that we remapped for TensorParallel Sharding
# This aligns with what the checkpointer expects so we can save the checkpoint
def unmap_base_model_state_dict(base_model_state_dict):
new_state_dict = {}
for k, v in base_model_state_dict.items():
if (
"q_proj.linear.bias" in k
or "output_proj.linear.bias" in k
or "v_proj.linear.bias" in k
):
new_state_dict[k.replace(".linear.bias", ".bias")] = v
elif (
"q_proj.linear.weight" in k
or "output_proj.linear.weight" in k
or "v_proj.linear.weight" in k
):
new_state_dict[k.replace(".linear.weight", ".weight")] = v
elif (
"w1.linear.bias" in k
or "w2.linear.bias" in k
or "w3.linear.bias" in k
):
new_state_dict[k.replace(".linear.bias", ".bias")] = v
elif (
"w1.linear.weight" in k
or "w2.linear.weight" in k
or "w3.linear.weight" in k
):
new_state_dict[k.replace(".linear.weight", ".weight")] = v
else:
new_state_dict[k] = v
return new_state_dict

# Unmap the state dict so we can save the checkpoint
state_dict = unmap_base_model_state_dict(state_dict)

if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

Expand Down Expand Up @@ -921,9 +999,23 @@ def recipe_main(cfg: DictConfig) -> None:
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()

# Get world size and rank to initialize the device mesh
world_size, rank = training.get_world_size_and_rank()
tp_size = 8
if world_size % tp_size != 0:
raise ValueError(
f"World size {world_size} must be divisible by tensor parallel size {tp_size}"
)
dp_size = world_size // tp_size

# Initialize device mesh
device_mesh = init_device_mesh(
"cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp")
)

config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)

recipe = LoRAFinetuneRecipeDistributed(cfg=cfg)
recipe = LoRAFinetuneRecipeDistributed(cfg=cfg, device_mesh=device_mesh)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
Expand Down
26 changes: 6 additions & 20 deletions torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from torch import nn

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.low_precision import _register_nf4_dispatch_ops # noqa: F401
from torchtune.modules.peft import AdapterModule


class LoRALinear(nn.Module, AdapterModule):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_.
"""Modified LoRA Linear config to support Tensor Parallel Sharding with DTensors
(which currently only support sharding nn.Linear and nn.Embedding layers)
LoRA perturbs a given layer via a low-rank approximation where only
the rank decomposition matrices are trainable. In a linear layer instead of
Expand Down Expand Up @@ -70,23 +69,15 @@ def __init__(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)

# Setup weight and bias
linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias)
weight = (
linear.weight
if not self._quantize_base
else to_nf4(linear.weight, **quantization_kwargs)
# Setup weight and bias (these are the original weights that we will be loading from state_dict)
self.linear = nn.Linear(
in_features=in_dim, out_features=out_dim, bias=self.use_bias
)
bias = linear.bias if self.use_bias else None

# 'self.disabled' is a flag showing whether to turn off LoRA adapters,
# this can be used in DPO for treating the lora adapters as the policy model
# and disabling it to treat the base model as the reference model
self.disabled = False
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
Expand Down Expand Up @@ -126,12 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: output tensor with shape ``(..., out_dim)``
"""
if self._quantize_base:
out = linear_nf4(input=x, weight=self.weight)
if self.use_bias:
out = out + self.bias
else:
out = F.linear(x, self.weight, self.bias)
out = self.linear(x)
if self.disabled:
return out
lora_out = self.lora_a(self.dropout(x))
Expand Down
Loading

0 comments on commit 44a26c6

Please sign in to comment.