diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index c1e3e914f40..9b9ad782dd4 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn from torch.fx import GraphModule +from ..core import Config, ParallelExecutionCtx, ParameterMeta from ..distributed import scatter from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from ..passes import ( @@ -30,10 +31,6 @@ ) -if TYPE_CHECKING: - from ..core import Config, ParallelExecutionCtx, ParameterMeta - - class BackEnd(ABC): @abstractmethod def create_column_parallel_linear( @@ -67,7 +64,6 @@ def create_parallel_embedding( ) -> nn.Module: raise NotImplementedError - @abstractmethod def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our @@ -83,11 +79,9 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co param_meta.tied_to = parameter_mp[key] return graph_module - @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: return graph_module - @abstractmethod def init_parallelization_pass_pipeline( self, ) -> PassPipeline: @@ -165,7 +159,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c param_meta: "ParameterMeta" = getattr(param, "meta") # skip tied parameters for now - if param_meta.tied_to is not None: + if param_meta.tied_to is not None and param_meta.tied_to != name: continue shape = [ @@ -234,7 +228,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c # take care of tied parameters for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): param_meta: "ParameterMeta" = getattr(param, "meta") - if param_meta.tied_to is not None: + if param_meta.tied_to is not None and param_meta.tied_to != name: new_parameters.append((name, param_cache[param_meta.tied_to])) for name, new_param in new_parameters: diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 644365e090e..ffd6a5396e6 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -13,28 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # Nanotron specific imports +import importlib.util from collections import defaultdict from typing import TYPE_CHECKING, Optional, Tuple import torch.distributed as dist import torch.nn as nn -from nanotron.config import Config as NanotronConfig -from nanotron.parallel import ParallelContext -from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.nn import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) -from nanotron.parallel.tied_parameters import tie_parameters from torch.fx import GraphModule +from ..core import Config, ParallelExecutionCtx, ParameterMeta from .base import BackEnd +# Check if nanotron is installed +_nanotron_available = importlib.util.find_spec("nanotron") is not None + if TYPE_CHECKING: - from ..core import Config, ParallelExecutionCtx, ParameterMeta + from nanotron.config import Config as NanotronConfig + from nanotron.parallel import ParallelContext + from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + ) class NanotronBackend(BackEnd): @@ -42,7 +43,10 @@ class NanotronBackend(BackEnd): Backend class which glues optimum fx parallelization context and nanotron context. """ - def __init__(self, nanotron_config: NanotronConfig, nanotron_context: ParallelContext) -> None: + def __init__(self, nanotron_config: "NanotronConfig", nanotron_context: "ParallelContext") -> None: + if not _nanotron_available: + raise ImportError("Nanotron is not installed. Please install it to use NanotronBackend.") + self.config = nanotron_config self.context = nanotron_context @@ -53,7 +57,10 @@ def create_column_parallel_linear( sequence_parallel: bool, gather_output: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelColumnLinear: + ) -> "TensorParallelColumnLinear": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear + if gather_output: raise ValueError( "Nanotron backend does not support `gather_output=True` in `TensorParallelColumnLinear` yet" @@ -84,7 +91,10 @@ def create_row_parallel_linear( sequence_parallel: bool, input_is_parallel: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelRowLinear: + ) -> "TensorParallelRowLinear": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear + if not input_is_parallel: raise ValueError( "Nanotron backend does not support `input_is_parallel=True` in `TensorParallelRowLinear` yet" @@ -114,7 +124,10 @@ def create_parallel_embedding( parallel_ctx: "ParallelExecutionCtx", sequence_parallel: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelEmbedding: + ) -> "TensorParallelEmbedding": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding + if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER: raise ValueError( "`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend" @@ -139,6 +152,9 @@ def create_parallel_embedding( def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: + from nanotron.parallel.parameters import NanotronParameter + from nanotron.parallel.tied_parameters import tie_parameters + param_cache, tied_parameter_groups = parallel_ctx.param_cache, defaultdict(list) for name, param in graph_module.named_parameters(): param_meta: "ParameterMeta" = getattr(param, "meta") diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 19a70a5a7a3..0fb8517fd99 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -14,13 +14,15 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn -from .backend import BackEnd, DefaultBackend + +if TYPE_CHECKING: + from .backend import BackEnd class HashableSlice: @@ -113,7 +115,7 @@ class ParallelExecutionCtx: - current_device (`torch.device`): Device correpsonding to the current process. - - backend (`BackEnd`, defaults to `DefaultBackEnd`): + - backend (`Optional[BackEnd]`, defaults to `None`): Backend instance which converts layers into their parallelized counterparts. - example_inputs (`List[Any]`): @@ -144,7 +146,7 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device - backend: BackEnd = DefaultBackend() + backend: Optional["BackEnd"] = None example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) param_cache: Dict[str, nn.Parameter] = field(default_factory=dict) @@ -152,6 +154,12 @@ class ParallelExecutionCtx: last_optimized_module: Optional[nn.Module] = None compile_times: int = 0 + def __post_init__(self): + if self.backend is None: + from .backend import DefaultBackend + + self.backend = DefaultBackend() + @dataclass class Config: diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 9a995eb2784..ba44f5f15a2 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING - import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from ..core import ParallelExecutionCtx from ..distributed import ( differentiable_all_gather, differentiable_all_reduce_sum, @@ -28,10 +27,6 @@ from ..utils import ensure_divisibility -if TYPE_CHECKING: - from ..core import ParallelExecutionCtx - - class ColumnParallelLinear(nn.Module): """ Linear layer with column parallelism. diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 64ce80b78fb..f8aab7a912f 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -80,7 +80,7 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) logits = model(**inputs)[0] tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) @@ -106,7 +106,7 @@ def run_test_parameters_persist_bewteen_recompile( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) # different shape to trigger recompile @@ -141,7 +141,7 @@ def run_test_parallel_results_matches_non_parallel( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) set_seed(SEED) @@ -153,7 +153,7 @@ def run_test_parallel_results_matches_non_parallel( tp_group = dist.new_group() set_seed(SEED) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) parallel_logits = model(**inputs)[0] torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) @@ -169,7 +169,7 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) model(**inputs)