Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Refactor to Introduce Backend Abstraction #2011

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
50fcfc0
modify parallelization strategy
zhenglongjiepheonix Aug 14, 2024
4114d3b
only support model id in api now
zhenglongjiepheonix Aug 14, 2024
c689402
more comments
zhenglongjiepheonix Aug 15, 2024
252c3b7
more comments
zhenglongjiepheonix Aug 16, 2024
1be77ed
Merge remote-tracking branch 'upstream/main' into longjie/generalize_…
zhenglongjiepheonix Aug 16, 2024
22d6766
address comments
zhenglongjiepheonix Aug 20, 2024
febac9b
remove idle runner
zhenglongjiepheonix Aug 20, 2024
bf99175
fix
zhenglongjiepheonix Aug 20, 2024
4d9d036
format
zhenglongjiepheonix Aug 20, 2024
44a87f4
more comments
zhenglongjiepheonix Aug 26, 2024
513d516
generalize api & add backend abstraction
zhenglongjiepheonix Aug 26, 2024
8335a35
fix
zhenglongjiepheonix Aug 27, 2024
d051217
copyright
zhenglongjiepheonix Aug 27, 2024
6b03855
fix api
zhenglongjiepheonix Aug 28, 2024
6466ccc
Merge remote-tracking branch 'upstream/main' into longjie/add_backend…
zhenglongjiepheonix Aug 29, 2024
b4166ac
move weights intialization inside post process
zhenglongjiepheonix Aug 29, 2024
576104c
seperate meta update and parallel layer construction
zhenglongjiepheonix Aug 30, 2024
8bbc2e9
move weight intialization & binding inside backend
zhenglongjiepheonix Sep 2, 2024
d68df89
add weights tying for nanotron backend
zhenglongjiepheonix Sep 2, 2024
c752e29
fix
zhenglongjiepheonix Sep 3, 2024
82d1cf9
resolve
zhenglongjiepheonix Sep 3, 2024
3a1a195
fix
zhenglongjiepheonix Sep 3, 2024
b5b371f
fix conflict
zhenglongjiepheonix Sep 20, 2024
0ff39bb
address comments
zhenglongjiepheonix Sep 20, 2024
5137f68
address comments
zhenglongjiepheonix Sep 20, 2024
9dd77de
fix
zhenglongjiepheonix Sep 20, 2024
a375b6d
fix
zhenglongjiepheonix Sep 20, 2024
40880a3
fix
zhenglongjiepheonix Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import importlib
import os
from functools import partial
from typing import Callable, List
from typing import Callable, List, Optional, Type

import torch
import torch.nn as nn
from torch.fx import GraphModule
from transformers import AutoConfig
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline
from .utils import (
MetaAwareMethodsPatcher,
download_model_from_hf,
Expand All @@ -34,32 +34,40 @@

def parallelize_backend(
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config
) -> GraphModule:
) -> nn.Module:
ctx.example_inputs = example_inputs
pass_pipeline = build_parallel_pass_pipeline()
pass_pipeline = ctx.backend.init_parallelization_pass_pipeline()
graph_module = ctx.backend.pre_process(graph_module=graph_module, ctx=ctx, config=config)
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
finalized_module = ctx.backend.post_process(graph_module=graph_module, ctx=ctx, config=config)
ctx.compile_times += 1
ctx.last_optimized_graph_module = graph_module
return graph_module
ctx.last_optimized_module = finalized_module
return finalized_module


def parallelize_model(
model: str,
parallel_ctx: ParallelExecutionCtx,
*model_args,
model_id_or_path: Optional[str] = None,
model_cls: Optional[Type[PreTrainedModel]] = None,
model_config: Optional[PretrainedConfig] = None,
**kwargs,
) -> Callable:
"""
API for automatic model parallelism through Pytorch FX.

Args:
model (`str`):
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
parallel_ctx (`ParallelExecutionCtx`):
Parallel execution context containing process groups the current process belongs to.
*model_args (`Any`):
Additional postional arguments for intializing the model if a model id is passed.
model_id_or_path (`str`):
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`):
Model class in transformers library, i.e, `LlamaForCausalLM`.
model_config (`Optional[PretrainedConfig]`, defaults to `None`):
Model config to intialize the model.
revision (`str`, defaults to `main`):
Model revision for weights downloading if a model id is passed.
cache_dir (`Optional[str]`, defaults to `None`):
Expand All @@ -82,29 +90,39 @@ def parallelize_model(
setattr(parallel_config, k, v)
kwargs.pop(k)

is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
if model_id_or_path is not None and (model_cls is not None or model_config is not None):
raise ValueError(
"Can not accept passing in all of `model_id_or_path`, `model_cls` and `model_config`. Only specify "
"`model_id_or_path` or `model_cls` and `model_config` because there might be conflicts otherwise"
)
else:
hf_folder = model

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)
# Init model instance
if model_id_or_path is not None:
is_local = os.path.isdir(model_id_or_path)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model_id_or_path,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model_id_or_path

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)

# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])

if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model_id_or_path, cache_dir, hf_folder)
elif model_cls is None or model_config is None:
raise ValueError("must provide `model_cls` and `model_config` in the case of not providing `model_id_or_path`")

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
Expand Down
15 changes: 15 additions & 0 deletions optimum/fx/parallelization/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BackEnd, DefaultBackend
244 changes: 244 additions & 0 deletions optimum/fx/parallelization/backend/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule

from ..core import Config, ParallelExecutionCtx, ParameterMeta
from ..distributed import scatter
from ..parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from ..passes import (
ParallelAxisSolverPass,
ParallelLayerAnnotatePass,
ParallelLayerReplacePass,
PassPipeline,
)


class BackEnd(ABC):
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def create_column_parallel_linear(
self,
mod: nn.Linear,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
gather_output: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> nn.Module:
raise NotImplementedError

@abstractmethod
def create_row_parallel_linear(
self,
mod: nn.Linear,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
input_is_parallel: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> nn.Module:
raise NotImplementedError

@abstractmethod
def create_parallel_embedding(
self,
mod: nn.Embedding,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
contiguous_chunks: Optional[Tuple[int]] = None,
) -> nn.Module:
raise NotImplementedError

def pre_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> GraphModule:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's a config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a Config is a data class which records static configurations during the whole process

"""
Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our
passes don't.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Mark tie information right before we run passes because dynamo tracing will alter the parameter name while our
passes don't.
Mark information about tied parameters right before running passes because dynamo tracing alters the names of the parameters while our passes do not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure naming this pre_process is the most adapted considered it just marks for tied weights?

