From 50fcfc012a751a566b7a54f0e4dd365721ff0f37 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 14 Aug 2024 03:05:25 +0200 Subject: [PATCH 01/24] modify parallelization strategy --- optimum/fx/parallelization/core.py | 5 + optimum/fx/parallelization/decomp.py | 189 ++++++++ .../parallelization/op_registry/__init__.py | 15 + .../op_registry/op_handlers.py | 445 ++++++++++++++++++ optimum/fx/parallelization/passes.py | 350 ++++++-------- optimum/fx/parallelization/utils.py | 29 +- 6 files changed, 793 insertions(+), 240 deletions(-) create mode 100644 optimum/fx/parallelization/decomp.py create mode 100644 optimum/fx/parallelization/op_registry/__init__.py create mode 100644 optimum/fx/parallelization/op_registry/op_handlers.py diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index cba7d454441..fafa30f2e7e 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -160,8 +160,13 @@ class Config: - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, if not provided weights loading path. + + - enable_sequence_parallel (`bool`, defaults to `False`): + Whether to enable Megatron-style sequence parallelism in searching parallelization + strategies. """ lint_and_recompile: bool = True clean_markers_after_all_passes: bool = True weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) + enable_sequence_parallel: bool = False diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py new file mode 100644 index 00000000000..90dfc1bf129 --- /dev/null +++ b/optimum/fx/parallelization/decomp.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Callable, Dict, List + +import torch +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import SymBool, SymFloat, SymInt +from torch._decomp import core_aten_decompositions +from torch._functorch._aot_autograd.functional_utils import from_fun, to_fun +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, disable_functional_mode +from torch.fx import Graph, GraphModule, Interpreter, Proxy, traceback +from torch.fx.experimental.proxy_tensor import ( + ProxyTorchDispatchMode, + _ProxyTensor, + _SymNodeDict, + decompose, + disable_proxy_modes_tracing, + fetch_object_proxy, + fetch_sym_proxy, + get_proxy_slot, + track_tensor_tree, +) +from torch.fx.proxy import GraphAppendingTracer +from torch.utils.weak import WeakTensorKeyDictionary + + +def is_leaf_module(m): + return (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) and not isinstance( + m, torch.nn.Sequential + ) + + +@contextlib.contextmanager +def trace_decomp_origin(): + creat_node = Graph.create_node + + def create_node_(*args, **kwargs): + node = creat_node(*args, **kwargs) + node.meta["traced_from"] = traceback.get_current_meta()["from_node"] + return node + + try: + Graph.create_node = create_node_ + yield + finally: + Graph.create_node = creat_node + + +class DecompTracer(GraphAppendingTracer): + def __init__(self, graph: Graph): + super().__init__(graph) + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + + +class DecompositionInterpreter(Interpreter): + def __init__( + self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs + ): + super().__init__(module, **kwargs) + self.new_graph = new_graph + self.tracer = DecompTracer(new_graph) + + self.decomposition_table = decomposition_table + if self.decomposition_table is None: + self.decomposition_table = {} + + self.leaf_function_targets = leaf_function_targets + if self.leaf_function_targets is None: + self.leaf_function_targets = [] + + self.fun_mode = FunctionalTensorMode() + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder(self, target, args, kwargs): + out = super().placeholder(target, args, kwargs) + out = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), out) + proxy = self.tracer.create_proxy("placeholder", target, args, kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + # TODO handle case where the first character of target is '*' + return out + + def call_function(self, target, args, kwargs): + if target in self.leaf_function_targets: + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = target(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_function", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + return super().call_function(target, args, kwargs) + + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + submod = self.fetch_attr(target) + if not is_leaf_module(submod): + return super().call_module(target, args, kwargs) + + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = submod(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_module", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + def get_attr(self, target, args, kwargs): + out = super().get_attr(target, args, kwargs) + proxy = Proxy(self.new_graph.get_attr(target), self.tracer) + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + def output(self, target, args, kwargs): + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + out = super().output(target, args, kwargs) + + def unwrap(e): + return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args, **kwargs): + with self.fun_mode: + args = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), args) + kwargs = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), kwargs) + with traceback.preserve_node_meta(), trace_decomp_origin(), decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) + + +def decompose_and_functionalize( + graph_module: GraphModule, + decomposition_table: Dict = core_aten_decompositions(), + leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], +) -> Callable: + new_graph = Graph(owning_module=graph_module) + interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets) + + def wrapper(*args, **kwargs): + interp.run(*args, **kwargs) + return new_graph + + return wrapper diff --git a/optimum/fx/parallelization/op_registry/__init__.py b/optimum/fx/parallelization/op_registry/__init__.py new file mode 100644 index 00000000000..8f8df0f7bd0 --- /dev/null +++ b/optimum/fx/parallelization/op_registry/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .op_handlers import REGISTRY, FallbackParallelAxisPropagateHandler diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py new file mode 100644 index 00000000000..e5113537ff3 --- /dev/null +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -0,0 +1,445 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from typing import Any, List, Optional + +import torch +from torch.fx import Node + +from ..core import Config +from ..utils import is_activation, is_embedding, is_linear + + +class Registry: + def __init__(self) -> None: + self.mapping = {} + + def register(self, op_types): + def wrapper(cls): + if isinstance(op_types, (list, tuple)): + for op_type in op_types: + self.mapping[op_type] = cls + else: + self.mapping[op_types] = cls + return cls + + return wrapper + + def is_supported(self, op_type) -> bool: + return op_type in self.mapping + + +REGISTRY = Registry() + + +class OpParallelAxisPropagateHandler: + def __init__(self, node: Node, meta_key: str, config: Config) -> None: + self.node = node + self.meta_key = meta_key + self.config = config + + def extract_axis(self, arg: Any) -> Optional[int]: + if not isinstance(arg, Node): + return None + return arg.meta[self.meta_key].get("parallel_axis", None) + + @abstractmethod + def propagate(self) -> List[int]: + raise NotImplementedError + + +@REGISTRY.register( + [ + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.rsqrt.default, + torch.ops.aten.clone.default, + torch.ops.aten.bitwise_not.default, + torch.ops.aten.abs.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.acos.default, + torch.ops.aten.acosh.default, + torch.ops.aten.alias.default, + torch.ops.aten.asin.default, + torch.ops.aten.asinh.default, + torch.ops.aten.atan.default, + torch.ops.aten.atanh.default, + torch.ops.aten.ceil.default, + torch.ops.aten.clamp.default, + torch.ops.aten.cos.default, + torch.ops.aten.cosh.default, + torch.ops.aten.erf.default, + torch.ops.aten.exp.default, + torch.ops.aten.trunc.default, + torch.ops.aten.tanh.default, + torch.ops.aten.tan.default, + torch.ops.aten.add.Scalar, + torch.ops.aten.sub.Scalar, + torch.ops.aten.sqrt.default, + torch.ops.aten.sin.default, + torch.ops.aten.sinh.default, + torch.ops.aten.sign.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.round.default, + torch.ops.aten.remainder.Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.reciprocal.default, + torch.ops.aten.neg.default, + torch.ops.aten.ne.Scalar, + torch.ops.aten.native_dropout.default, + torch.ops.aten.mul.Scalar, + torch.ops.aten.logical_not.default, + torch.ops.aten.lt.Scalar, + torch.ops.aten.le.Scalar, + torch.ops.aten.log.default, + torch.ops.aten.log10.default, + torch.ops.aten.log2.default, + torch.ops.aten.log1p.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.isnan.default, + torch.ops.aten.isinf.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.gt.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.ge.Scalar, + torch.ops.aten.fmod.Scalar, + torch.ops.aten.floor.default, + torch.ops.aten.fill.Scalar, + torch.ops.aten.div.Scalar_mode, + torch.ops.aten.div.Scalar, + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] +) +class UnaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> bool: + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.atan2.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.eq.Tensor, + torch.ops.aten.fmod.Tensor, + torch.ops.aten.ge.Tensor, + torch.ops.aten.gt.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.logical_and.default, + torch.ops.aten.logical_or.default, + torch.ops.aten.logical_xor.default, + torch.ops.aten.lt.Tensor, + torch.ops.aten.maximum.default, + torch.ops.aten.minimum.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.ne.Tensor, + torch.ops.aten.pow.Tensor_Tensor, + torch.ops.aten.remainder.Tensor, + torch.ops.aten.sub.Tensor, + ] +) +class BinaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + input_nodes = self.node.all_input_nodes + # only one node + if len(input_nodes) == 1: + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + assert len(input_nodes) == 2, "binary op should have exact two nodes as inputs" + lhs_shape, rhs_shape = input_nodes[0].meta["val"].shape, input_nodes[1].meta["val"].shape + lhs_axis = self.extract_axis(input_nodes[0]) + rhs_axis = self.extract_axis(input_nodes[1]) + i, j = len(lhs_shape) - 1, len(rhs_shape) - 1 + while i >= 0 and j >= 0: + k = max(lhs_shape[i], rhs_shape[j]) + assert ( + k % min(lhs_shape[i], rhs_shape[j]) == 0 + ), f"shape {lhs_shape} and {rhs_shape} are not broadcastable!" + i -= 1 + j -= 1 + + if i < 0 and lhs_axis is not None: + lhs_axis += j + 1 + if j < 0 and rhs_axis is not None: + rhs_axis += i + 1 + + if lhs_axis is None: + return [rhs_axis] + elif rhs_axis is None: + return [lhs_axis] + elif lhs_axis != rhs_axis: + return [] + return [lhs_axis] + + +@REGISTRY.register( + [ + torch.ops.aten.amax.default, + torch.ops.aten.amin.default, + torch.ops.aten.any.dim, + torch.ops.aten._log_softmax.default, + torch.ops.aten._softmax.default, + torch.ops.aten.cumsum.default, + torch.ops.aten.mean.dim, + # torch.ops.aten.min.dim, + # torch.ops.aten.max.dim, + torch.ops.aten.var.dim, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.prod.dim_int, + ] +) +class ReductionOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def extract_dims( + self, + ) -> List[int]: + ndim = self.node.meta["val"].ndim + dims = None + if "dim" in self.node.kwargs: + dims = self.node.kwargs["dim"] + elif len(self.node.args) > 1 and isinstance(self.node.args[1], (int, list)): + dims = self.node.args[1] + + if isinstance(dims, int): + dims = [dims] + if not dims: + dims = list(range(ndim)) + dims = [(dim + ndim) % ndim for dim in dims] + + keepdim = False + if "keepdim" in self.node.kwargs: + keepdim = self.node.kwargs + elif len(self.node.args) > 2 and isinstance(self.node.args[2], bool): + keepdim = self.node.args[2] + + return dims, keepdim + + def propagate(self) -> List[int]: + dims, keepdim = self.extract_dims() + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + if axis in dims: + return [] + if axis is None: + return [None] + if keepdim: + return [axis] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.view.default) +class ViewLikeOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg = self.node.args[0] + axis = self.extract_axis(arg) + if axis is None: + return [None] + shape_before, shape_after = arg.meta["val"].shape, self.node.meta["val"].shape + size = 1 + for i in range(len(shape_before) - 1, axis - 1, -1): + size *= shape_before[i] + + cur, i, res = 1, len(shape_after) - 1, [] + while cur <= size and i >= 0: + cur *= shape_after[i] + if cur == size: + res.append(i) + i -= 1 + + return res + + +@REGISTRY.register(torch.ops.aten.unsqueeze.default) +class UnsqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dim = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + dim = (dim + ndim) % ndim + if dim <= axis: + return [axis + 1] + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + ] +) +class SqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + + ndim = self.node.args[0].meta["val"].ndim + if isinstance(dims, int): + dims = [dims] + dims = [(dim + ndim) % ndim for dim in dims] + if axis in dims: + # being conservative + return [] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.permute.default) +class PermuteParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + + for i, dim in enumerate(dims): + if (dim + ndim) % ndim == axis: + return [i] + return [] + + +@REGISTRY.register(torch.ops.aten.slice.Tensor) +class SliceParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, slice_dim = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + ndim = arg.meta["val"].ndim + slice_dim = (slice_dim + ndim) % ndim + if slice_dim == axis: + # slice on the parallel axis is not allowed + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.expand.default) +class ExpandParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, size = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + assert len(size) >= arg.meta["val"].ndim, "input size must be broadcastable to the target size in expand" + return [axis + len(size) - arg.meta["val"].ndim] + + +@REGISTRY.register(torch.ops.aten.cat.default) +class CatParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + nodes, cat_axis = self.node.all_input_nodes, self.node.args[1] + axis, ndim = self.extract_axis(nodes[0]), nodes[0].meta["val"].ndim + cat_axis = (cat_axis + ndim) % ndim + if cat_axis == axis: + return [] + for i in range(1, len(nodes)): + if self.extract_axis(nodes[i]) != axis: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.constant_pad_nd.default) +class PadParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + pad, ndim = self.node.args[1], self.node.args[0].meta["val"].ndim + axis = self.extract_axis(self.node.args[0]) + if axis is None: + return [None] + if axis >= ndim - pad // 2: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.copy.default) +class CopyParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + dst, src = self.node.all_input_nodes + axis_dst = self.extract_axis(dst) + axis_src = self.extract_axis(src) + if axis_dst != axis_src: + return [] + return [axis_dst] + + +@REGISTRY.register(torch.nn.functional.scaled_dot_product_attention) +class SpdaAttnParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + q, k, v = self.node.args[:3] + q_axis = self.extract_axis(q) + # parallel axis must be the head dimension if being parallelized + if q_axis != self.extract_axis(k) or q_axis != self.extract_axis(v) or q_axis not in {None, 1}: + return [] + return [q_axis] + + +class FallbackParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + # by default we don't parallelize inputs and constants(except parameters embeded in modules) + if self.node.op in ["placeholder", "get_attr"]: + return [None] + elif self.node.op == "output": + for node in self.node.all_input_nodes: + # TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather + # parallelized output if intructed + if self.extract_axis(node) is not None: + return [] + return [None] + elif is_linear(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # with input being not parallelized, output can be parallelized on the head dimension, + # i.e., `ColumnLinear`, or not being parallelized by all-gather at the end + return [2, None] + elif self.config.enable_sequence_parallel and axis == 1: + # with input being parallelized on sequence dimension, output can be parallelized on + # the head dimension, i.e., `ColumnLinear` with sequence parallel, or not being parallelized + # by all-gather at the end + return [2, None] + elif axis == 2: + # with input being parallelized on head dimension, output can be parallelized on the + # sequence dimension or not parallelized by all-reduce at the end, i.e., `RowLinear` + # when sp is not enabled + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_embedding(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # only support the embedding parameter being parallelized on `vocab` dim or not parallelized for now, + # the output can be parallelized on sequence dim or not parallelized + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_activation(self.node): + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + # last resort, if no input is being parallelized, then we make output also not parallelized, + # this will give us relief on writing policies for strange ops which don't actually need + # parallelization in most cases + if all([self.extract_axis(arg) is None for arg in self.node.all_input_nodes]): + return [None] + + raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}") diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 1b25e9e1233..aee14811627 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -23,15 +23,14 @@ from torch.fx import Graph, GraphModule, Node from .core import Config, ParallelExecutionCtx, ParameterMeta +from .decomp import decompose_and_functionalize from .distributed import scatter +from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .utils import ( is_embedding, is_linear, - is_permute, is_shape_consumer, - is_shape_generator, - is_transpose, stable_topological_sort, ) @@ -135,238 +134,151 @@ def clean_all(self, graph_module: GraphModule) -> None: self.clear_marker_per_node(node) -class ParallelLayerAnnotatePass(AnalyzeBase): +class ParallelAxisSolverPass(AnalyzeBase): """ - A pass which tries to automatically identify parallel layers in the graph. Note that for simplicity - we only consider classical ways of parallelizing layers in transformers architecture for now, we are not - solving an optimization problem which tries to give a best solution of parallelizing any model under - memory/hardware constraints. - - For `nn.Embedding` layers, we parallelize them on the vocabulary dim by default, because they are often tied - to the `lm_head` of the model, which is usually a `ColumnLinear`(parallelized on vocab dim). - - For `nn.Linear` layers, we parallelize them by grouping them as `upstream` nodes and `downstream` nodes, and - `upstream` nodes are marked as `ColumnLinear`, `downstream` nodes are marked as `RowLinear`. - - Typical examples in transformer models: - - Attention Bert-style MLP Llama-style MLP - __________________________________________________________________________ - Linear Linear Linear Linear - \\ / | \\ --> upstream - Matmul Linear Activation Activation Linear - __________________________________________________________________________ - \\ / | \\ / - \\ / ___________ \\ / - Matmul / Linear \ Mul - | / \ | - _______________________________/ \___________________________ - Linear Linear --> downstream - - Note that there are some patterns that can not be clearly marked, like this one: - - Linear - | \\ - | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` - | / - Add - | - Linear - - For patterns like this we will be conservative and raise errors directly because we don't know how to parallelize - it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution - even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution - should work fairly well. + A pass which tries to automatically identify parallel layers in the graph. There are three steps + involved to find a possible parallel solution given the traced graph module and process group. + + - Decompostion & Functionalization + The vanilla graph traced by dynamo frontend is a high-level graph which contains high-level + pytorch ops, and there could be thousands of them, which makes graph analysis hard in order + to cover all cases. So we decompose the high-level graph into low-level graph which only + conrtains core aten ops, which is a much smaller set. And functionalization is also needed + to remove inplace ops in the graph so that we get `aten.Add` instead of `aten.Add_` in the + graph, which furthur reduces the op set we need to consider. + + - Parallel Axis Propagation + We need to write parallel axis propagation rules for aten ops in the decomposed and functionalized + graph, note that we don't need to cover every possible parallelization strategy because in general + only certain ops(usually involves computation) can be parallelized in transformer models. And we just + need to write rules for a subset of core aten op set in order to support most of the transformer models. + + - Backtracking Search + After we have defined parallel axis propagation rules for each op in the graph, we do a brute force + backtracking search to try to find a possible solution which respects the propagation rule of every + op in the graph. + + + Note that there are several practical concerns + + - Time Complexity. Although brute force backtracking introduces an exponential time complexity, we reduces + the search space by injecting human heuristics. First, we only consider parallelization on the head dimension + (for tensor parallel) or the sequence dimension(to support sequence parallel), then at any time the tensor is + parallelized on at most one dimension. Second, we only allow axis switch around certain layers(like `nn.Linear` + or `nn.Embedding), and all other ops fall into their places by the parallel axis of their input and rules we write. + + - Optimal Solution. Note that since we return the first solution we find, then it might not be optimal in terms of + memory consumption and communication overhead. But again we can adjust the order of search and try parallelize + as much as we can first before fall back to non-parallelized search paths. And we don't pay too much attention + on calculating communication overhead because in practice they are bounded by number of certain layers in the graph + under the constraint that only certain layers are allowed to communicate. + + Our goal is not to solve an optimization problem which tries to give a best solution of parallelizing any model under memory/hardware + constraints, but rather a cheap solution which relieves you from writing boilerplate code for parallelizing layers of different models. """ - def try_form_parallel_linear_groups(self, linear: Node) -> None: - """ - We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down - recusively to find all the potential `downstream` linears, note that once we have reached a linear, the recursion stops. - And the newly found `downstream` linears are used as new seeds to traverse upwards to find all the potential `upstream` - linears, the process goes on until number of linears on both sides converges. - Args: - linear (Node): the first linear node used as `upstream` node seed to form closure. - - Raises: - RuntimeError: - raises runtime error when the pattern itself is not clear, there are no clear boundaries that can be drawn. - """ - upstream_nodes, downstream_nodes = {linear}, set() - - seeds, next_seeds = [(linear, "down")], [] - - def traverse(start: Node, cur: Node, direction: str = "down"): - if is_linear(cur) and cur is not start: - if direction == "up" and cur not in upstream_nodes: - upstream_nodes.add(cur) - next_seeds.append((cur, "down")) - elif direction == "down" and cur not in downstream_nodes: - downstream_nodes.add(cur) - next_seeds.append((cur, "up")) - return - - next_nodes = cur.all_input_nodes if direction == "up" else cur.users - for node in next_nodes: - # we should ignore shape-related dependencies - if is_shape_generator(node): - continue - traverse(start, node, direction) - - while seeds: - next_seeds = [] - for node, direction in seeds: - traverse(start=node, cur=node, direction=direction) - seeds = next_seeds - - if any(self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)) or ( - upstream_nodes & downstream_nodes - ): - raise RuntimeError( - "Failed to automatically group and parallelize ops in graph in greedy way: " - "no clear boudaries between `upstream` and `downstream` ops." - ) - - for node in upstream_nodes: - self.place_marker_per_node(node, {"axis": "column", "gather_output": False if downstream_nodes else True}) - - for node in downstream_nodes: - self.place_marker_per_node(node, {"axis": "row", "input_is_parallel": True}) + def trace_back(self, graph_module: GraphModule, decomp_graph: Graph) -> None: + node_map = {node.name: node for node in graph_module.graph.nodes} + + for node in decomp_graph.nodes: + if "traced_from" in node.meta: + node_name, _ = node.meta["traced_from"][0] + assert node_name in node_map, f"un-recognized node origin {node_name} not in graph being traced" + orig_node = node_map[node_name] + self.clear_marker_per_node(orig_node) + self.place_marker_per_node( + orig_node, {"parallel_axis": self.get_stored_field_info(node, field="parallel_axis")} + ) def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - graph: Graph = graph_module.graph + graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs) stable_topological_sort(graph) - for node in graph.nodes: - if is_linear(node) and not self.already_executed_per_node(node): - self.try_form_parallel_linear_groups(node) - elif is_embedding(node): - # directly mark `nn.Embedding` layers - self.place_marker_per_node(node, {"axis": "vocab"}) - return graph_module + nodes = [node for node in graph.nodes] + def search(idx: int): + if idx == len(nodes): + return True -class ParallelAxisPropagationPass(AnalyzeBase): - """ - A pass which tries to track which axis is being parallelized in the dataflow. For transformer models, the - axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for - Q and K matrices which need to swap the sequence length axis and head axis to do the attention computation, - so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel - axis after these operations. - """ + node = nodes[idx] + if node.op == "call_function" and REGISTRY.is_supported(node.target): + prop_cls = REGISTRY.mapping[node.target] + else: + prop_cls = FallbackParallelAxisPropagateHandler - def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: - dims = node.meta["example_value"].dim() - if "dim0" in node.kwargs and "dim1" in node.kwargs: - dim0, dim1 = node.kwargs["dim0"], node.kwargs["dim1"] - elif len(node.args) == 3: - dim0, dim1 = node.args[1:] - - dim0 = (dim0 + dims) % dims - dim1 = (dim1 + dims) % dims - - if dim0 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim1}) - return True - elif dim1 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim0}) - return True - return False - - def propagate_permute(self, node: Node, parallel_axis: int) -> bool: - if "dims" in node.kwargs: - dims = node.kwargs["dims"] - else: - dims = ( - list(node.args[1]) - if isinstance(node.args[1], tuple) - else [arg for arg in node.args if isinstance(arg, int)] - ) + prop = prop_cls(node, self.meta_key(), config) + axis_candidates = prop.propagate() + for axis in axis_candidates: + self.place_marker_per_node(node, {"parallel_axis": axis}) + if search(idx + 1): + return True + self.clear_marker_per_node(node) - dim_len = node.meta["example_value"].dim() - dims = [dim + dim_len if dim < 0 else dim for dim in dims] + return False - for i, dim in enumerate(dims): - if dim == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": i}) - return True - return False - - def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: - slices = node.args[1] - dims = node.meta["example_value"].dim() - assert parallel_axis < dims - inc, i, j = 0, 0, 0 - - while i < parallel_axis and j < len(slices): - if isinstance(slices[j], int): - inc -= 1 - i += 1 - elif slices[j] is None: - inc += 1 - elif slices[j] is Ellipsis: - i = dims - k = j - while k < len(slices): - if slices[k] is not Ellipsis: - i -= 1 - k += 1 - else: - i += 1 - j += 1 + if not search(0): + raise RuntimeError("Failed to find a solution to automatically parallelize ops in graph in greedy way.") - if inc != 0: - assert parallel_axis + inc < dims and parallel_axis + inc >= 0 - self.place_marker_per_node(node, {"parallel_axis": parallel_axis + inc}) - return True - return False + self.trace_back(graph_module, graph) + return graph_module + + +class ParallelLayerAnnotatePass(AnalyzeBase): + """ + This pass annotates layers which have different parallel axis(requires communication inside the layer) in their + input and output tensors. Since heuristics applied during the searching process respect traditional classical ways of + parallelizing layers(like Megatron-style `ColumnLinear` or `RowLinear`), we are guaranteed to match a valid replacement + annotation according to parallelization strategy of input and output tensors. + """ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - g: Graph = graph_module.graph - stable_topological_sort(g) + for node in graph_module.graph.nodes: + if is_linear(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + info = {} + if axis_before is None: + info["axis"] = "column" + info["gather_output"] = True if axis_after is None else False + elif axis_before == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["axis"] = "column" + info["sequence_parallel"] = True + info["gather_output"] = True if axis_after is None else False + elif axis_before == 2: + info["axis"] = "row" + info["input_is_parallel"] = True + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True + else: + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) - for node in g.nodes: - if ParallelLayerAnnotatePass.already_executed_per_node(node): - # start propagating at ColumnLinear, marking the beginning of parallelized region - axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis", must_have=True) - gather_output = ParallelLayerAnnotatePass.get_stored_field_info(node, field="gather_output") - if axis == "column" and not gather_output: - self.place_marker_per_node(node, {"parallel_axis": 2}) - # stop propagating at RowLinear, concluding the ending of parallelized region + elif is_embedding(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + assert axis_before is None and axis_after in [1, None] + info = {"axis": "vocab"} + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True else: - continue - else: - already_marked_args, parallel_axis = [], None - for arg in node.all_input_nodes: - if not self.already_executed_per_node(arg): - continue - if parallel_axis is None: - parallel_axis = self.get_stored_field_info(arg, field="parallel_axis", must_have=True) - else: - assert parallel_axis == self.get_stored_field_info( - arg, field="parallel_axis", must_have=True - ), "`parallel_axis` should be equal for all arguments in any related ops" - already_marked_args.append(arg) - - if not already_marked_args: - continue - - marked = False - if is_transpose(node): - marked = self.propagate_transpose(node, parallel_axis) - elif is_permute(node): - marked = self.propagate_permute(node, parallel_axis) - - # fall back - if not marked: - self.place_marker_per_node(node, {"parallel_axis": parallel_axis}) + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) + return graph_module class ParallelLayerReplacePass(PassBase): """ - A pass which modifies graph according to information provided by previous analytical passes, - in general it does two things for now: + A pass which modifies graph according to information provided by previous analytical passes, in general it does two things for now: 1. replaces linears and embedding layers with their parallel counterparts. 2. modifies hard-coded arguments like the number of attention heads in the graph by dividing it by parallelism level. """ @@ -453,7 +365,7 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): else: node.update_arg(parallel_axis + 1, shape[parallel_axis]) - parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field="parallel_axis") + parallel_axis = ParallelAxisSolverPass.get_stored_field_info(node, field="parallel_axis") if parallel_axis is None: return @@ -577,18 +489,18 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf def build_parallel_pass_pipeline() -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: - 1. `ParallelLayerAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` - 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow - 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes - 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters + 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. + 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. Returns: PassPipeline: the pipeline used for automatic parallelism. """ return PassPipeline( [ + ParallelAxisSolverPass(), ParallelLayerAnnotatePass(), - ParallelAxisPropagationPass(), ParallelLayerReplacePass(), InitializeOrLoadWeightsPass(), ] diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index f129ffbd402..b7b1ccd41c8 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -17,7 +17,6 @@ import hashlib import importlib import json -import operator import os import re import tempfile @@ -45,6 +44,14 @@ def ensure_divisibility(numerator: int, denominator: int) -> None: ) +def is_activation(node: Node) -> bool: + # only consider leaf Module activations + if node.op != "call_module": + return False + mod = node.graph.owning_module + return getattr(mod.get_submodule(node.target), "__module__", "").startswith("torch.nn.modules.activation") + + def is_linear(node: Node) -> bool: if node.op != "call_module": return False @@ -67,26 +74,6 @@ def is_shape_consumer(node: Node) -> bool: return False -def is_transpose(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"transpose", "transpose_"} - elif node.op == "call_function": - return node.target is torch.transpose - return False - - -def is_permute(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"permute"} - elif node.op == "call_function": - return node.target is torch.permute - return False - - -def is_getitem(node: Node) -> bool: - return node.op == "call_function" and node.target is operator.getitem - - def is_output(node: Node) -> bool: return node.op == "output" From 4114d3bbf7396fa8925890c1031d19cf73a74df1 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 14 Aug 2024 03:14:30 +0200 Subject: [PATCH 02/24] only support model id in api now --- optimum/fx/parallelization/api.py | 70 +++++++++---------- .../op_registry/op_handlers.py | 2 +- optimum/fx/parallelization/passes.py | 2 +- 3 files changed, 36 insertions(+), 38 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index bd307bd93c1..fd38ae13e24 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -15,10 +15,11 @@ import importlib import os from functools import partial -from typing import List, Union +from typing import List import torch from torch.fx import GraphModule +from transformers import AutoConfig from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline @@ -43,7 +44,7 @@ def parallelize_backend( def parallelize_model( - model: Union[torch.nn.Module, str], + model: str, parallel_ctx: ParallelExecutionCtx, *model_args, **kwargs, @@ -52,8 +53,8 @@ def parallelize_model( API for automatic model parallelism through Pytorch FX. Args: - model (Union[torch.nn.Module, str]): - Model to parallelize, could either be a module or a model id on the Huggingface Hub. + model (str): + Model to parallelize, a model id on the Huggingface Hub. parallel_ctx (ParallelExecutionCtx): Parallel execution context containing process groups the current process belongs to. *model_args (Any): @@ -80,44 +81,41 @@ def parallelize_model( setattr(parallel_config, k, v) kwargs.pop(k) - if isinstance(model, str): - from transformers import AutoConfig - - is_local = os.path.isdir(model) - if not is_local: - hf_folder = download_model_from_hf( - model_name_or_path=model, - cache_dir=cache_dir, - revision=revision, - local_files_only=local_files_only, - skip_download_weights=skip_load_weights, - ) - else: - hf_folder = model - - # should be able to load config using only local files - model_config, kwargs = AutoConfig.from_pretrained( - hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + is_local = os.path.isdir(model) + if not is_local: + hf_folder = download_model_from_hf( + model_name_or_path=model, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_files_only, + skip_download_weights=skip_load_weights, ) + else: + hf_folder = model - # try getting model class info from config - model_arch = model_config.architectures - model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) - if not skip_load_weights: - parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) + # try getting model class info from config + model_arch = model_config.architectures + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) - torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None - if torch_dtype is not None: - dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + if not skip_load_weights: + parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) - with MetaAwareMethodsPatcher(): - model = model_cls(model_config, *model_args, **kwargs) - # TODO: remove this once support training-time trace - model.eval() + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) + with MetaAwareMethodsPatcher(): + model = model_cls(model_config, *model_args, **kwargs) + # TODO: remove this once support training-time trace + model.eval() + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index e5113537ff3..faddf8ed455 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -439,7 +439,7 @@ def propagate(self) -> List[int]: # last resort, if no input is being parallelized, then we make output also not parallelized, # this will give us relief on writing policies for strange ops which don't actually need # parallelization in most cases - if all([self.extract_axis(arg) is None for arg in self.node.all_input_nodes]): + if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes): return [None] raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}") diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index aee14811627..80d53aeddf3 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -194,7 +194,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs) stable_topological_sort(graph) - nodes = [node for node in graph.nodes] + nodes = list(graph.nodes) def search(idx: int): if idx == len(nodes): From c689402b5d193b4d4d1639ecdcfd4b01b866985e Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 15 Aug 2024 21:49:12 +0200 Subject: [PATCH 03/24] more comments --- optimum/fx/parallelization/api.py | 3 ++- optimum/fx/parallelization/decomp.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index fd38ae13e24..ea85301c4c2 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -54,7 +54,8 @@ def parallelize_model( Args: model (str): - Model to parallelize, a model id on the Huggingface Hub. + Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights + of the model. parallel_ctx (ParallelExecutionCtx): Parallel execution context containing process groups the current process belongs to. *model_args (Any): diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 90dfc1bf129..b3dd5149a8b 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -68,6 +68,17 @@ def __init__(self, graph: Graph): class DecompositionInterpreter(Interpreter): + """ + DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note + that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific + heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts + in the orignal graph module. + + Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for + parallel axis propagation and analysis, the original graph module is still used for real execution. + """ + def __init__( self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs ): From 252c3b7b12e9558058776a780a1b4fe76bb2b1b8 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 16 Aug 2024 22:06:02 +0200 Subject: [PATCH 04/24] more comments --- optimum/fx/parallelization/api.py | 4 ++-- optimum/fx/parallelization/decomp.py | 21 ++++++++++++++++++--- optimum/fx/parallelization/passes.py | 4 ++-- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index ea85301c4c2..47e86c28c71 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -15,7 +15,7 @@ import importlib import os from functools import partial -from typing import List +from typing import Callable, List import torch from torch.fx import GraphModule @@ -48,7 +48,7 @@ def parallelize_model( parallel_ctx: ParallelExecutionCtx, *model_args, **kwargs, -): +) -> Callable: """ API for automatic model parallelism through Pytorch FX. diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index b3dd5149a8b..a4e7b1d4261 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -72,7 +72,7 @@ class DecompositionInterpreter(Interpreter): DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific - heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts + heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts in the orignal graph module. Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for @@ -106,7 +106,6 @@ def placeholder(self, target, args, kwargs): track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) - # TODO handle case where the first character of target is '*' return out def call_function(self, target, args, kwargs): @@ -187,9 +186,25 @@ def run(self, *args, **kwargs): def decompose_and_functionalize( graph_module: GraphModule, - decomposition_table: Dict = core_aten_decompositions(), + decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(), leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], ) -> Callable: + """ + API to decompose and funcitonalize a high-level graph module. + + Args: + graph_module (GraphModule): + The high-level graph module to be decomposed and functionalized. + decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`): + The lookup table which maps high-level torch op to their equivalent low-level implementation. + leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`): + Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is + treated as a leaf function by default so that we don't have to deal with all detailed version of + sdpas in the traced graph. + + Returns: + Callable: a wrapper which returns the traced low-level graph when called with concrete arguments. + """ new_graph = Graph(owning_module=graph_module) interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 80d53aeddf3..85ec0e7aba5 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -170,8 +170,8 @@ class ParallelAxisSolverPass(AnalyzeBase): - Optimal Solution. Note that since we return the first solution we find, then it might not be optimal in terms of memory consumption and communication overhead. But again we can adjust the order of search and try parallelize as much as we can first before fall back to non-parallelized search paths. And we don't pay too much attention - on calculating communication overhead because in practice they are bounded by number of certain layers in the graph - under the constraint that only certain layers are allowed to communicate. + on calculating communication overhead because in practice they are bounded under the constraint that only certain + layers are allowed to communicate. Our goal is not to solve an optimization problem which tries to give a best solution of parallelizing any model under memory/hardware constraints, but rather a cheap solution which relieves you from writing boilerplate code for parallelizing layers of different models. From 22d67665f4b040dff9e4562b41702807f9bec1be Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 20 Aug 2024 21:34:25 +0200 Subject: [PATCH 05/24] address comments --- optimum/fx/parallelization/api.py | 16 +++++----- optimum/fx/parallelization/decomp.py | 31 ++++++++++++------- .../op_registry/op_handlers.py | 2 +- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 47e86c28c71..9700b491e52 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -53,22 +53,22 @@ def parallelize_model( API for automatic model parallelism through Pytorch FX. Args: - model (str): + model (`str`): Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights of the model. - parallel_ctx (ParallelExecutionCtx): + parallel_ctx (`ParallelExecutionCtx`): Parallel execution context containing process groups the current process belongs to. - *model_args (Any): + *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. - revision (str, defaults to `main`): + revision (`str`, defaults to `main`): Model revision for weights downloading if a model id is passed. - cache_dir (Optional[str], defaults to `None`): + cache_dir (`Optional[str]`, defaults to `None`): Cache directory to store downloaded weights. Defaults to None. - local_files_only (bool, defaults to `False`): + local_files_only (`bool`, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. - skip_load_weights (bool, defaults to `False`): + skip_load_weights (`bool`, defaults to `False`): Whether to skip loading weights from disk to model. - **kwargs (Dict[str, Any]): + **kwargs (`Dict[str, Any]`): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. """ revision = kwargs.pop("revision", "main") diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index a4e7b1d4261..312249253e7 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -61,6 +61,13 @@ def create_node_(*args, **kwargs): class DecompTracer(GraphAppendingTracer): + """ + DecompTracer is a tracer class which works together with `DecompositionInterpreter`, it keeps track of tensors and their + corresponding proxy objects during execution process. When invoked with `create_proxy`, it will creates a node in the containing + graph and associate the output tensor of the node with the created proxy. + + See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. + """ def __init__(self, graph: Graph): super().__init__(graph) self.tensor_tracker = WeakTensorKeyDictionary() @@ -70,13 +77,15 @@ def __init__(self, graph: Graph): class DecompositionInterpreter(Interpreter): """ DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose - high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note - that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific - heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts - in the orignal graph module. - - Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for - parallel axis propagation and analysis, the original graph module is still used for real execution. + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. + + Notes: + - Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific + heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts + in the orignal graph module. + + - The traced graph is a low-level equivalent representation of the original graph module, and is only used for + parallel axis propagation and analysis, the original graph module is still used for real execution. """ def __init__( @@ -190,14 +199,14 @@ def decompose_and_functionalize( leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], ) -> Callable: """ - API to decompose and funcitonalize a high-level graph module. + API to decompose and functionalize a high-level graph module. Args: - graph_module (GraphModule): + graph_module (`GraphModule`): The high-level graph module to be decomposed and functionalized. - decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`): + decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`): The lookup table which maps high-level torch op to their equivalent low-level implementation. - leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`): + leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`): Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is treated as a leaf function by default so that we don't have to deal with all detailed version of sdpas in the traced graph. diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index faddf8ed455..61c621d5f1c 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -124,7 +124,7 @@ def propagate(self) -> List[int]: ] ) class UnaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): - def propagate(self) -> bool: + def propagate(self) -> List[int]: arg = self.node.all_input_nodes[0] axis = self.extract_axis(arg) return [axis] From febac9b0f558a0a773604f68e201e0f05f54dae8 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 20 Aug 2024 21:35:39 +0200 Subject: [PATCH 06/24] remove idle runner --- .github/workflows/test_fx_automatic_parallel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml index 3c913e3f7ed..d8af6e40caa 100644 --- a/.github/workflows/test_fx_automatic_parallel.yml +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -24,7 +24,7 @@ jobs: config: - name: GPU-enabled Optimum Test Suite image: nvidia/cuda:12.4.1-devel-ubuntu22.04 - gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] + gpu_target: ["nvidia-multi-gpu-a10-runners"] name: ${{ matrix.config.name }} runs-on: From bf99175f992bd1836ac518be590aeb0bbfb46539 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 20 Aug 2024 21:38:21 +0200 Subject: [PATCH 07/24] fix --- optimum/fx/parallelization/decomp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 312249253e7..7ba18f43438 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -63,7 +63,7 @@ def create_node_(*args, **kwargs): class DecompTracer(GraphAppendingTracer): """ DecompTracer is a tracer class which works together with `DecompositionInterpreter`, it keeps track of tensors and their - corresponding proxy objects during execution process. When invoked with `create_proxy`, it will creates a node in the containing + corresponding proxy objects during execution process. When invoked with `create_proxy`, it creates a node in the containing graph and associate the output tensor of the node with the created proxy. See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. From 4d9d036d10f896e0a4871514f850839e7061e8c3 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 20 Aug 2024 21:54:41 +0200 Subject: [PATCH 08/24] format --- optimum/fx/parallelization/decomp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 7ba18f43438..26258d451bf 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -68,6 +68,7 @@ class DecompTracer(GraphAppendingTracer): See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. """ + def __init__(self, graph: Graph): super().__init__(graph) self.tensor_tracker = WeakTensorKeyDictionary() @@ -77,8 +78,8 @@ def __init__(self, graph: Graph): class DecompositionInterpreter(Interpreter): """ DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose - high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. - + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. + Notes: - Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts From 44a87f4626c4076a56d1773a2d3dfea8da1d4d22 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 26 Aug 2024 20:45:57 +0200 Subject: [PATCH 09/24] more comments --- optimum/fx/parallelization/op_registry/op_handlers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py index 61c621d5f1c..18042c3e2b9 100644 --- a/optimum/fx/parallelization/op_registry/op_handlers.py +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -23,6 +23,11 @@ class Registry: + """ + Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new + aten op, you need to register the corresponding handler class by decorating it with `register` function. + """ + def __init__(self) -> None: self.mapping = {} From 513d5163aa9a377bc8fb4261b35b91dc79072ea5 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 27 Aug 2024 00:36:03 +0200 Subject: [PATCH 10/24] generalize api & add backend abstraction --- optimum/fx/parallelization/api.py | 66 +++++--- .../fx/parallelization/backend/__init__.py | 1 + optimum/fx/parallelization/backend/base.py | 126 +++++++++++++++ .../fx/parallelization/backend/nanotron.py | 150 ++++++++++++++++++ 4 files changed, 318 insertions(+), 25 deletions(-) create mode 100644 optimum/fx/parallelization/backend/__init__.py create mode 100644 optimum/fx/parallelization/backend/base.py create mode 100644 optimum/fx/parallelization/backend/nanotron.py diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 9700b491e52..0e105ae699c 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -15,11 +15,11 @@ import importlib import os from functools import partial -from typing import Callable, List +from typing import Callable, List, Optional, Type import torch from torch.fx import GraphModule -from transformers import AutoConfig +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline @@ -44,22 +44,28 @@ def parallelize_backend( def parallelize_model( - model: str, parallel_ctx: ParallelExecutionCtx, *model_args, + model_path: Optional[str] = None, + model_cls: Optional[Type[PreTrainedModel]] = None, + model_config: Optional[PretrainedConfig] = None, **kwargs, ) -> Callable: """ API for automatic model parallelism through Pytorch FX. Args: - model (`str`): - Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights - of the model. parallel_ctx (`ParallelExecutionCtx`): Parallel execution context containing process groups the current process belongs to. *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. + model_path (`str`): + Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights + of the model. + model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`): + Model class in transformers library, i.e, `LlamaForCausalLM`. + model_config (`Optional[PretrainedConfig]`, defaults to `None`): + Model config to intialize the model. revision (`str`, defaults to `main`): Model revision for weights downloading if a model id is passed. cache_dir (`Optional[str]`, defaults to `None`): @@ -82,29 +88,39 @@ def parallelize_model( setattr(parallel_config, k, v) kwargs.pop(k) - is_local = os.path.isdir(model) - if not is_local: - hf_folder = download_model_from_hf( - model_name_or_path=model, - cache_dir=cache_dir, - revision=revision, - local_files_only=local_files_only, - skip_download_weights=skip_load_weights, + if model_path is not None and (model_cls is not None or model_config is not None): + raise ValueError( + "Can not accept passing in all of `model_path`, `model_cls` and `model_config`. Only specify " + "`model_path` or `model_cls` and `model_config` because there might be conflicts otherwise" ) - else: - hf_folder = model - # should be able to load config using only local files - model_config, kwargs = AutoConfig.from_pretrained( - hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs - ) + # Init model instance + if model_path is not None: + is_local = os.path.isdir(model_path) + if not is_local: + hf_folder = download_model_from_hf( + model_name_or_path=model_path, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_files_only, + skip_download_weights=skip_load_weights, + ) + else: + hf_folder = model_path + + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) - # try getting model class info from config - model_arch = model_config.architectures - model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + # try getting model class info from config + model_arch = model_config.architectures + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) - if not skip_load_weights: - parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) + if not skip_load_weights: + parallel_ctx.weight_map = try_collect_weight_map(model_path, cache_dir, hf_folder) + elif model_cls is None or model_config is None: + raise ValueError("must provide `model_cls` and `model_config` in the case of not providing `model_path`") torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None if torch_dtype is not None: diff --git a/optimum/fx/parallelization/backend/__init__.py b/optimum/fx/parallelization/backend/__init__.py new file mode 100644 index 00000000000..c41499635d7 --- /dev/null +++ b/optimum/fx/parallelization/backend/__init__.py @@ -0,0 +1 @@ +from .base import BackEnd, DefaultBackend diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py new file mode 100644 index 00000000000..746680ce86c --- /dev/null +++ b/optimum/fx/parallelization/backend/base.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Tuple + +import torch.nn as nn +from torch.fx import GraphModule + +from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding +from ..passes import ( + InitializeOrLoadWeightsPass, + ParallelAxisSolverPass, + ParallelLayerAnnotatePass, + ParallelLayerReplacePass, + PassPipeline, +) + + +if TYPE_CHECKING: + from ..core import ParallelExecutionCtx + + +class BackEnd(ABC): + @abstractmethod + def create_column_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + gather_output: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> nn.Module: + raise NotImplementedError + + @abstractmethod + def create_row_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + input_is_parallel: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> nn.Module: + raise NotImplementedError + + @abstractmethod + def create_parallel_embedding( + self, + mod: nn.Embedding, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> nn.Module: + raise NotImplementedError + + @abstractmethod + def post_process(self, graph_module: GraphModule) -> nn.Module: + return graph_module + + @abstractmethod + def init_parallelization_pass_pipeline( + self, + ) -> PassPipeline: + raise NotImplementedError + + +class DefaultBackend(BackEnd): + def create_column_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + gather_output: bool, + contiguous_chunks: Tuple[int] | None = None, + ) -> nn.Module: + if sequence_parallel or contiguous_chunks is not None: + raise NotImplementedError( + "DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now" + ) + return ColumnParallelLinear(parallel_ctx, mod, gather_output) + + def create_row_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + input_is_parallel: bool, + contiguous_chunks: Tuple[int] | None = None, + ) -> nn.Module: + if sequence_parallel or contiguous_chunks is not None: + raise NotImplementedError( + "DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now" + ) + return RowParallelLinear(parallel_ctx, mod, input_is_parallel) + + def create_parallel_embedding( + self, + mod: nn.Embedding, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + contiguous_chunks: Tuple[int] | None = None, + ) -> nn.Module: + if sequence_parallel or contiguous_chunks is not None: + raise NotImplementedError( + "DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now" + ) + + return VocabParallelEmbedding(parallel_ctx, mod) + + def init_parallelization_pass_pipeline(self): + """ + Ensemble a pass pipeline which contains the following passes: + 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. + 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. + + Returns: + PassPipeline: the pipeline used for automatic parallelism. + """ + return PassPipeline( + [ + ParallelAxisSolverPass(), + ParallelLayerAnnotatePass(), + ParallelLayerReplacePass(), + InitializeOrLoadWeightsPass(), + ] + ) diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py new file mode 100644 index 00000000000..545540d6ee4 --- /dev/null +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -0,0 +1,150 @@ +from typing import TYPE_CHECKING, Optional, Tuple + +import torch.nn as nn + +# Nanotron +from nanotron.config import Config +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) +from torch.fx import GraphModule + +from ..passes import ( + ParallelAxisSolverPass, + ParallelLayerAnnotatePass, + ParallelLayerReplacePass, + PassPipeline, +) +from .base import BackEnd + + +if TYPE_CHECKING: + from ..core import ParallelExecutionCtx + + +class NanotronBackend(BackEnd): + def __init__(self, nanotron_config: Config) -> None: + self.config = nanotron_config + + def create_column_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + gather_output: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> TensorParallelColumnLinear: + if gather_output: + raise ValueError( + "Nanotron backend does not support `gather_output=True` in `TensorParallelColumnLinear` yet" + ) + + if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER: + raise ValueError( + "`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend" + ) + + tp_mode = TensorParallelLinearMode.REDUCE_SCATTER if sequence_parallel else TensorParallelLinearMode.ALL_REDUCE + return TensorParallelColumnLinear( + in_features=mod.in_features, + out_features=mod.out_features, + pg=parallel_ctx.tp_group, + mode=tp_mode, + bias=mod.bias is not None, + device=parallel_ctx.current_device, + dtype=mod.weight.dtype, + async_communication=self.config.parallelism.tp_linear_async_communication, + contiguous_chunks=contiguous_chunks, + ) + + def create_row_parallel_linear( + self, + mod: nn.Linear, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + input_is_parallel: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> TensorParallelRowLinear: + if not input_is_parallel: + raise ValueError( + "Nanotron backend does not support `input_is_parallel=True` in `TensorParallelRowLinear` yet" + ) + + if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER: + raise ValueError( + "`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend" + ) + + tp_mode = TensorParallelLinearMode.REDUCE_SCATTER if sequence_parallel else TensorParallelLinearMode.ALL_REDUCE + return TensorParallelRowLinear( + in_features=mod.in_features, + out_features=mod.out_features, + pg=parallel_ctx.tp_group, + mode=tp_mode, + bias=mod.bias is not None, + device=parallel_ctx.current_device, + dtype=mod.weight.dtype, + async_communication=self.config.parallelism.tp_linear_async_communication, + contiguous_chunks=contiguous_chunks, + ) + + def create_parallel_embedding( + self, + mod: nn.Embedding, + parallel_ctx: "ParallelExecutionCtx", + sequence_parallel: bool, + contiguous_chunks: Optional[Tuple[int]] = None, + ) -> TensorParallelEmbedding: + if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER: + raise ValueError( + "`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend" + ) + + tp_mode = TensorParallelLinearMode.REDUCE_SCATTER if sequence_parallel else TensorParallelLinearMode.ALL_REDUCE + return TensorParallelEmbedding( + num_embeddings=mod.num_embeddings, + embedding_dim=mod.embedding_dim, + pg=parallel_ctx.tp_group, + mode=tp_mode, + padding_idx=mod.padding_idx, + max_norm=mod.max_norm, + norm_type=mod.norm_type, + scale_grad_by_freq=mod.scale_grad_by_freq, + sparse=mod.sparse, + device=parallel_ctx.current_device, + dtype=mod.weight.dtype, + contiguous_chunks=contiguous_chunks, + ) + + def post_process(self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx") -> nn.Module: + for name, param in graph_module.named_parameters(): + if not isinstance(param, NanotronParameter): + prefix_and_field = name.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = name + + assert ( + param.device == parallel_ctx.current_device + ), "all parameters should already be on the current device" + new_param = NanotronParameter(param.detach(), param.requires_grad) + setattr(parent_mod, field, new_param) + + def init_parallelization_pass_pipeline(self) -> PassPipeline: + """ + For nanotron backend, parameter initialization and checkpoint loading is handled outside. + """ + return PassPipeline( + [ + ParallelAxisSolverPass(), + ParallelLayerAnnotatePass(), + ParallelLayerReplacePass(), + ] + ) From 8335a3510b09b4790d8a5c7689a3d4b1e4985062 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 27 Aug 2024 20:32:29 +0200 Subject: [PATCH 11/24] fix --- optimum/fx/parallelization/api.py | 11 ++++++----- optimum/fx/parallelization/backend/base.py | 2 +- optimum/fx/parallelization/core.py | 12 ++++++++---- tests/fx/parallelization/test_tensor_parallel.py | 8 ++++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 0e105ae699c..5d7c1fe71c4 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -18,11 +18,11 @@ from typing import Callable, List, Optional, Type import torch +import torch.nn as nn from torch.fx import GraphModule from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from .core import Config, ParallelExecutionCtx -from .passes import build_parallel_pass_pipeline from .utils import ( MetaAwareMethodsPatcher, download_model_from_hf, @@ -34,13 +34,14 @@ def parallelize_backend( graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config -) -> GraphModule: +) -> nn.Module: ctx.example_inputs = example_inputs - pass_pipeline = build_parallel_pass_pipeline() + pass_pipeline = ctx.backend.init_parallelization_pass_pipeline() graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) + finalized_module = ctx.backend.post_process(graph_module, ctx) ctx.compile_times += 1 - ctx.last_optimized_graph_module = graph_module - return graph_module + ctx.last_optimized_module = finalized_module + return finalized_module def parallelize_model( diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 746680ce86c..51971d4676e 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -52,7 +52,7 @@ def create_parallel_embedding( raise NotImplementedError @abstractmethod - def post_process(self, graph_module: GraphModule) -> nn.Module: + def post_process(self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx") -> nn.Module: return graph_module @abstractmethod diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index fafa30f2e7e..7d3746e489d 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -19,7 +19,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.fx import GraphModule +from .backend import BackEnd, DefaultBackend class HashableSlice: @@ -115,6 +115,9 @@ class ParallelExecutionCtx: - current_device (`torch.device`): Device correpsonding to the current process. + - backend (`BackEnd`, defaults to `DefaultBackEnd`): + Backend instance which converts layers into their parallelized counterparts. + - example_inputs (`List[Any]`): A list of tensors which are used as example inputs for graphs captured by dynamo. @@ -129,8 +132,8 @@ class ParallelExecutionCtx: Mapping between parameter names and their locations on disk, useful when loading weights from disk. - - last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`): - Optimized graph module corresponding to the latest compilation. + - last_optimized_module (`Optional[nn.Module]`, defaults to `None`): + Optimized module corresponding to the latest compilation. - compile_times (`int`, defaults to `0`): Number of compilation times happened during the whole process. @@ -138,10 +141,11 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device + backend: BackEnd = DefaultBackend() example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) weight_map: Dict[str, str] = field(default_factory=dict) - last_optimized_graph_module: Optional[GraphModule] = None + last_optimized_module: Optional[nn.Module] = None compile_times: int = 0 diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 9626fccec3b..64ce80b78fb 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -114,17 +114,17 @@ def run_test_parameters_persist_bewteen_recompile( yet_another_inputs = prepare_dummy_inputs(model.config, batch_size=2, seq_len=12) model(**inputs) - parameter_ids = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + parameter_ids = {id(param) for _, param in ctx.last_optimized_module.named_parameters()} model(**another_inputs) # check second compilation has been triggered assert ctx.compile_times == 2 - parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_module.named_parameters()} assert parameter_ids == parameter_ids_after_recompile model(**yet_another_inputs) assert ctx.compile_times == 3 - parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_graph_module.named_parameters()} + parameter_ids_after_recompile = {id(param) for _, param in ctx.last_optimized_module.named_parameters()} assert parameter_ids == parameter_ids_after_recompile dist.barrier(tp_group) tearDown(tp_group) @@ -175,7 +175,7 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode model(**inputs) embedding_weight, lm_head_weight = None, None - graph_module = ctx.last_optimized_graph_module + graph_module = ctx.last_optimized_module stable_topological_sort(graph_module.graph) for node in graph_module.graph.nodes: if node.op == "call_module": From d051217337ed4ba33ad8a57263c594316d17feb2 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 27 Aug 2024 22:35:44 +0200 Subject: [PATCH 12/24] copyright --- .../fx/parallelization/backend/__init__.py | 14 ++++++++ optimum/fx/parallelization/backend/base.py | 14 ++++++++ .../fx/parallelization/backend/nanotron.py | 14 ++++++++ optimum/fx/parallelization/core.py | 1 + optimum/fx/parallelization/passes.py | 36 +++++-------------- 5 files changed, 52 insertions(+), 27 deletions(-) diff --git a/optimum/fx/parallelization/backend/__init__.py b/optimum/fx/parallelization/backend/__init__.py index c41499635d7..abae86a6185 100644 --- a/optimum/fx/parallelization/backend/__init__.py +++ b/optimum/fx/parallelization/backend/__init__.py @@ -1 +1,15 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .base import BackEnd, DefaultBackend diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 51971d4676e..b2f428a9dc4 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional, Tuple diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 545540d6ee4..cda3bb1545c 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING, Optional, Tuple import torch.nn as nn diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 7d3746e489d..fe258ccdbb0 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -19,6 +19,7 @@ import torch import torch.distributed as dist import torch.nn as nn + from .backend import BackEnd, DefaultBackend diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 85ec0e7aba5..85990903b98 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -26,7 +26,6 @@ from .decomp import decompose_and_functionalize from .distributed import scatter from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler -from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .utils import ( is_embedding, is_linear, @@ -300,20 +299,23 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Linear = graph_module.get_submodule(node.target) - key, layer_cache = node.target, ctx.parallel_layer_cache + key, layer_cache, backend = node.target, ctx.parallel_layer_cache, ctx.backend if key in layer_cache: new_mod = layer_cache[key] else: + assert ctx.compile_times == 0, "illegal path for recompilation" if axis == "column": gather_output = ParallelLayerAnnotatePass.get_stored_field_info( node, field="gather_output", must_have=True ) - new_mod = ColumnParallelLinear(ctx, mod, gather_output) + # TODO: enable sequence parallel + new_mod = backend.create_column_parallel_linear(mod, ctx, False, gather_output) else: input_is_parallel = ParallelLayerAnnotatePass.get_stored_field_info( node, field="input_is_parallel", must_have=True ) - new_mod = RowParallelLinear(ctx, mod, input_is_parallel) + # TODO: enable sequence parallel + new_mod = backend.create_row_parallel_linear(mod, ctx, False, input_is_parallel) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @@ -334,12 +336,13 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: field = node.target mod: nn.Embedding = graph_module.get_submodule(node.target) - key, layer_cache = node.target, ctx.parallel_layer_cache + key, layer_cache, backend = node.target, ctx.parallel_layer_cache, ctx.backend if key in layer_cache: new_mod = layer_cache[key] else: assert ctx.compile_times == 0, "illegal path for recompilation" - new_mod = VocabParallelEmbedding(ctx, mod) + # TODO: enable sequence parallel + new_mod = backend.create_parallel_embedding(mod, ctx, False) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) @@ -486,27 +489,6 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf return graph_module -def build_parallel_pass_pipeline() -> PassPipeline: - """ - Ensemble a pass pipeline which contains the following passes: - 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. - 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. - 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. - 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. - - Returns: - PassPipeline: the pipeline used for automatic parallelism. - """ - return PassPipeline( - [ - ParallelAxisSolverPass(), - ParallelLayerAnnotatePass(), - ParallelLayerReplacePass(), - InitializeOrLoadWeightsPass(), - ] - ) - - class PassPipeline: """ `PassPipeline` ensembles a list of passes and execute them one by one as provided in the list, From 6b03855fef2cd126e879840afe0464b70cf8ec3d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Thu, 29 Aug 2024 00:04:48 +0200 Subject: [PATCH 13/24] fix api --- optimum/fx/parallelization/api.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 5d7c1fe71c4..743079494d0 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -47,7 +47,7 @@ def parallelize_backend( def parallelize_model( parallel_ctx: ParallelExecutionCtx, *model_args, - model_path: Optional[str] = None, + model_id_or_path: Optional[str] = None, model_cls: Optional[Type[PreTrainedModel]] = None, model_config: Optional[PretrainedConfig] = None, **kwargs, @@ -60,7 +60,7 @@ def parallelize_model( Parallel execution context containing process groups the current process belongs to. *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. - model_path (`str`): + model_id_or_path (`str`): Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights of the model. model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`): @@ -89,25 +89,25 @@ def parallelize_model( setattr(parallel_config, k, v) kwargs.pop(k) - if model_path is not None and (model_cls is not None or model_config is not None): + if model_id_or_path is not None and (model_cls is not None or model_config is not None): raise ValueError( - "Can not accept passing in all of `model_path`, `model_cls` and `model_config`. Only specify " - "`model_path` or `model_cls` and `model_config` because there might be conflicts otherwise" + "Can not accept passing in all of `model_id_or_path`, `model_cls` and `model_config`. Only specify " + "`model_id_or_path` or `model_cls` and `model_config` because there might be conflicts otherwise" ) # Init model instance - if model_path is not None: - is_local = os.path.isdir(model_path) + if model_id_or_path is not None: + is_local = os.path.isdir(model_id_or_path) if not is_local: hf_folder = download_model_from_hf( - model_name_or_path=model_path, + model_name_or_path=model_id_or_path, cache_dir=cache_dir, revision=revision, local_files_only=local_files_only, skip_download_weights=skip_load_weights, ) else: - hf_folder = model_path + hf_folder = model_id_or_path # should be able to load config using only local files model_config, kwargs = AutoConfig.from_pretrained( @@ -119,9 +119,9 @@ def parallelize_model( model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) if not skip_load_weights: - parallel_ctx.weight_map = try_collect_weight_map(model_path, cache_dir, hf_folder) + parallel_ctx.weight_map = try_collect_weight_map(model_id_or_path, cache_dir, hf_folder) elif model_cls is None or model_config is None: - raise ValueError("must provide `model_cls` and `model_config` in the case of not providing `model_path`") + raise ValueError("must provide `model_cls` and `model_config` in the case of not providing `model_id_or_path`") torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None if torch_dtype is not None: From b4166aca5cb59943e0fbbfa5a5ee28c11ef8020d Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 30 Aug 2024 01:18:20 +0200 Subject: [PATCH 14/24] move weights intialization inside post process --- optimum/fx/parallelization/backend/base.py | 135 +++++++++++++++--- .../fx/parallelization/backend/nanotron.py | 30 +--- optimum/fx/parallelization/passes.py | 129 ++--------------- 3 files changed, 134 insertions(+), 160 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index b2f428a9dc4..1e72f4903d5 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -15,12 +15,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional, Tuple +import torch +import torch.distributed as dist import torch.nn as nn from torch.fx import GraphModule +from ..distributed import scatter from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from ..passes import ( - InitializeOrLoadWeightsPass, ParallelAxisSolverPass, ParallelLayerAnnotatePass, ParallelLayerReplacePass, @@ -29,7 +31,7 @@ if TYPE_CHECKING: - from ..core import ParallelExecutionCtx + from ..core import Config, ParallelExecutionCtx, ParameterMeta class BackEnd(ABC): @@ -66,14 +68,29 @@ def create_parallel_embedding( raise NotImplementedError @abstractmethod - def post_process(self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx") -> nn.Module: + def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: return graph_module @abstractmethod def init_parallelization_pass_pipeline( self, ) -> PassPipeline: - raise NotImplementedError + """ + Ensemble a pass pipeline which contains the following passes: + 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. + 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. + + Returns: + PassPipeline: the pipeline used for automatic parallelism. + """ + return PassPipeline( + [ + ParallelAxisSolverPass(), + ParallelLayerAnnotatePass(), + ParallelLayerReplacePass(), + ] + ) class DefaultBackend(BackEnd): @@ -119,22 +136,96 @@ def create_parallel_embedding( return VocabParallelEmbedding(parallel_ctx, mod) - def init_parallelization_pass_pipeline(self): - """ - Ensemble a pass pipeline which contains the following passes: - 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. - 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. - 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. - 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. - - Returns: - PassPipeline: the pipeline used for automatic parallelism. - """ - return PassPipeline( - [ - ParallelAxisSolverPass(), - ParallelLayerAnnotatePass(), - ParallelLayerReplacePass(), - InitializeOrLoadWeightsPass(), + def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: + world_size = dist.get_world_size(ctx.tp_group) + tp_rank = dist.get_rank(ctx.tp_group) + + new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache + for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): + # skip initializing new params when recompilation happens + if name in param_cache: + new_parameters.append((name, param_cache[name])) + continue + + param_meta: "ParameterMeta" = getattr(param, "meta") + # skip already initialized/loaded tied parameters + if param_meta.is_tied and id(param) in tied_parameters: + new_parameters.append((name, tied_parameters[id(param)])) + continue + + shape = [ + param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) + for dim in range(param.ndim) ] - ) + + if not param_meta.is_parallel and param.device == ctx.current_device: + new_param = param + else: + new_param = nn.Parameter( + torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), + requires_grad=param.requires_grad, + ) + + # load weights if possible + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + from safetensors import safe_open + + with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp: + tensor_slice = fp.get_slice(target.source) + source_index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + load_index = [ + target.index if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + + tensor = tensor_slice[load_index].contiguous() + tensor = torch.empty_like(tensor).copy_(tensor) + with torch.no_grad(): + new_param.data[source_index].copy_(tensor) + + # weights initialization + if param_meta.need_initialize: + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + continue + if not param_meta.is_parallel or tp_rank == 0: + # initialize weight on master rank + weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") + init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn + init_fn(weight) + weight = weight.to(ctx.current_device) + else: + weight = None + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + with torch.no_grad(): + if param_meta.is_parallel: + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + else: + new_param.data[index].copy_(weight) + setattr(new_param, "meta", param_meta) + + if id(new_param) != id(param): + new_parameters.append((name, new_param)) + if param_meta.is_tied: + tied_parameters[id(param)] = new_param + + for name, new_param in new_parameters: + prefix_and_field = name.rsplit(".", maxsplit=1) + if len(prefix_and_field) == 2: + parent_mod = graph_module.get_submodule(prefix_and_field[0]) + field = prefix_and_field[1] + else: + parent_mod = graph_module + field = name + if name not in param_cache: + param_cache[name] = new_param + setattr(parent_mod, field, new_param) + + return graph_module diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index cda3bb1545c..77f12a5d4ab 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -16,8 +16,8 @@ import torch.nn as nn -# Nanotron -from nanotron.config import Config +# Nanotron specific imports +from nanotron.config import Config as NanotronConfig from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import ( @@ -27,21 +27,15 @@ ) from torch.fx import GraphModule -from ..passes import ( - ParallelAxisSolverPass, - ParallelLayerAnnotatePass, - ParallelLayerReplacePass, - PassPipeline, -) from .base import BackEnd if TYPE_CHECKING: - from ..core import ParallelExecutionCtx + from ..core import Config, ParallelExecutionCtx class NanotronBackend(BackEnd): - def __init__(self, nanotron_config: Config) -> None: + def __init__(self, nanotron_config: NanotronConfig) -> None: self.config = nanotron_config def create_column_parallel_linear( @@ -134,7 +128,9 @@ def create_parallel_embedding( contiguous_chunks=contiguous_chunks, ) - def post_process(self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx") -> nn.Module: + def post_process( + self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" + ) -> nn.Module: for name, param in graph_module.named_parameters(): if not isinstance(param, NanotronParameter): prefix_and_field = name.rsplit(".", maxsplit=1) @@ -150,15 +146,3 @@ def post_process(self, graph_module: GraphModule, parallel_ctx: "ParallelExecuti ), "all parameters should already be on the current device" new_param = NanotronParameter(param.detach(), param.requires_grad) setattr(parent_mod, field, new_param) - - def init_parallelization_pass_pipeline(self) -> PassPipeline: - """ - For nanotron backend, parameter initialization and checkpoint loading is handled outside. - """ - return PassPipeline( - [ - ParallelAxisSolverPass(), - ParallelLayerAnnotatePass(), - ParallelLayerReplacePass(), - ] - ) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index c925350ca47..a11fea4b69f 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -15,16 +15,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List -import torch -import torch.distributed as dist import torch.nn as nn from torch.fx import Graph, GraphModule, Node -from .core import Config, ParallelExecutionCtx, ParameterMeta from .decomp import decompose_and_functionalize -from .distributed import scatter from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .utils import ( is_embedding, @@ -34,6 +30,10 @@ ) +if TYPE_CHECKING: + from .core import Config, ParallelExecutionCtx + + class PassBase(ABC): """ Base class for parallelization targeted passes. @@ -46,7 +46,7 @@ def signature(cls) -> str: return cls.__name__ @abstractmethod - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Args: graph_module (`GraphModule`): @@ -61,7 +61,7 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf """ raise NotImplementedError - def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def __call__(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: # skip running when recompilation happens if not self.need_rerun_when_recompile and ctx.compile_times > 0: return graph_module @@ -189,7 +189,7 @@ def trace_back(self, graph_module: GraphModule, decomp_graph: Graph) -> None: orig_node, {"parallel_axis": self.get_stored_field_info(node, field="parallel_axis")} ) - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs) stable_topological_sort(graph) @@ -230,7 +230,7 @@ class ParallelLayerAnnotatePass(AnalyzeBase): annotation according to parallelization strategy of input and output tensors. """ - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: for node in graph_module.graph.nodes: if is_linear(node): axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") @@ -283,7 +283,7 @@ class ParallelLayerReplacePass(PassBase): """ @staticmethod - def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: + def handle_linear(node: Node, ctx: "ParallelExecutionCtx") -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -320,7 +320,7 @@ def handle_linear(node: Node, ctx: ParallelExecutionCtx) -> None: setattr(parent_mod, field, new_mod) @staticmethod - def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: + def handle_embedding(node: Node, ctx: "ParallelExecutionCtx") -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -347,7 +347,7 @@ def handle_embedding(node: Node, ctx: ParallelExecutionCtx) -> None: setattr(parent_mod, field, new_mod) @staticmethod - def handle_hard_coded_axis_param(node: Node, ctx: ParallelExecutionCtx) -> None: + def handle_hard_coded_axis_param(node: Node, ctx: "ParallelExecutionCtx") -> None: def extract_shape_from_node(node: Node) -> List[Any]: if "size" in node.kwargs: return list(node.kwargs["size"]) @@ -381,7 +381,7 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): shape[parallel_axis] = shape[parallel_axis] // world_size update(node, shape, parallel_axis) - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def run(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: for node in graph_module.graph.nodes: if is_linear(node): self.handle_linear(node, ctx) @@ -393,107 +393,6 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf return graph_module -class InitializeOrLoadWeightsPass(PassBase): - """ - Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk - if necessary. - """ - - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - world_size = dist.get_world_size(ctx.tp_group) - tp_rank = dist.get_rank(ctx.tp_group) - - new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache - for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): - # skip initializing new params when recompilation happens - if name in param_cache: - new_parameters.append((name, param_cache[name])) - continue - - param_meta: ParameterMeta = getattr(param, "meta") - # skip already initialized/loaded tied parameters - if param_meta.is_tied and id(param) in tied_parameters: - new_parameters.append((name, tied_parameters[id(param)])) - continue - - shape = [ - param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) - for dim in range(param.ndim) - ] - - if not param_meta.is_parallel and param.device == ctx.current_device: - new_param = param - else: - new_param = nn.Parameter( - torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), - requires_grad=param.requires_grad, - ) - - # load weights if possible - for source, target in sorted(param_meta.mapping.items()): - if target.source in ctx.weight_map: - from safetensors import safe_open - - with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp: - tensor_slice = fp.get_slice(target.source) - source_index = [ - source.to_slice() if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - load_index = [ - target.index if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - - tensor = tensor_slice[load_index].contiguous() - tensor = torch.empty_like(tensor).copy_(tensor) - with torch.no_grad(): - new_param.data[source_index].copy_(tensor) - - # weights initialization - if param_meta.need_initialize: - for source, target in sorted(param_meta.mapping.items()): - if target.source in ctx.weight_map: - continue - if not param_meta.is_parallel or tp_rank == 0: - # initialize weight on master rank - weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") - init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn - init_fn(weight) - weight = weight.to(ctx.current_device) - else: - weight = None - index = [ - source.to_slice() if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - with torch.no_grad(): - if param_meta.is_parallel: - scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) - else: - new_param.data[index].copy_(weight) - setattr(new_param, "meta", param_meta) - - if id(new_param) != id(param): - new_parameters.append((name, new_param)) - if param_meta.is_tied: - tied_parameters[id(param)] = new_param - - for name, new_param in new_parameters: - prefix_and_field = name.rsplit(".", maxsplit=1) - if len(prefix_and_field) == 2: - parent_mod = graph_module.get_submodule(prefix_and_field[0]) - field = prefix_and_field[1] - else: - parent_mod = graph_module - field = name - if name not in param_cache: - param_cache[name] = new_param - setattr(parent_mod, field, new_param) - - return graph_module - - class PassPipeline: """ `PassPipeline` ensembles a list of passes and execute them one by one as provided in the list, @@ -511,7 +410,7 @@ def __iter__( def append(self, PASS: PassBase) -> None: self._passes.append(PASS) - def __call__(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: + def __call__(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: for PASS in self._passes: graph_module = PASS(graph_module=graph_module, ctx=ctx, config=config) From 576104c00710bbbbc30945600d52af5feb35d311 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 30 Aug 2024 21:35:28 +0200 Subject: [PATCH 15/24] seperate meta update and parallel layer construction --- optimum/fx/parallelization/backend/base.py | 2 +- .../parallel_layers/embedding.py | 25 ++--- .../parallelization/parallel_layers/linear.py | 95 +++++-------------- optimum/fx/parallelization/passes.py | 60 ++++++++++-- 4 files changed, 85 insertions(+), 97 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 1e72f4903d5..5b0bfa87fb1 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -158,7 +158,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c for dim in range(param.ndim) ] - if not param_meta.is_parallel and param.device == ctx.current_device: + if param.device == ctx.current_device: new_param = param else: new_param = nn.Parameter( diff --git a/optimum/fx/parallelization/parallel_layers/embedding.py b/optimum/fx/parallelization/parallel_layers/embedding.py index eb8cc9b2942..7c79e0c5905 100644 --- a/optimum/fx/parallelization/parallel_layers/embedding.py +++ b/optimum/fx/parallelization/parallel_layers/embedding.py @@ -17,7 +17,7 @@ import torch.nn as nn import torch.nn.functional as F -from ..core import ParallelExecutionCtx, ParameterMeta +from ..core import ParallelExecutionCtx from ..distributed import differentiable_all_reduce_sum from ..utils import ensure_divisibility @@ -48,23 +48,12 @@ def __init__(self, ctx: ParallelExecutionCtx, embedding: nn.Embedding): self.vocab_start_idx = tp_rank * num_embeddings self.vocab_end_idx = (tp_rank + 1) * num_embeddings - # modify meta information - weight_meta = getattr(embedding.weight, "meta", None) - assert isinstance( - weight_meta, ParameterMeta - ), "should have run `initialize_parameter_meta` after moving model to current device" - if weight_meta.is_modified_meta: - assert weight_meta.is_tied, "only tied parameters could already have modified meta" - else: - weight_meta.need_initialize = True - weight_meta.is_parallel = True - weight_meta.dim = 0 - for _, Slice in weight_meta.mapping.items(): - Slice.index = slice(self.vocab_start_idx, self.vocab_end_idx) - weight_meta.is_modified_meta = True - - # skip creating actual parameters - self.weight = embedding.weight + self.weight = nn.Parameter( + torch.empty( + (num_embeddings, embedding.embedding_dim), dtype=embedding.weight.dtype, device=ctx.current_device + ), + requires_grad=embedding.weight.requires_grad, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx) diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 62d5894dacf..9a995eb2784 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -12,15 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from ..core import ( - ParallelExecutionCtx, - ParameterMeta, -) from ..distributed import ( differentiable_all_gather, differentiable_all_reduce_sum, @@ -30,6 +28,10 @@ from ..utils import ensure_divisibility +if TYPE_CHECKING: + from ..core import ParallelExecutionCtx + + class ColumnParallelLinear(nn.Module): """ Linear layer with column parallelism. @@ -43,53 +45,26 @@ class ColumnParallelLinear(nn.Module): gather_output(`bool`, defaults to `True`): whether gathering output in the end of forward. """ - def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None: + def __init__(self, ctx: "ParallelExecutionCtx", linear: nn.Linear, gather_output: bool = True) -> None: super(ColumnParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) - tp_rank = dist.get_rank(self.process_group) ensure_divisibility(linear.out_features, world_size) out_features = linear.out_features // world_size bias = linear.bias is not None - # modify meta information - weight_meta = getattr(linear.weight, "meta", None) - assert isinstance( - weight_meta, ParameterMeta - ), "should have run `initialize_parameter_meta` after moving model to current device" - - if weight_meta.is_modified_meta: - assert weight_meta.is_tied, "only tied parameters could already have modified meta" - else: - weight_meta.need_initialize = True - weight_meta.is_parallel = True - weight_meta.dim = 0 - for _, Slice in weight_meta.mapping.items(): - Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) - weight_meta.is_modified_meta = True - - # skip creating actual parameters - self.weight = linear.weight + self.weight = nn.Parameter( + torch.empty((out_features, linear.in_features), dtype=linear.weight.dtype, device=ctx.current_device), + linear.weight.requires_grad, + ) self.gather_output = gather_output if bias: - bias_meta = getattr(linear.bias, "meta", None) - assert isinstance( - bias_meta, ParameterMeta - ), "should have run `initialize_parameter_meta` after moving model to current device" - - if bias_meta.is_modified_meta: - assert bias_meta.is_tied, "only tied parameters could already have modified meta" - else: - bias_meta.need_initialize = True - bias_meta.is_parallel = True - bias_meta.init_fn = torch.zero_ - bias_meta.dim = 0 - for _, Slice in bias_meta.mapping.items(): - Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features) - bias_meta.is_modified_meta = True - self.bias = linear.bias + self.bias = nn.Parameter( + torch.empty((out_features,), dtype=linear.bias.dtype, device=ctx.current_device), + linear.bias.requires_grad, + ) else: self.register_parameter("bias", None) @@ -120,48 +95,26 @@ class RowParallelLinear(nn.Module): input_is_parallel(`bool`, defaults to `True`): whether the input tensor has already been parallelized. """ - def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, input_is_parallel: bool = False) -> None: + def __init__(self, ctx: "ParallelExecutionCtx", linear: nn.Linear, input_is_parallel: bool = False) -> None: super(RowParallelLinear, self).__init__() self.process_group = ctx.tp_group world_size = dist.get_world_size(self.process_group) - tp_rank = dist.get_rank(self.process_group) ensure_divisibility(linear.in_features, world_size) in_features = linear.in_features // world_size bias = linear.bias is not None - # modify meta information - weight_meta = getattr(linear.weight, "meta", None) - assert isinstance( - weight_meta, ParameterMeta - ), "should have run `initialize_parameter_meta` after moving model to current device" - - if weight_meta.is_modified_meta: - assert weight_meta.is_tied, "only tied parameters could already have modified meta" - else: - weight_meta.need_initialize = True - weight_meta.is_parallel = True - weight_meta.dim = 1 - for _, Slice in weight_meta.mapping.items(): - Slice.index = slice(tp_rank * in_features, (tp_rank + 1) * in_features) - weight_meta.is_modified_meta = True - - # skip creating actual parameters - self.weight = linear.weight + self.weight = nn.Parameter( + torch.empty((linear.out_features, in_features), dtype=linear.weight.dtype, device=ctx.current_device), + linear.weight.requires_grad, + ) self.input_is_parallel = input_is_parallel if bias: - bias_meta = getattr(linear.bias, "meta", None) - assert isinstance( - bias_meta, ParameterMeta - ), "should have run `initialize_parameter_meta` after moving model to current device" - if bias_meta.is_modified_meta: - assert bias_meta.is_tied, "only tied parameters could already have modified meta" - else: - bias_meta.need_initialize = True - bias_meta.init_fn = torch.zero_ - bias_meta.is_modified_meta = True - self.bias = linear.bias + self.bias = nn.Parameter( + torch.empty((linear.out_features,), dtype=linear.bias.dtype, device=ctx.current_device), + linear.bias.requires_grad, + ) else: self.register_parameter("bias", None) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index a11fea4b69f..4c9fb7ad71c 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -14,15 +14,19 @@ # limitations under the License. from __future__ import annotations +import copy from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union +import torch +import torch.distributed as dist import torch.nn as nn from torch.fx import Graph, GraphModule, Node from .decomp import decompose_and_functionalize from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .utils import ( + ensure_divisibility, is_embedding, is_linear, is_shape_consumer, @@ -31,7 +35,7 @@ if TYPE_CHECKING: - from .core import Config, ParallelExecutionCtx + from .core import Config, ParallelExecutionCtx, ParameterMeta class PassBase(ABC): @@ -283,7 +287,49 @@ class ParallelLayerReplacePass(PassBase): """ @staticmethod - def handle_linear(node: Node, ctx: "ParallelExecutionCtx") -> None: + def propagate_meta( + ctx: "ParallelExecutionCtx", + mod: Union[nn.Linear, nn.Embedding], + new_mod: Union[nn.Linear, nn.Embedding], + axis: str, + ) -> None: + world_size, tp_rank = dist.get_world_size(ctx.tp_group), dist.get_rank(ctx.tp_group) + + def get_current_slice(shape: Tuple[int], axis: int = 0) -> slice: + ensure_divisibility(shape[axis], world_size) + return slice(shape[axis] // world_size * tp_rank, shape[axis] // world_size * (tp_rank + 1)) + + # modify meta information + weight_meta: "ParameterMeta" = copy.deepcopy(getattr(mod.weight, "meta")) + if weight_meta.is_modified_meta: + assert weight_meta.is_tied, "only tied parameters could already have modified meta" + else: + weight_meta.need_initialize = True + weight_meta.is_parallel = True + weight_meta.dim = 0 if axis in {"column", "vocab"} else 1 + for _, Slice in weight_meta.mapping.items(): + Slice.index = get_current_slice(Slice.shape, weight_meta.dim) + weight_meta.is_modified_meta = True + + setattr(new_mod.weight, "meta", weight_meta) + + if hasattr(new_mod, "bias") and new_mod.bias is not None: + bias_meta: "ParameterMeta" = copy.deepcopy(getattr(mod.bias, "meta")) + if bias_meta.is_modified_meta: + assert bias_meta.is_tied, "only tied parameters could already have modified meta" + else: + bias_meta.need_initialize = True + bias_meta.init_fn = torch.zero_ + bias_meta.is_modified_meta = True + + if weight_meta.dim == 0: + bias_meta.dim = 0 + bias_meta.is_parallel = True + for _, Slice in bias_meta.mapping.items(): + Slice.index = get_current_slice(Slice.shape, 0) + setattr(new_mod.bias, "meta", bias_meta) + + def handle_linear(self, node: Node, ctx: "ParallelExecutionCtx") -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -316,11 +362,11 @@ def handle_linear(node: Node, ctx: "ParallelExecutionCtx") -> None: ) # TODO: enable sequence parallel new_mod = backend.create_row_parallel_linear(mod, ctx, False, input_is_parallel) + self.propagate_meta(ctx, mod, new_mod, axis) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) - @staticmethod - def handle_embedding(node: Node, ctx: "ParallelExecutionCtx") -> None: + def handle_embedding(self, node: Node, ctx: "ParallelExecutionCtx") -> None: graph_module = node.graph.owning_module axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis") if axis is None: @@ -343,11 +389,11 @@ def handle_embedding(node: Node, ctx: "ParallelExecutionCtx") -> None: assert ctx.compile_times == 0, "illegal path for recompilation" # TODO: enable sequence parallel new_mod = backend.create_parallel_embedding(mod, ctx, False) + self.propagate_meta(ctx, mod, new_mod, axis) layer_cache[key] = new_mod setattr(parent_mod, field, new_mod) - @staticmethod - def handle_hard_coded_axis_param(node: Node, ctx: "ParallelExecutionCtx") -> None: + def handle_hard_coded_axis_param(self, node: Node, ctx: "ParallelExecutionCtx") -> None: def extract_shape_from_node(node: Node) -> List[Any]: if "size" in node.kwargs: return list(node.kwargs["size"]) From 8bbc2e985d910fa1686d52bee6151b4680de7280 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 2 Sep 2024 15:43:07 +0200 Subject: [PATCH 16/24] move weight intialization & binding inside backend --- optimum/fx/parallelization/api.py | 3 +- optimum/fx/parallelization/backend/base.py | 35 +++++++++++++++++----- optimum/fx/parallelization/core.py | 12 +++----- optimum/fx/parallelization/passes.py | 34 ++++++++------------- optimum/fx/parallelization/utils.py | 22 +++++--------- 5 files changed, 53 insertions(+), 53 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 743079494d0..bf4ef8fbd12 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -37,8 +37,9 @@ def parallelize_backend( ) -> nn.Module: ctx.example_inputs = example_inputs pass_pipeline = ctx.backend.init_parallelization_pass_pipeline() + graph_module = ctx.backend.pre_process(graph_module=graph_module, ctx=ctx, config=config) graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config) - finalized_module = ctx.backend.post_process(graph_module, ctx) + finalized_module = ctx.backend.post_process(graph_module=graph_module, ctx=ctx, config=config) ctx.compile_times += 1 ctx.last_optimized_module = finalized_module return finalized_module diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 5b0bfa87fb1..cb6609e14fd 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -67,6 +67,22 @@ def create_parallel_embedding( ) -> nn.Module: raise NotImplementedError + @abstractmethod + def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: + """ + Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our + passes don't. + """ + parameter_mp = {} + for name, tensor in graph_module.named_parameters(remove_duplicate=False): + key = id(tensor) + if key not in parameter_mp: + parameter_mp[key] = name + else: + param_meta: 'ParameterMeta' = getattr(tensor, 'meta') + param_meta.tied_to = parameter_mp[key] + return graph_module + @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: return graph_module @@ -140,7 +156,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c world_size = dist.get_world_size(ctx.tp_group) tp_rank = dist.get_rank(ctx.tp_group) - new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache + new_parameters, param_cache = [], ctx.param_cache for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): # skip initializing new params when recompilation happens if name in param_cache: @@ -148,9 +164,8 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c continue param_meta: "ParameterMeta" = getattr(param, "meta") - # skip already initialized/loaded tied parameters - if param_meta.is_tied and id(param) in tied_parameters: - new_parameters.append((name, tied_parameters[id(param)])) + # skip tied parameters for now + if param_meta.tied_to is not None: continue shape = [ @@ -158,6 +173,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c for dim in range(param.ndim) ] + # if the device is correct, then we directly use the parameter if param.device == ctx.current_device: new_param = param else: @@ -213,8 +229,13 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c if id(new_param) != id(param): new_parameters.append((name, new_param)) - if param_meta.is_tied: - tied_parameters[id(param)] = new_param + param_cache[name] = new_param + + # take care of tied parameters + for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): + param_meta: "ParameterMeta" = getattr(param, "meta") + if param_meta.tied_to is not None: + new_parameters.append((name, param_cache[param_meta.tied_to])) for name, new_param in new_parameters: prefix_and_field = name.rsplit(".", maxsplit=1) @@ -224,8 +245,6 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c else: parent_mod = graph_module field = name - if name not in param_cache: - param_cache[name] = new_param setattr(parent_mod, field, new_param) return graph_module diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 0d08adb8468..9fa254cf289 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -73,15 +73,13 @@ class ParameterMeta: Parameter meta information. Attributes: - - is_tied (`bool`, defaults to `False`): - Whether the parameter is shared accross multiple modules. + - tied_to (`Optional[str]`, defaults to `None`): + The name of host parameter the current parameter is tied to, i.e., `lm_head.weight` is tied + to `embedding_tokens.weight`. If None then it's unique and not shared across modules. - is_parallel (`bool`, defaults to `False`): Whether the parameter needs to be parallelized. - - is_modified_meta (`bool`, defaults to `False`): - Whether the meta has already been modified since initialization. - - need_initialize (`bool`, defaults to `False`): Whether need to manually initialize weights if not provided in weight map. @@ -94,10 +92,8 @@ class ParameterMeta: - mapping (`Dict[HashableSlice, ParameterSlice]`): Mapping between the current parameter and weight tensor stored in weight map. """ - - is_tied: bool = False + tied_to: Optional[str] = None is_parallel: bool = False - is_modified_meta: bool = False need_initialize: bool = False init_fn: Optional[Callable] = None dim: int = 0 diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 4c9fb7ad71c..18c17fb26b7 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -301,32 +301,22 @@ def get_current_slice(shape: Tuple[int], axis: int = 0) -> slice: # modify meta information weight_meta: "ParameterMeta" = copy.deepcopy(getattr(mod.weight, "meta")) - if weight_meta.is_modified_meta: - assert weight_meta.is_tied, "only tied parameters could already have modified meta" - else: - weight_meta.need_initialize = True - weight_meta.is_parallel = True - weight_meta.dim = 0 if axis in {"column", "vocab"} else 1 - for _, Slice in weight_meta.mapping.items(): - Slice.index = get_current_slice(Slice.shape, weight_meta.dim) - weight_meta.is_modified_meta = True - + weight_meta.need_initialize = True + weight_meta.is_parallel = True + weight_meta.dim = 0 if axis in {"column", "vocab"} else 1 + for _, Slice in weight_meta.mapping.items(): + Slice.index = get_current_slice(Slice.shape, weight_meta.dim) setattr(new_mod.weight, "meta", weight_meta) if hasattr(new_mod, "bias") and new_mod.bias is not None: bias_meta: "ParameterMeta" = copy.deepcopy(getattr(mod.bias, "meta")) - if bias_meta.is_modified_meta: - assert bias_meta.is_tied, "only tied parameters could already have modified meta" - else: - bias_meta.need_initialize = True - bias_meta.init_fn = torch.zero_ - bias_meta.is_modified_meta = True - - if weight_meta.dim == 0: - bias_meta.dim = 0 - bias_meta.is_parallel = True - for _, Slice in bias_meta.mapping.items(): - Slice.index = get_current_slice(Slice.shape, 0) + bias_meta.need_initialize = True + bias_meta.init_fn = torch.zero_ + if weight_meta.dim == 0: + bias_meta.dim = 0 + bias_meta.is_parallel = True + for _, Slice in bias_meta.mapping.items(): + Slice.index = get_current_slice(Slice.shape, 0) setattr(new_mod.bias, "meta", bias_meta) def handle_linear(self, node: Node, ctx: "ParallelExecutionCtx") -> None: diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index b7b1ccd41c8..e4852bde90f 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -260,21 +260,15 @@ def __exit__(self, exc_type, exc_value, traceback): def initialize_parameter_meta(model: nn.Module) -> None: - parameter_ids = set() for name, tensor in model.named_parameters(remove_duplicate=False): - key = id(tensor) - if key not in parameter_ids: - setattr( - tensor, - "meta", - ParameterMeta( - dim=0, - mapping={HashableSlice(None, None, None): ParameterSlice(source=name, shape=tuple(tensor.shape))}, - ), - ) - parameter_ids.add(key) - else: - tensor.meta.is_tied = True + setattr( + tensor, + "meta", + ParameterMeta( + dim=0, + mapping={HashableSlice(None, None, None): ParameterSlice(source=name, shape=tuple(tensor.shape))}, + ), + ) @torch.no_grad From d68df892901a0cd62e46e4965f6e5e12210a5bde Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Mon, 2 Sep 2024 20:07:37 +0200 Subject: [PATCH 17/24] add weights tying for nanotron backend --- optimum/fx/parallelization/backend/base.py | 2 +- .../fx/parallelization/backend/nanotron.py | 67 +++++++++++++++++-- optimum/fx/parallelization/core.py | 1 + 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index cb6609e14fd..c1e3e914f40 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -79,7 +79,7 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co if key not in parameter_mp: parameter_mp[key] = name else: - param_meta: 'ParameterMeta' = getattr(tensor, 'meta') + param_meta: "ParameterMeta" = getattr(tensor, "meta") param_meta.tied_to = parameter_mp[key] return graph_module diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 77f12a5d4ab..644365e090e 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -12,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Nanotron specific imports +from collections import defaultdict from typing import TYPE_CHECKING, Optional, Tuple +import torch.distributed as dist import torch.nn as nn - -# Nanotron specific imports from nanotron.config import Config as NanotronConfig +from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import ( @@ -25,18 +27,24 @@ TensorParallelEmbedding, TensorParallelRowLinear, ) +from nanotron.parallel.tied_parameters import tie_parameters from torch.fx import GraphModule from .base import BackEnd if TYPE_CHECKING: - from ..core import Config, ParallelExecutionCtx + from ..core import Config, ParallelExecutionCtx, ParameterMeta class NanotronBackend(BackEnd): - def __init__(self, nanotron_config: NanotronConfig) -> None: + """ + Backend class which glues optimum fx parallelization context and nanotron context. + """ + + def __init__(self, nanotron_config: NanotronConfig, nanotron_context: ParallelContext) -> None: self.config = nanotron_config + self.context = nanotron_context def create_column_parallel_linear( self, @@ -131,7 +139,9 @@ def create_parallel_embedding( def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: + param_cache, tied_parameter_groups = parallel_ctx.param_cache, defaultdict(list) for name, param in graph_module.named_parameters(): + param_meta: "ParameterMeta" = getattr(param, "meta") if not isinstance(param, NanotronParameter): prefix_and_field = name.rsplit(".", maxsplit=1) if len(prefix_and_field) == 2: @@ -144,5 +154,52 @@ def post_process( assert ( param.device == parallel_ctx.current_device ), "all parameters should already be on the current device" - new_param = NanotronParameter(param.detach(), param.requires_grad) + + if name not in param_cache: + new_param = NanotronParameter(param.detach(), param.requires_grad) + param_cache[name] = new_param + else: + raise RuntimeError( + "Found already initialized parameter which is not a nanotron parameter in parameter cache!" + ) setattr(parent_mod, field, new_param) + elif name not in param_cache: + param_cache[name] = param + + # now we have NanotronParameter anyway + nanotron_param: NanotronParameter = param_cache[name] + # we have tied the parameter, in the very first compilation. + if nanotron_param.is_tied: + continue + + # not tied, must be the very first compilation + assert parallel_ctx.compile_times == 0, "illegal path for recompilation" + host_name = param_meta.tied_to if param_meta.tied_to is not None else name + tied_parameter_groups[host_name].append(name) + + # take care of weights tying + for _, groups in tied_parameter_groups: + # just one parameter in the group, no need for tying + if len(groups) == 1: + continue + + ties = [ + ( + target, + ( + self.context.get_global_rank( + # TODO: modify this accordingly when ep is supported + ep_rank=0, + # TODO: modify this accordingly when pp is supported + pp_rank=0, + dp_rank=dist.get_rank(self.context.dp_pg), + tp_rank=dist.get_rank(self.context.tp_pg), + ), + ), + ) + for target in groups + ] + # no new parameters will be created because we make sure every param is already a NanotronParameter + tie_parameters(graph_module, ties, self.context, dist.ReduceOp.SUM) + + return graph_module diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 9fa254cf289..19a70a5a7a3 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -92,6 +92,7 @@ class ParameterMeta: - mapping (`Dict[HashableSlice, ParameterSlice]`): Mapping between the current parameter and weight tensor stored in weight map. """ + tied_to: Optional[str] = None is_parallel: bool = False need_initialize: bool = False From c752e294527b918d9b802a34b84e974d059a945b Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 3 Sep 2024 17:11:44 +0200 Subject: [PATCH 18/24] fix --- optimum/fx/parallelization/backend/base.py | 14 ++---- .../fx/parallelization/backend/nanotron.py | 46 +++++++++++++------ optimum/fx/parallelization/core.py | 16 +++++-- .../parallelization/parallel_layers/linear.py | 7 +-- .../parallelization/test_tensor_parallel.py | 10 ++-- 5 files changed, 53 insertions(+), 40 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index c1e3e914f40..9b9ad782dd4 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn from torch.fx import GraphModule +from ..core import Config, ParallelExecutionCtx, ParameterMeta from ..distributed import scatter from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from ..passes import ( @@ -30,10 +31,6 @@ ) -if TYPE_CHECKING: - from ..core import Config, ParallelExecutionCtx, ParameterMeta - - class BackEnd(ABC): @abstractmethod def create_column_parallel_linear( @@ -67,7 +64,6 @@ def create_parallel_embedding( ) -> nn.Module: raise NotImplementedError - @abstractmethod def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our @@ -83,11 +79,9 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co param_meta.tied_to = parameter_mp[key] return graph_module - @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: return graph_module - @abstractmethod def init_parallelization_pass_pipeline( self, ) -> PassPipeline: @@ -165,7 +159,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c param_meta: "ParameterMeta" = getattr(param, "meta") # skip tied parameters for now - if param_meta.tied_to is not None: + if param_meta.tied_to is not None and param_meta.tied_to != name: continue shape = [ @@ -234,7 +228,7 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c # take care of tied parameters for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): param_meta: "ParameterMeta" = getattr(param, "meta") - if param_meta.tied_to is not None: + if param_meta.tied_to is not None and param_meta.tied_to != name: new_parameters.append((name, param_cache[param_meta.tied_to])) for name, new_param in new_parameters: diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 644365e090e..ffd6a5396e6 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -13,28 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # Nanotron specific imports +import importlib.util from collections import defaultdict from typing import TYPE_CHECKING, Optional, Tuple import torch.distributed as dist import torch.nn as nn -from nanotron.config import Config as NanotronConfig -from nanotron.parallel import ParallelContext -from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.nn import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) -from nanotron.parallel.tied_parameters import tie_parameters from torch.fx import GraphModule +from ..core import Config, ParallelExecutionCtx, ParameterMeta from .base import BackEnd +# Check if nanotron is installed +_nanotron_available = importlib.util.find_spec("nanotron") is not None + if TYPE_CHECKING: - from ..core import Config, ParallelExecutionCtx, ParameterMeta + from nanotron.config import Config as NanotronConfig + from nanotron.parallel import ParallelContext + from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + ) class NanotronBackend(BackEnd): @@ -42,7 +43,10 @@ class NanotronBackend(BackEnd): Backend class which glues optimum fx parallelization context and nanotron context. """ - def __init__(self, nanotron_config: NanotronConfig, nanotron_context: ParallelContext) -> None: + def __init__(self, nanotron_config: "NanotronConfig", nanotron_context: "ParallelContext") -> None: + if not _nanotron_available: + raise ImportError("Nanotron is not installed. Please install it to use NanotronBackend.") + self.config = nanotron_config self.context = nanotron_context @@ -53,7 +57,10 @@ def create_column_parallel_linear( sequence_parallel: bool, gather_output: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelColumnLinear: + ) -> "TensorParallelColumnLinear": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear + if gather_output: raise ValueError( "Nanotron backend does not support `gather_output=True` in `TensorParallelColumnLinear` yet" @@ -84,7 +91,10 @@ def create_row_parallel_linear( sequence_parallel: bool, input_is_parallel: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelRowLinear: + ) -> "TensorParallelRowLinear": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear + if not input_is_parallel: raise ValueError( "Nanotron backend does not support `input_is_parallel=True` in `TensorParallelRowLinear` yet" @@ -114,7 +124,10 @@ def create_parallel_embedding( parallel_ctx: "ParallelExecutionCtx", sequence_parallel: bool, contiguous_chunks: Optional[Tuple[int]] = None, - ) -> TensorParallelEmbedding: + ) -> "TensorParallelEmbedding": + from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding + if sequence_parallel and self.config.parallelism.tp_mode != TensorParallelLinearMode.REDUCE_SCATTER: raise ValueError( "`sequence_parallel` can not be activated when `tp_mode` is not set to `REDUCE_SCATTER` in nanotron backend" @@ -139,6 +152,9 @@ def create_parallel_embedding( def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: + from nanotron.parallel.parameters import NanotronParameter + from nanotron.parallel.tied_parameters import tie_parameters + param_cache, tied_parameter_groups = parallel_ctx.param_cache, defaultdict(list) for name, param in graph_module.named_parameters(): param_meta: "ParameterMeta" = getattr(param, "meta") diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 19a70a5a7a3..0fb8517fd99 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -14,13 +14,15 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn -from .backend import BackEnd, DefaultBackend + +if TYPE_CHECKING: + from .backend import BackEnd class HashableSlice: @@ -113,7 +115,7 @@ class ParallelExecutionCtx: - current_device (`torch.device`): Device correpsonding to the current process. - - backend (`BackEnd`, defaults to `DefaultBackEnd`): + - backend (`Optional[BackEnd]`, defaults to `None`): Backend instance which converts layers into their parallelized counterparts. - example_inputs (`List[Any]`): @@ -144,7 +146,7 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device - backend: BackEnd = DefaultBackend() + backend: Optional["BackEnd"] = None example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) param_cache: Dict[str, nn.Parameter] = field(default_factory=dict) @@ -152,6 +154,12 @@ class ParallelExecutionCtx: last_optimized_module: Optional[nn.Module] = None compile_times: int = 0 + def __post_init__(self): + if self.backend is None: + from .backend import DefaultBackend + + self.backend = DefaultBackend() + @dataclass class Config: diff --git a/optimum/fx/parallelization/parallel_layers/linear.py b/optimum/fx/parallelization/parallel_layers/linear.py index 9a995eb2784..ba44f5f15a2 100644 --- a/optimum/fx/parallelization/parallel_layers/linear.py +++ b/optimum/fx/parallelization/parallel_layers/linear.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING - import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from ..core import ParallelExecutionCtx from ..distributed import ( differentiable_all_gather, differentiable_all_reduce_sum, @@ -28,10 +27,6 @@ from ..utils import ensure_divisibility -if TYPE_CHECKING: - from ..core import ParallelExecutionCtx - - class ColumnParallelLinear(nn.Module): """ Linear layer with column parallelism. diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index 64ce80b78fb..f8aab7a912f 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -80,7 +80,7 @@ def run_test_all_rank_results_match(rank: int, world_size: int, model_id: str, m device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) logits = model(**inputs)[0] tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size) @@ -106,7 +106,7 @@ def run_test_parameters_persist_bewteen_recompile( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) # different shape to trigger recompile @@ -141,7 +141,7 @@ def run_test_parallel_results_matches_non_parallel( device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) set_seed(SEED) @@ -153,7 +153,7 @@ def run_test_parallel_results_matches_non_parallel( tp_group = dist.new_group() set_seed(SEED) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) parallel_logits = model(**inputs)[0] torch.testing.assert_close(logits.cpu(), parallel_logits.cpu(), rtol=1e-4, atol=1e-4) @@ -169,7 +169,7 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode # prepare config and context device = torch.device(type="cuda", index=torch.cuda.current_device()) ctx = ParallelExecutionCtx(tp_group=tp_group, current_device=device) - model = parallelize_model(model_id, ctx, skip_load_weights=True, **model_kwargs) + model = parallelize_model(ctx, model_id_or_path=model_id, skip_load_weights=True, **model_kwargs) inputs = prepare_dummy_inputs(model.config) model(**inputs) From 3a1a195963849795d87e89a25da6c484df136510 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Wed, 4 Sep 2024 00:59:26 +0200 Subject: [PATCH 19/24] fix --- tests/fx/parallelization/test_tensor_parallel.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/fx/parallelization/test_tensor_parallel.py b/tests/fx/parallelization/test_tensor_parallel.py index f8aab7a912f..4c9ba131e4b 100644 --- a/tests/fx/parallelization/test_tensor_parallel.py +++ b/tests/fx/parallelization/test_tensor_parallel.py @@ -189,11 +189,7 @@ def run_test_tie_word_embeddings(rank: int, world_size: int, model_id: str, mode if isinstance(mod, ColumnParallelLinear): lm_head_weight = mod.weight break - assert ( - id(embedding_weight) == id(lm_head_weight) - and hasattr(embedding_weight, "meta") - and embedding_weight.meta.is_tied - ) + assert id(embedding_weight) == id(lm_head_weight) dist.barrier(tp_group) tearDown() From 0ff39bb9373cc5b4efa8b5adb7542b8db1c40594 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 20 Sep 2024 18:56:47 +0200 Subject: [PATCH 20/24] address comments --- optimum/fx/parallelization/api.py | 2 +- optimum/fx/parallelization/backend/__init__.py | 2 +- optimum/fx/parallelization/backend/base.py | 6 +++--- optimum/fx/parallelization/backend/nanotron.py | 6 +++--- optimum/fx/parallelization/core.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index bf4ef8fbd12..bdba7de18e8 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -61,7 +61,7 @@ def parallelize_model( Parallel execution context containing process groups the current process belongs to. *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. - model_id_or_path (`str`): + model_id_or_path (`Optional[str]`, defaults to `None`): Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights of the model. model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`): diff --git a/optimum/fx/parallelization/backend/__init__.py b/optimum/fx/parallelization/backend/__init__.py index abae86a6185..1f7252122ec 100644 --- a/optimum/fx/parallelization/backend/__init__.py +++ b/optimum/fx/parallelization/backend/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .base import BackEnd, DefaultBackend +from .base import Backend, DefaultBackend diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 11cdd766065..728aa807a53 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -38,7 +38,7 @@ ) -class BackEnd(ABC): +class Backend(ABC): @abstractmethod def create_column_parallel_linear( self, @@ -85,7 +85,7 @@ def create_parallel_cross_entropy( def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our - passes don't. + passes do not. """ parameter_mp = {} for name, tensor in graph_module.named_parameters(remove_duplicate=False): @@ -121,7 +121,7 @@ def init_parallelization_pass_pipeline( ) -class DefaultBackend(BackEnd): +class DefaultBackend(Backend): def create_column_parallel_linear( self, mod: nn.Linear, diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index ffd6a5396e6..847b7c4b924 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -22,13 +22,13 @@ from torch.fx import GraphModule from ..core import Config, ParallelExecutionCtx, ParameterMeta -from .base import BackEnd +from .base import Backend # Check if nanotron is installed _nanotron_available = importlib.util.find_spec("nanotron") is not None -if TYPE_CHECKING: +if TYPE_CHECKING and _nanotron_available: from nanotron.config import Config as NanotronConfig from nanotron.parallel import ParallelContext from nanotron.parallel.tensor_parallel.nn import ( @@ -38,7 +38,7 @@ ) -class NanotronBackend(BackEnd): +class NanotronBackend(Backend): """ Backend class which glues optimum fx parallelization context and nanotron context. """ diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 0fb8517fd99..c3c03f93373 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: - from .backend import BackEnd + from .backend import Backend class HashableSlice: @@ -115,7 +115,7 @@ class ParallelExecutionCtx: - current_device (`torch.device`): Device correpsonding to the current process. - - backend (`Optional[BackEnd]`, defaults to `None`): + - backend (`Optional[Backend]`, defaults to `None`): Backend instance which converts layers into their parallelized counterparts. - example_inputs (`List[Any]`): @@ -146,7 +146,7 @@ class ParallelExecutionCtx: tp_group: dist.ProcessGroup current_device: torch.device - backend: Optional["BackEnd"] = None + backend: Optional["Backend"] = None example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) param_cache: Dict[str, nn.Parameter] = field(default_factory=dict) From 5137f68e74fe0851cfd64126531221e0971ec042 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 20 Sep 2024 19:24:34 +0200 Subject: [PATCH 21/24] address comments --- optimum/fx/parallelization/backend/base.py | 32 ++++++++++++++++--- .../fx/parallelization/backend/nanotron.py | 4 +++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 728aa807a53..79100e7154a 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -26,8 +26,8 @@ from ..parallel_layers import ( ColumnParallelLinear, RowParallelLinear, - VocabParallelEmbedding, VocabParallelCrossEntropyLoss, + VocabParallelEmbedding, sharded_cross_entropy_wrapper_fn, ) from ..passes import ( @@ -39,6 +39,22 @@ class Backend(ABC): + """ + Abstract base class for implementing parallelization backends. + + This class defines the interface for creating parallel versions of various + PyTorch modules and operations. Subclasses should implement the abstract + methods to provide specific parallelization strategies. + + Methods: + create_column_parallel_linear: Create a column-parallel version of a linear layer. + create_row_parallel_linear: Create a row-parallel version of a linear layer. + create_parallel_embedding: Create a parallel version of an embedding layer. + create_parallel_cross_entropy: Create a parallel version of cross entropy loss. + pre_process: Perform pre-processing on the graph module before parallelization. + post_process: Perform post-processing on the graph module after parallelization. + init_parallelization_pass_pipeline: Initialize the parallelization pass pipeline. + """ @abstractmethod def create_column_parallel_linear( self, @@ -82,6 +98,7 @@ def create_parallel_cross_entropy( else: return sharded_cross_entropy_wrapper_fn(process_group=parallel_ctx.tp_group) + @abstractmethod def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our @@ -97,12 +114,16 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co param_meta.tied_to = parameter_mp[key] return graph_module + @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: + """ + This method is called after the parallelization passes have been applied. It is used to perform any backend-specific + post-processing on the graph module. + """ return graph_module - def init_parallelization_pass_pipeline( - self, - ) -> PassPipeline: + @abstractmethod + def init_parallelization_pass_pipeline(self) -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. @@ -165,6 +186,9 @@ def create_parallel_embedding( return VocabParallelEmbedding(parallel_ctx, mod) def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: + """ + Initialize or load parameters from checkpoint, and tie them if needed. + """ world_size = dist.get_world_size(ctx.tp_group) tp_rank = dist.get_rank(ctx.tp_group) diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 847b7c4b924..76e0553a795 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -152,6 +152,10 @@ def create_parallel_embedding( def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: + """ + Convert parameters to `NanotronParameter` and tie them if needed. Note that we don't initialize or load weights here + because nanotron will do that for us in the trainer class. + """ from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.tied_parameters import tie_parameters From 9dd77defeaf9c4c2ac11a202e38ac626efc37781 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 20 Sep 2024 19:31:41 +0200 Subject: [PATCH 22/24] fix --- optimum/fx/parallelization/backend/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 79100e7154a..4f6c2264177 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -55,6 +55,7 @@ class Backend(ABC): post_process: Perform post-processing on the graph module after parallelization. init_parallelization_pass_pipeline: Initialize the parallelization pass pipeline. """ + @abstractmethod def create_column_parallel_linear( self, @@ -116,7 +117,7 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: - """ + """ This method is called after the parallelization passes have been applied. It is used to perform any backend-specific post-processing on the graph module. """ From a375b6df0cab7363bccf802056ad457ea9b057dc Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 20 Sep 2024 19:33:28 +0200 Subject: [PATCH 23/24] fix --- optimum/fx/parallelization/backend/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 4f6c2264177..68834ab4726 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -99,7 +99,6 @@ def create_parallel_cross_entropy( else: return sharded_cross_entropy_wrapper_fn(process_group=parallel_ctx.tp_group) - @abstractmethod def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule: """ Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our @@ -115,7 +114,6 @@ def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", co param_meta.tied_to = parameter_mp[key] return graph_module - @abstractmethod def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: """ This method is called after the parallelization passes have been applied. It is used to perform any backend-specific @@ -123,7 +121,6 @@ def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", c """ return graph_module - @abstractmethod def init_parallelization_pass_pipeline(self) -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: From 40880a3dc6afabdad2a1d372d642af09ff12f577 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Fri, 20 Sep 2024 19:42:23 +0200 Subject: [PATCH 24/24] fix --- optimum/fx/parallelization/backend/base.py | 5 +++++ optimum/fx/parallelization/backend/nanotron.py | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/optimum/fx/parallelization/backend/base.py b/optimum/fx/parallelization/backend/base.py index 68834ab4726..aeeb0d693ef 100644 --- a/optimum/fx/parallelization/backend/base.py +++ b/optimum/fx/parallelization/backend/base.py @@ -183,6 +183,11 @@ def create_parallel_embedding( return VocabParallelEmbedding(parallel_ctx, mod) + def create_parallel_cross_entropy( + self, mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy], parallel_ctx: ParallelExecutionCtx + ): + return super().create_parallel_cross_entropy(mod_or_fn, parallel_ctx) + def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module: """ Initialize or load parameters from checkpoint, and tie them if needed. diff --git a/optimum/fx/parallelization/backend/nanotron.py b/optimum/fx/parallelization/backend/nanotron.py index 76e0553a795..3b6828a2928 100644 --- a/optimum/fx/parallelization/backend/nanotron.py +++ b/optimum/fx/parallelization/backend/nanotron.py @@ -15,10 +15,11 @@ # Nanotron specific imports import importlib.util from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F from torch.fx import GraphModule from ..core import Config, ParallelExecutionCtx, ParameterMeta @@ -149,6 +150,11 @@ def create_parallel_embedding( contiguous_chunks=contiguous_chunks, ) + def create_parallel_cross_entropy( + self, mod_or_fn: Union[nn.CrossEntropyLoss, F.cross_entropy], parallel_ctx: "ParallelExecutionCtx" + ): + return super().create_parallel_cross_entropy(mod_or_fn, parallel_ctx) + def post_process( self, graph_module: GraphModule, parallel_ctx: "ParallelExecutionCtx", config: "Config" ) -> nn.Module: