forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parallelize_llama.py
375 lines (328 loc) · 13.7 KB
/
parallelize_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
from collections import defaultdict
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import (
CPUOffloadPolicy,
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_if_feature_in_pytorch
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
if parallel_dims.tp_enabled:
if (
job_config.experimental.enable_async_tensor_parallel
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm is not compatible with torch.compile yet. "
"Please use rmsnorm or layernorm."
)
apply_compile(model)
if (
parallel_dims.dp_shard_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
try:
dp_mesh = (
world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
)
except IndexError:
# note: this is a workaround of the above logic for old pytorch version
# where https://github.com/pytorch/pytorch/pull/138945 is not included
# throw a warning to encourage users to upgrade to a newer pytorch version
check_if_feature_in_pytorch(
"DeviceMesh flattening over 3D+ meshes",
"https://github.com/pytorch/pytorch/pull/138945",
"2.6.0.dev20241030",
)
# TODO: remove this workaround once PyTorch 2.6 is released
dp_mesh_dim_names = (
("dp_replicate", "dp_shard")
if parallel_dims.dp_replicate_enabled
else ("dp",)
)
# note that mesh can only be flattened from the finest-grained mesh dimensions
dp_mesh = (
world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp")
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
apply_fsdp(
model,
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
)
def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
if enable_float8:
# TODO(vkuzo): once float8 configuration supports delayed scaling,
# add a check here to enforce supported float8 all-gather configurations
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
rowwise_parallel, colwise_parallel, prepare_module_input = (
Float8RowwiseParallel,
Float8ColwiseParallel,
PrepareFloat8ModuleInput,
)
else:
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)
# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
logger.info(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
# for selective op activation checkpointing
_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
}
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)
if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
use_layer_sac = ac_config.selective_ac_option.isdigit()
if not use_op_sac and not use_layer_sac:
raise ValueError(
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)
return _custom_policy
def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
)
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
else:
return module
def apply_ac(model: nn.Module, ac_config):
"""Apply activation checkpointing to the model."""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
model.layers.register_module(layer_id, transformer_block)
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)
logger.info("Compiling each TransformerBlock with torch.compile")
def apply_fsdp(
model: nn.Module,
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
cpu_offload: bool = False,
):
"""
Apply data parallelism to the model. FSDP2 is used here.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()
for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
def apply_ddp(
model: nn.Module,
dp_mesh: DeviceMesh,
enable_compile: bool,
enable_compiled_autograd: bool,
):
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
logger.info("Applied DDP to the model")