You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I used {2, 4, 8, 16, 32} 8xH100 DGX nodes to run Mixtral style MoE pre-training following the Megatron MoE guidance. The model size is roughly 8x3B, and I only used EP and DP for scaling it up. However, when fixing EP=8 and scaling DP={2, 4, 8, 16, 32} (please see the detailed script below). I observed the speed does not improve after DP=8.
I'm reporting the runtime results below: DP=2, EP=8
[2024-10-25 04:08:43] iteration 4/ 365000 | consumed samples: 8192 | elapsed time per iteration (ms): 33450.2 | learning rate: 6.000000E-09 | global batch size: 2048 | lm loss: 1.049151E+01 | z_loss: 4.799740E+00 | load_balancing_loss: 1.261538E+00 | loss scale: 1.0 | grad norm: 17.428 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:09:49] iteration 6/ 365000 | consumed samples: 12288 | elapsed time per iteration (ms): 33031.0 | learning rate: 9.000000E-09 | global batch size: 2048 | lm loss: 1.049228E+01 | z_loss: 4.805847E+00 | load_balancing_loss: 1.257355E+00 | loss scale: 1.0 | grad norm: 17.315 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:10:54] iteration 8/ 365000 | consumed samples: 16384 | elapsed time per iteration (ms): 32436.5 | learning rate: 1.200000E-08 | global batch size: 2048 | lm loss: 1.049336E+01 | z_loss: 4.809865E+00 | load_balancing_loss: 1.263811E+00 | loss scale: 1.0 | grad norm: 17.803 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
DP=4, EP=8
0 | load_balancing_loss: 1.257406E+00 | loss scale: 1.0 | grad norm: 17.701 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:01:53] iteration 4/ 365000 | consumed samples: 8192 | elapsed time per iteration (ms): 23922.8 | learning rate: 6.000000E-09 | global batch size: 2048 | lm loss: 1.049150E+01 | z_loss: 4.805318E+00 | load_balancing_loss: 1.261618E+00 | loss scale: 1.0 | grad norm: 17.428 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:02:40] iteration 6/ 365000 | consumed samples: 12288 | elapsed time per iteration (ms): 23492.5 | learning rate: 9.000000E-09 | global batch size: 2048 | lm loss: 1.049229E+01 | z_loss: 4.813999E+00 | load_balancing_loss: 1.253946E+00 | loss scale: 1.0 | grad norm: 17.316 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:03:26] iteration 8/ 365000 | consumed samples: 16384 | elapsed time per iteration (ms): 23066.2 | learning rate: 1.200000E-08 | global batch size: 2048 | lm loss: 1.049337E+01 | z_loss: 4.805440E+00 | load_balancing_loss: 1.264813E+00 | loss scale: 1.0 | grad norm: 17.802 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 04:04:13] iteration 10/ 365000 | consumed samples: 20480 | elapsed time per iteration (ms): 23237.8 | learning rate: 1.500000E-08 | global batch size: 2048 | lm loss: 1.049290E+01 | z_loss: 4.794414E+00 | load_balancing_loss: 1.258429E+00 | loss scale: 1.0 | grad norm: 17.760 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
DP=8, EP=8
[2024-10-25 02:59:01] iteration 20/ 365000 | consumed samples: 40960 | elapsed time per iteration (ms): 18228.8 | learning rate: 3.000000E-06 | global batch size: 2048 | lm loss: 9.122166E+00 | z_loss: 2.849494E+00 | load_balancing_loss: 1.052246E+00 | loss scale: 1.0 | grad norm: 3.607 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:59:39] iteration 22/ 365000 | consumed samples: 45056 | elapsed time per iteration (ms): 18968.6 | learning rate: 3.300000E-06 | global batch size: 2048 | lm loss: 9.057307E+00 | z_loss: 2.501926E+00 | load_balancing_loss: 1.040945E+00 | loss scale: 1.0 | grad norm: 3.386 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 03:00:17] iteration 24/ 365000 | consumed samples: 49152 | elapsed time per iteration (ms): 18932.1 | learning rate: 3.600000E-06 | global batch size: 2048 | lm loss: 8.980194E+00 | z_loss: 2.145369E+00 | load_balancing_loss: 1.034725E+00 | loss scale: 1.0 | grad norm: 3.216 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 03:00:53] iteration 26/ 365000 | consumed samples: 53248 | elapsed time per iteration (ms): 17997.6 | learning rate: 3.900000E-06 | global batch size: 2048 | lm loss: 8.913023E+00 | z_loss: 1.799559E+00 | load_balancing_loss: 1.032197E+00 | loss scale: 1.0 | grad norm: 3.193 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 03:01:31] iteration 28/ 365000 | consumed samples: 57344 | elapsed time per iteration (ms): 18797.5 | learning rate: 4.200000E-06 | global batch size: 2048 | lm loss: 8.858095E+00 | z_loss: 1.508859E+00 | load_balancing_loss: 1.028084E+00 | loss scale: 1.0 | grad norm: 3.021 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations: 0 | number of nan iterations: 0 |
DP=16, EP=8
[2024-10-25 02:42:33] iteration 4/ 365000 | consumed samples: 8192 | elapsed time per iteration (ms): 18489.0 | learning rate: 6.000000E-07 | global batch size: 2048 | lm loss: 1.048763E+01 | z_loss: 4.795546E+00 | load_balancing_loss: 1.251609E+00 | loss scale: 1.0 | grad norm: 17.397 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:43:10] iteration 6/ 365000 | consumed samples: 12288 | elapsed time per iteration (ms): 18519.2 | learning rate: 9.000000E-07 | global batch size: 2048 | lm loss: 1.043374E+01 | z_loss: 4.780572E+00 | load_balancing_loss: 1.244789E+00 | loss scale: 1.0 | grad norm: 16.892 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:43:47] iteration 8/ 365000 | consumed samples: 16384 | elapsed time per iteration (ms): 18453.2 | learning rate: 1.200000E-06 | global batch size: 2048 | lm loss: 1.020294E+01 | z_loss: 4.610379E+00 | load_balancing_loss: 1.236636E+00 | loss scale: 1.0 | grad norm: 17.293 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
DP=32,EP=8 (speed starts to saturate when scaling the DP dim)
[2024-10-25 02:47:00] iteration 4/ 365000 | consumed samples: 8192 | elapsed time per iteration (ms): 18215.0 | learning rate: 6.000000E-07 | global batch size: 2048 | lm loss: 1.048763E+01 | z_loss: 4.808811E+00 | load_balancing_loss: 1.253611E+00 | loss scale: 1.0 | grad norm: 17.397 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:47:38] iteration 6/ 365000 | consumed samples: 12288 | elapsed time per iteration (ms): 18854.0 | learning rate: 9.000000E-07 | global batch size: 2048 | lm loss: 1.043371E+01 | z_loss: 4.762641E+00 | load_balancing_loss: 1.241172E+00 | loss scale: 1.0 | grad norm: 16.892 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:48:16] iteration 8/ 365000 | consumed samples: 16384 | elapsed time per iteration (ms): 19142.1 | learning rate: 1.200000E-06 | global batch size: 2048 | lm loss: 1.020293E+01 | z_loss: 4.619304E+00 | load_balancing_loss: 1.248783E+00 | loss scale: 1.0 | grad norm: 17.295 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
[2024-10-25 02:48:54] iteration 10/ 365000 | consumed samples: 20480 | elapsed time per iteration (ms): 18787.4 | learning rate: 1.500000E-06 | global batch size: 2048 | lm loss: 9.840721E+00 | z_loss: 4.333767E+00 | load_balancing_loss: 1.271878E+00 | loss scale: 1.0 | grad norm: 15.654 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations: 0 | number of nan iterations: 0 |
I can of course start introducing PP>1 as described in the MoE Doc, but not being able to scale the DP dimension beyond 8 still seems to be an issue for me. Any help would be appreciated!
Expected behavior
Training throughput scale along the DP dim when DP>8.
Stack trace/logs
N/A
Environment (please complete the following information):
NGC official PyTorch Container 24.09 nvcr.io/nvidia/pytorch:24.09-py3
Proposed fix
N/A
The text was updated successfully, but these errors were encountered:
hwang595
changed the title
[BUG] MoE pre-training does not scale beyond DP dim > 16
[BUG] MoE pre-training does not scale beyond DP dim>16
Oct 25, 2024
hwang595
changed the title
[BUG] MoE pre-training does not scale beyond DP dim>16
[BUG] MoE pre-training does not scale beyond DP dim>8
Oct 25, 2024
Describe the bug
I used {2, 4, 8, 16, 32} 8xH100 DGX nodes to run Mixtral style MoE pre-training following the Megatron MoE guidance. The model size is roughly 8x3B, and I only used EP and DP for scaling it up. However, when fixing
EP=8
and scalingDP={2, 4, 8, 16, 32}
(please see the detailed script below). I observed the speed does not improve afterDP=8
.I'm reporting the runtime results below:
DP=2, EP=8
DP=4, EP=8
DP=8, EP=8
DP=16, EP=8
DP=32,EP=8
(speed starts to saturate when scaling the DP dim)I can of course start introducing
PP>1
as described in the MoE Doc, but not being able to scale the DP dimension beyond8
still seems to be an issue for me. Any help would be appreciated!To Reproduce
Below is my script
Expected behavior
Training throughput scale along the
DP
dim whenDP>8
.Stack trace/logs
N/A
Environment (please complete the following information):
NGC official PyTorch Container
24.09
nvcr.io/nvidia/pytorch:24.09-py3
Proposed fix
N/A
The text was updated successfully, but these errors were encountered: