Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Distributed(Tensor Parallel) Inference Recipe #2245

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

acisseJZhong
Copy link
Contributor

@acisseJZhong acisseJZhong commented Jan 10, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Copied dev/generate_v2.py and added TP to the recipe. The main change is in __init__ and __setup__.
  • Added model distribute to parallel each module.
  • Generalize load_from_full_model_state_dict to general parallelism, not only FSDP.
  • Added distributed inference config for llama3 70B and 3.1 70B.
  • Fixed a few typos.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2245

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3f2d6ce with merge base baae232 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 10, 2025
@acisseJZhong acisseJZhong changed the title [TP] Added Distributed Inference Recipe Added Distributed(Tensor Parallel) Inference Recipe Jan 10, 2025
@codecov-commenter
Copy link

codecov-commenter commented Jan 10, 2025

Codecov Report

Attention: Patch coverage is 47.05882% with 18 lines in your changes missing coverage. Please review.

Project coverage is 66.84%. Comparing base (baae232) to head (f7615a1).

Files with missing lines Patch % Lines
torchtune/modules/transformer.py 38.46% 8 Missing ⚠️
torchtune/modules/attention.py 53.33% 7 Missing ⚠️
torchtune/modules/feed_forward.py 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2245      +/-   ##
==========================================
+ Coverage   64.30%   66.84%   +2.53%     
==========================================
  Files         352      352              
  Lines       20566    20598      +32     
==========================================
+ Hits        13225    13768     +543     
+ Misses       7341     6830     -511     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@acisseJZhong acisseJZhong requested review from RdoubleA, ebsmothers and joecummings and removed request for RdoubleA and ebsmothers January 10, 2025 07:21
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor comments/question

@@ -45,6 +48,18 @@
"dev" not in torch_version and torch_version_ge("2.6.0")
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")

BASE_LLAMA_TP_PLAN = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need one for each family of models? If so, is this file the right place to store it?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am also curious what's the best place to store this info. The plan should be shared within llama3, 3.1, and 3.2, but we should define unique plans for 3.2 vision and 4. Maybe it should be stored in _model_builders.py? What's a better place?

Comment on lines +52 to +61
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
"output": ColwiseParallel(output_layouts=Replicate()),
"layers.*.attn.q_proj": ColwiseParallel(),
"layers.*.attn.k_proj": ColwiseParallel(),
"layers.*.attn.v_proj": ColwiseParallel(),
"layers.*.attn.output_proj": RowwiseParallel(),
"layers.*.mlp.w1": ColwiseParallel(),
"layers.*.mlp.w2": RowwiseParallel(),
"layers.*.mlp.w3": ColwiseParallel(),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: is this row/col the optimal setup? or is it somewhat arbitrary?

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For matrix multiplication, we just need to make sure one matrix is Col and the other is Row. For example, because the math is mlp.w2(mlp.w1(x) * mlp.w3(x)), therefore we just need to make sure that w1 and w3 are col and w2 is row, or the other way around.

Comment on lines +566 to +577
def get_tp_plan(model_type: str) -> Dict[str, ParallelStyle]:
"""
Get the TP plan for a given model type.

Args:
model_type (str): The model type to get the TP plan for.

Returns:
Dict[str, str]: A dictionary mapping layer names to their corresponding TP plan.
"""
# For now, we only support base TP plan, will add more plan later
return BASE_LLAMA_TP_PLAN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is a v0, but should we add something like:

if model_type not in LLAMA_MODEL_TYPES:
	raise "TP only supported for llama type models"

Returns:
nn.Module: Adjusted model.
"""
for transformer_block in model.layers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will break for vision model, since we do model.decoder.layers, unless we call adjust_attention_for_tp(model=model.decoder)

Copy link
Contributor Author

@acisseJZhong acisseJZhong Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! yeah I had this in my local changes, trying to make vision 3.2 work. I made the function a bit ugly:

def adjust_attention_for_tp():
    if hasattr(model, "layers") is False:
        model = model.decoder

Let me know if you have better ideas.

"""
for transformer_block in model.layers:
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is ok, but maybe a more robust option would be to look for the module type == SelfAttentionLayer

Comment on lines +28 to +36
Expects the YAML to look like:
system: You are a helpful AI assistant.
user: What is the capital of France?

or if it includes an image:
system: You are a helpful AI assistant.
user:
image: url or path_to_image
text: Describe the image in detail.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should denote that it is a ::codeblock: yaml, ask some llm for formating

self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
# Set up distributed env
dist.init_process_group("cuda:nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have seen in other parts of the code this resulting in errors if we dont do init_process_group("cuda:nccl,cpu:gloo")

# Set up tenosr parallel device mesh
tp_degree = dist.get_world_size() # Using all GPUs for TP
tp_mesh_shape = (tp_degree,)
tp_device_mesh = dist.init_device_mesh("cuda", tp_mesh_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: should we worry about other device types, e.g. npu?


# This method will convert the full model state dict into a sharded state
# dict and load into the model
training.load_from_full_model_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: as a rule of thumb, i think its worth using key arguments for all args, not only stric and cpu_offload

Comment on lines +151 to +154
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think in general we prefer to use GiB, otherwise it may appear that we used more memory than the GPU has available --> to change replace 1e9 with /1024/1204

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants