Skip to content

Commit

Permalink
Fixes consolidation issue when TP is enabled (#739)
Browse files Browse the repository at this point in the history
Copy tensor model attributes when initializing
  • Loading branch information
michaelbenayoun authored Nov 22, 2024
1 parent dd60749 commit d78c7c7
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def _initialize_or_load_weights(
from neuronx_distributed import parallel_layers
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank
from neuronx_distributed.parallel_layers.utils import copy_tensor_model_parallel_attributes

weight_map = getattr(model, "_weight_map", {})
with torch.no_grad():
Expand Down Expand Up @@ -389,6 +390,7 @@ def _initialize_or_load_weights(
if device is not None:
weight_data = weight_data.to(device)
new_parameter = torch.nn.Parameter(weight_data)
copy_tensor_model_parallel_attributes(new_parameter, parameter)
elif parameter.device != torch.device("meta") and (
was_already_initialized_during_parallelization(parameter)
or not parameter_can_be_initialized(model, module, attribute_name)
Expand All @@ -401,6 +403,7 @@ def _initialize_or_load_weights(
# We first create the module on CPU, initialize it and then move it on device if needed.
device = torch.device("cpu")
new_parameter = torch.nn.Parameter(torch.empty_like(parameter, device=device))
copy_tensor_model_parallel_attributes(new_parameter, parameter)
modules_to_initialize[module].append(attribute_name)

setattr(
Expand Down

0 comments on commit d78c7c7

Please sign in to comment.