Skip to content

Commit

Permalink
more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Aug 15, 2024
1 parent 4114d3b commit c689402
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def parallelize_model(
Args:
model (str):
Model to parallelize, a model id on the Huggingface Hub.
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
parallel_ctx (ParallelExecutionCtx):
Parallel execution context containing process groups the current process belongs to.
*model_args (Any):
Expand Down
11 changes: 11 additions & 0 deletions optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def __init__(self, graph: Graph):


class DecompositionInterpreter(Interpreter):
"""
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. Note
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
heuristic based parallelization strategy for them and we can conveniently replace them into their parallelized counterparts
in the orignal graph module.
Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for
parallel axis propagation and analysis, the original graph module is still used for real execution.
"""

def __init__(
self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs
):
Expand Down

0 comments on commit c689402

Please sign in to comment.