Copy link
Contributor Author

@zhenglongjiepheonix zhenglongjiepheonix Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think generally it means something that needs done before we run passes, and weights tying info marking happens to be one of them

"""
parameter_mp = {}
for name, tensor in graph_module.named_parameters(remove_duplicate=False):
key = id(tensor)
if key not in parameter_mp:
parameter_mp[key] = name
else:
param_meta: "ParameterMeta" = getattr(tensor, "meta")
param_meta.tied_to = parameter_mp[key]
return graph_module

def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module:
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
return graph_module

def init_parallelization_pass_pipeline(
self,
) -> PassPipeline:
"""
Ensemble a pass pipeline which contains the following passes:
1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph.
2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step.
3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes.

Returns:
PassPipeline: the pipeline used for automatic parallelism.
"""
return PassPipeline(
[
ParallelAxisSolverPass(),
ParallelLayerAnnotatePass(),
ParallelLayerReplacePass(),
]
)


class DefaultBackend(BackEnd):
def create_column_parallel_linear(
self,
mod: nn.Linear,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
gather_output: bool,
contiguous_chunks: Tuple[int] | None = None,
) -> nn.Module:
if sequence_parallel or contiguous_chunks is not None:
raise NotImplementedError(
"DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now"
)
return ColumnParallelLinear(parallel_ctx, mod, gather_output)

def create_row_parallel_linear(
self,
mod: nn.Linear,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
input_is_parallel: bool,
contiguous_chunks: Tuple[int] | None = None,
) -> nn.Module:
if sequence_parallel or contiguous_chunks is not None:
raise NotImplementedError(
"DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now"
)
return RowParallelLinear(parallel_ctx, mod, input_is_parallel)

def create_parallel_embedding(
self,
mod: nn.Embedding,
parallel_ctx: "ParallelExecutionCtx",
sequence_parallel: bool,
contiguous_chunks: Tuple[int] | None = None,
) -> nn.Module:
if sequence_parallel or contiguous_chunks is not None:
raise NotImplementedError(
"DefaultBackend does not support `sequence_parallel=True` or specifying contiguous chunks for now"
)

return VocabParallelEmbedding(parallel_ctx, mod)

def post_process(self, graph_module: GraphModule, ctx: "ParallelExecutionCtx", config: "Config") -> nn.Module:
world_size = dist.get_world_size(ctx.tp_group)
tp_rank = dist.get_rank(ctx.tp_group)

new_parameters, param_cache = [], ctx.param_cache
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
# skip initializing new params when recompilation happens
if name in param_cache:
new_parameters.append((name, param_cache[name]))
continue

param_meta: "ParameterMeta" = getattr(param, "meta")
# skip tied parameters for now
if param_meta.tied_to is not None and param_meta.tied_to != name:
continue

shape = [
param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim)
for dim in range(param.ndim)
]

# if the device is correct, then we directly use the parameter
if param.device == ctx.current_device:
new_param = param
else:
new_param = nn.Parameter(
torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device),
requires_grad=param.requires_grad,
)

# load weights if possible
for source, target in sorted(param_meta.mapping.items()):
if target.source in ctx.weight_map:
from safetensors import safe_open

with safe_open(ctx.weight_map[target.source], framework="pt", device="cpu") as fp:
tensor_slice = fp.get_slice(target.source)
source_index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
load_index = [
target.index if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]

tensor = tensor_slice[load_index].contiguous()
tensor = torch.empty_like(tensor).copy_(tensor)
with torch.no_grad():
new_param.data[source_index].copy_(tensor)

# weights initialization
if param_meta.need_initialize:
for source, target in sorted(param_meta.mapping.items()):
if target.source in ctx.weight_map:
continue
if not param_meta.is_parallel or tp_rank == 0:
# initialize weight on master rank
weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu")
init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn
init_fn(weight)
weight = weight.to(ctx.current_device)
else:
weight = None
index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
with torch.no_grad():
if param_meta.is_parallel:
scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim)
else:
new_param.data[index].copy_(weight)
setattr(new_param, "meta", param_meta)

if id(new_param) != id(param):
new_parameters.append((name, new_param))
param_cache[name] = new_param

# take care of tied parameters
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
param_meta: "ParameterMeta" = getattr(param, "meta")
if param_meta.tied_to is not None and param_meta.tied_to != name:
new_parameters.append((name, param_cache[param_meta.tied_to]))

for name, new_param in new_parameters:
prefix_and_field = name.rsplit(".", maxsplit=1)
if len(prefix_and_field) == 2:
parent_mod = graph_module.get_submodule(prefix_and_field[0])
field = prefix_and_field[1]
else:
parent_mod = graph_module
field = name
setattr(parent_mod, field, new_param)

return graph_module
Loading
Loading