Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Sep 3, 2024
1 parent d68df89 commit c752e29
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
14 changes: 4 additions & 10 deletions optimum/fx/parallelization/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule

from ..core import Config, ParallelExecutionCtx, ParameterMeta
from ..distributed import scatter
from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from ..passes import (
Expand All @@ -30,10 +31,6 @@
)


if TYPE_CHECKING:
from ..core import Config, ParallelExecutionCtx, ParameterMeta


class BackEnd(ABC):
@abstractmethod
def create_column_parallel_linear(
Expand Down Expand Up @@ -67,7 +64,6 @@ def create_parallel_embedding(
) -> nn.Module:
raise NotImplementedError

@abstractmethod
def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule:
"""
Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our
Expand All @@ -83,11 +79,9 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co
param_meta.tied_to = parameter_mp[key]
return graph_module

@abstractmethod
def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module:
return graph_module

@abstractmethod
def init_parallelization_pass_pipeline(
self,
) -> PassPipeline:
Expand Down Expand Up @@ -165,7 +159,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c

param_meta: "ParameterMeta" = getattr(param, "meta")
# skip tied parameters for now
if param_meta.tied_to is not None:
if param_meta.tied_to is not None and param_meta.tied_to != name:
continue

shape = [
Expand Down Expand Up @@ -234,7 +228,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c
# take care of tied parameters
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
param_meta: "ParameterMeta" = getattr(param, "meta")
if param_meta.tied_to is not None:
if param_meta.tied_to is not None and param_meta.tied_to != name:
new_parameters.append((name, param_cache[param_meta.tied_to]))

for name, new_param in new_parameters:
Expand Down
46 changes: 31 additions & 15 deletions optimum/fx/parallelization/backend/nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Nanotron specific imports
import importlib.util
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Tuple

import torch.distributed as dist
import torch.nn as nn
from nanotron.config import Config as NanotronConfig
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import tie_parameters
from torch.fx import GraphModule

from ..core import Config, ParallelExecutionCtx, ParameterMeta
from .base import BackEnd


# Check if nanotron is installed
_nanotron_available = importlib.util.find_spec("nanotron") is not None

if TYPE_CHECKING:
from ..core import Config, ParallelExecutionCtx, ParameterMeta
from nanotron.config import Config as NanotronConfig
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)


class NanotronBackend(BackEnd):
"""
Backend class which glues optimum fx parallelization context and nanotron context.
"""

def __init__(self, nanotron_config: NanotronConfig, nanotron_context: ParallelContext) -> None:
def __init__(self, nanotron_config: "NanotronConfig", nanotron_context: "ParallelContext") -> None:
if not _nanotron_available:
raise ImportError("Nanotron is not installed. Please install it to use NanotronBackend.")

self.config = nanotron_config
self.context = nanotron_context

Expand All @@ -53,7 +57,10 @@ def create_column_parallel_linear(
sequence_parallel: bool,
gather_output: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> TensorParallelColumnLinear:
) -> "TensorParallelColumnLinear":
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear

if gather_output:
raise ValueError(
"Nanotron backend does not support `gather_output=True` in `TensorParallelColumnLinear` yet"
Expand Down Expand Up @@ -84,7 +91,10 @@ def create_row_parallel_linear(
sequence_parallel: bool,
input_is_parallel: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> TensorParallelRowLinear:
) -> "TensorParallelRowLinear":
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear

if not input_is_parallel:
raise ValueError(
"Nanotron backend does not support `input_is_parallel=True` in `TensorParallelRowLinear` yet"
Expand Down Expand Up @@ -114,7 +124,10 @@ def create_parallel_embedding(
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> TensorParallelEmbedding:
) -> "TensorParallelEmbedding":
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding

if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER:
raise ValueError(
"`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend"
Expand All @@ -139,6 +152,9 @@ def create_parallel_embedding(
def post_process(
self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config"
) -> nn.Module:
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tied_parameters import tie_parameters

param_cache, tied_parameter_groups = parallel_ctx.param_cache, defaultdict(list)
for name, param in graph_module.named_parameters():
param_meta: "ParameterMeta" = getattr(param, "meta")
Expand Down
16 changes: 12 additions & 4 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# limitations under the License.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn

from .backend import BackEnd, DefaultBackend

if TYPE_CHECKING:
from .backend import BackEnd


class HashableSlice:
Expand Down Expand Up @@ -113,7 +115,7 @@ class ParallelExecutionCtx:
- current_device (`torch.device`):
Device correpsonding to the current process.
- backend (`BackEnd`, defaults to `DefaultBackEnd`):
- backend (`Optional[BackEnd]`, defaults to `None`):
Backend instance which converts layers into their parallelized counterparts.
- example_inputs (`List[Any]`):
Expand Down Expand Up @@ -144,14 +146,20 @@ class ParallelExecutionCtx:

tp_group: dist.ProcessGroup
current_device: torch.device
backend: BackEnd = DefaultBackend()
backend: Optional["BackEnd"] = None
example_inputs: List[Any] = field(default_factory=list)
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
param_cache: Dict[str, nn.Parameter] = field(default_factory=dict)
weight_map: Dict[str, str] = field(default_factory=dict)
last_optimized_module: Optional[nn.Module] = None
compile_times: int = 0

def __post_init__(self):
if self.backend is None:
from .backend import DefaultBackend

self.backend = DefaultBackend()


@dataclass
class Config:
Expand Down
7 changes: 1 addition & 6 deletions optimum/fx/parallelization/parallel_layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from ..core import ParallelExecutionCtx
from ..distributed import (
differentiable_all_gather,
differentiable_all_reduce_sum,
Expand All @@ -28,10 +27,6 @@
from ..utils import ensure_divisibility


if TYPE_CHECKING:
from ..core import ParallelExecutionCtx


class ColumnParallelLinear(nn.Module):
"""
Linear layer with column parallelism.
Expand Down
10 changes: 5 additions & 5 deletions tests/fx/parallelization/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m
device = torch.device(type="cuda", index=torch.cuda.current_device())
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)

model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs)
inputs = prepare_dummy_inputs(model.config)
logits = model(**inputs)[0]
tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size)
Expand All @@ -106,7 +106,7 @@ def run_test_parameters_persist_bewteen_recompile(
device = torch.device(type="cuda", index=torch.cuda.current_device())
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)

model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs)
inputs = prepare_dummy_inputs(model.config)

# different shape to trigger recompile
Expand Down Expand Up @@ -141,7 +141,7 @@ def run_test_parallel_results_matches_non_parallel(
device = torch.device(type="cuda", index=torch.cuda.current_device())
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)

model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs)
inputs = prepare_dummy_inputs(model.config)

set_seed(SEED)
Expand All @@ -153,7 +153,7 @@ def run_test_parallel_results_matches_non_parallel(
tp_group = dist.new_group()
set_seed(SEED)
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)
model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs)
parallel_logits = model(**inputs)[0]

torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4)
Expand All @@ -169,7 +169,7 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode
# prepare config and context
device = torch.device(type="cuda", index=torch.cuda.current_device())
ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device)
model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs)
model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs)

inputs = prepare_dummy_inputs(model.config)
model(**inputs)
Expand Down

0 comments on commit c752e29

Please sign in to comment.