-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Distributed(Tensor Parallel) Inference Recipe #2245
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2245
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3f2d6ce with merge base baae232 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…/torchtune into distributed_inference
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2245 +/- ##
==========================================
+ Coverage 64.30% 66.84% +2.53%
==========================================
Files 352 352
Lines 20566 20598 +32
==========================================
+ Hits 13225 13768 +543
+ Misses 7341 6830 -511 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor comments/question
@@ -45,6 +48,18 @@ | |||
"dev" not in torch_version and torch_version_ge("2.6.0") | |||
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220") | |||
|
|||
BASE_LLAMA_TP_PLAN = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need one for each family of models? If so, is this file the right place to store it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am also curious what's the best place to store this info. The plan should be shared within llama3, 3.1, and 3.2, but we should define unique plans for 3.2 vision and 4. Maybe it should be stored in _model_builders.py
? What's a better place?
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()), | ||
"output": ColwiseParallel(output_layouts=Replicate()), | ||
"layers.*.attn.q_proj": ColwiseParallel(), | ||
"layers.*.attn.k_proj": ColwiseParallel(), | ||
"layers.*.attn.v_proj": ColwiseParallel(), | ||
"layers.*.attn.output_proj": RowwiseParallel(), | ||
"layers.*.mlp.w1": ColwiseParallel(), | ||
"layers.*.mlp.w2": RowwiseParallel(), | ||
"layers.*.mlp.w3": ColwiseParallel(), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: is this row/col the optimal setup? or is it somewhat arbitrary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For matrix multiplication, we just need to make sure one matrix is Col
and the other is Row
. For example, because the math is mlp.w2(mlp.w1(x) * mlp.w3(x))
, therefore we just need to make sure that w1 and w3 are col
and w2 is row
, or the other way around.
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]: | ||
""" | ||
Get the TP plan for a given model type. | ||
|
||
Args: | ||
model_type (str): The model type to get the TP plan for. | ||
|
||
Returns: | ||
Dict[str, str]: A dictionary mapping layer names to their corresponding TP plan. | ||
""" | ||
# For now, we only support base TP plan, will add more plan later | ||
return BASE_LLAMA_TP_PLAN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this is a v0, but should we add something like:
if model_type not in LLAMA_MODEL_TYPES:
raise "TP only supported for llama type models"
Returns: | ||
nn.Module: Adjusted model. | ||
""" | ||
for transformer_block in model.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will break for vision model, since we do model.decoder.layers, unless we call adjust_attention_for_tp(model=model.decoder)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! yeah I had this in my local changes, trying to make vision 3.2 work. I made the function a bit ugly:
def adjust_attention_for_tp():
if hasattr(model, "layers") is False:
model = model.decoder
Let me know if you have better ideas.
""" | ||
for transformer_block in model.layers: | ||
# Adjust attention module to use the local number of heads | ||
attn_layer = transformer_block.attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is ok, but maybe a more robust option would be to look for the module type == SelfAttentionLayer
Expects the YAML to look like: | ||
system: You are a helpful AI assistant. | ||
user: What is the capital of France? | ||
|
||
or if it includes an image: | ||
system: You are a helpful AI assistant. | ||
user: | ||
image: url or path_to_image | ||
text: Describe the image in detail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should denote that it is a ::codeblock: yaml
, ask some llm for formating
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) | ||
self._logger = utils.get_logger(cfg.log_level) | ||
# Set up distributed env | ||
dist.init_process_group("cuda:nccl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have seen in other parts of the code this resulting in errors if we dont do init_process_group("cuda:nccl,cpu:gloo")
# Set up tenosr parallel device mesh | ||
tp_degree = dist.get_world_size() # Using all GPUs for TP | ||
tp_mesh_shape = (tp_degree,) | ||
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: should we worry about other device types, e.g. npu?
|
||
# This method will convert the full model state dict into a sharded state | ||
# dict and load into the model | ||
training.load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: as a rule of thumb, i think its worth using key arguments for all args, not only stric and cpu_offload
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s" | ||
) | ||
self._logger.info( | ||
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: i think in general we prefer to use GiB, otherwise it may appear that we used more memory than the GPU has available --> to change replace 1e9 with /1024/1204
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
dev/generate_v2.py
and added TP to the recipe. The main change is in__init__
and__setup__
.distribute
to parallel each module.load_from_full_model_state_dict
to general parallelism, not only FSDP.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example