Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] MoE pre-training does not scale beyond DP dim>8 #1258

Open
hwang595 opened this issue Oct 25, 2024 · 0 comments
Open

[BUG] MoE pre-training does not scale beyond DP dim>8 #1258

hwang595 opened this issue Oct 25, 2024 · 0 comments

Comments

@hwang595
Copy link

hwang595 commented Oct 25, 2024

Describe the bug
I used {2, 4, 8, 16, 32} 8xH100 DGX nodes to run Mixtral style MoE pre-training following the Megatron MoE guidance. The model size is roughly 8x3B, and I only used EP and DP for scaling it up. However, when fixing EP=8 and scaling DP={2, 4, 8, 16, 32} (please see the detailed script below). I observed the speed does not improve after DP=8.

I'm reporting the runtime results below:
DP=2, EP=8

 [2024-10-25 04:08:43] iteration        4/  365000 | consumed samples:         8192 | elapsed time per iteration (ms): 33450.2 | learning rate: 6.000000E-09 | global batch size:  2048 | lm loss: 1.049151E+01 | z_loss: 4.799740E+00 | load_balancing_loss: 1.261538E+00 | loss scale: 1.0 | grad norm: 17.428 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:09:49] iteration        6/  365000 | consumed samples:        12288 | elapsed time per iteration (ms): 33031.0 | learning rate: 9.000000E-09 | global batch size:  2048 | lm loss: 1.049228E+01 | z_loss: 4.805847E+00 | load_balancing_loss: 1.257355E+00 | loss scale: 1.0 | grad norm: 17.315 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:10:54] iteration        8/  365000 | consumed samples:        16384 | elapsed time per iteration (ms): 32436.5 | learning rate: 1.200000E-08 | global batch size:  2048 | lm loss: 1.049336E+01 | z_loss: 4.809865E+00 | load_balancing_loss: 1.263811E+00 | loss scale: 1.0 | grad norm: 17.803 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |

DP=4, EP=8

0 | load_balancing_loss: 1.257406E+00 | loss scale: 1.0 | grad norm: 17.701 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:01:53] iteration        4/  365000 | consumed samples:         8192 | elapsed time per iteration (ms): 23922.8 | learning rate: 6.000000E-09 | global batch size:  2048 | lm loss: 1.049150E+01 | z_loss: 4.805318E+00 | load_balancing_loss: 1.261618E+00 | loss scale: 1.0 | grad norm: 17.428 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:02:40] iteration        6/  365000 | consumed samples:        12288 | elapsed time per iteration (ms): 23492.5 | learning rate: 9.000000E-09 | global batch size:  2048 | lm loss: 1.049229E+01 | z_loss: 4.813999E+00 | load_balancing_loss: 1.253946E+00 | loss scale: 1.0 | grad norm: 17.316 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:03:26] iteration        8/  365000 | consumed samples:        16384 | elapsed time per iteration (ms): 23066.2 | learning rate: 1.200000E-08 | global batch size:  2048 | lm loss: 1.049337E+01 | z_loss: 4.805440E+00 | load_balancing_loss: 1.264813E+00 | loss scale: 1.0 | grad norm: 17.802 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 04:04:13] iteration       10/  365000 | consumed samples:        20480 | elapsed time per iteration (ms): 23237.8 | learning rate: 1.500000E-08 | global batch size:  2048 | lm loss: 1.049290E+01 | z_loss: 4.794414E+00 | load_balancing_loss: 1.258429E+00 | loss scale: 1.0 | grad norm: 17.760 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |

DP=8, EP=8

 [2024-10-25 02:59:01] iteration       20/  365000 | consumed samples:        40960 | elapsed time per iteration (ms): 18228.8 | learning rate: 3.000000E-06 | global batch size:  2048 | lm loss: 9.122166E+00 | z_loss: 2.849494E+00 | load_balancing_loss: 1.052246E+00 | loss scale: 1.0 | grad norm: 3.607 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:59:39] iteration       22/  365000 | consumed samples:        45056 | elapsed time per iteration (ms): 18968.6 | learning rate: 3.300000E-06 | global batch size:  2048 | lm loss: 9.057307E+00 | z_loss: 2.501926E+00 | load_balancing_loss: 1.040945E+00 | loss scale: 1.0 | grad norm: 3.386 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 03:00:17] iteration       24/  365000 | consumed samples:        49152 | elapsed time per iteration (ms): 18932.1 | learning rate: 3.600000E-06 | global batch size:  2048 | lm loss: 8.980194E+00 | z_loss: 2.145369E+00 | load_balancing_loss: 1.034725E+00 | loss scale: 1.0 | grad norm: 3.216 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 03:00:53] iteration       26/  365000 | consumed samples:        53248 | elapsed time per iteration (ms): 17997.6 | learning rate: 3.900000E-06 | global batch size:  2048 | lm loss: 8.913023E+00 | z_loss: 1.799559E+00 | load_balancing_loss: 1.032197E+00 | loss scale: 1.0 | grad norm: 3.193 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 03:01:31] iteration       28/  365000 | consumed samples:        57344 | elapsed time per iteration (ms): 18797.5 | learning rate: 4.200000E-06 | global batch size:  2048 | lm loss: 8.858095E+00 | z_loss: 1.508859E+00 | load_balancing_loss: 1.028084E+00 | loss scale: 1.0 | grad norm: 3.021 | num zeros: 0.0 | params norm: 1125.557 | number of skipped iterations:   0 | number of nan iterations:   0 |

DP=16, EP=8

 [2024-10-25 02:42:33] iteration        4/  365000 | consumed samples:         8192 | elapsed time per iteration (ms): 18489.0 | learning rate: 6.000000E-07 | global batch size:  2048 | lm loss: 1.048763E+01 | z_loss: 4.795546E+00 | load_balancing_loss: 1.251609E+00 | loss scale: 1.0 | grad norm: 17.397 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:43:10] iteration        6/  365000 | consumed samples:        12288 | elapsed time per iteration (ms): 18519.2 | learning rate: 9.000000E-07 | global batch size:  2048 | lm loss: 1.043374E+01 | z_loss: 4.780572E+00 | load_balancing_loss: 1.244789E+00 | loss scale: 1.0 | grad norm: 16.892 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:43:47] iteration        8/  365000 | consumed samples:        16384 | elapsed time per iteration (ms): 18453.2 | learning rate: 1.200000E-06 | global batch size:  2048 | lm loss: 1.020294E+01 | z_loss: 4.610379E+00 | load_balancing_loss: 1.236636E+00 | loss scale: 1.0 | grad norm: 17.293 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |

DP=32,EP=8 (speed starts to saturate when scaling the DP dim)

 [2024-10-25 02:47:00] iteration        4/  365000 | consumed samples:         8192 | elapsed time per iteration (ms): 18215.0 | learning rate: 6.000000E-07 | global batch size:  2048 | lm loss: 1.048763E+01 | z_loss: 4.808811E+00 | load_balancing_loss: 1.253611E+00 | loss scale: 1.0 | grad norm: 17.397 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:47:38] iteration        6/  365000 | consumed samples:        12288 | elapsed time per iteration (ms): 18854.0 | learning rate: 9.000000E-07 | global batch size:  2048 | lm loss: 1.043371E+01 | z_loss: 4.762641E+00 | load_balancing_loss: 1.241172E+00 | loss scale: 1.0 | grad norm: 16.892 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:48:16] iteration        8/  365000 | consumed samples:        16384 | elapsed time per iteration (ms): 19142.1 | learning rate: 1.200000E-06 | global batch size:  2048 | lm loss: 1.020293E+01 | z_loss: 4.619304E+00 | load_balancing_loss: 1.248783E+00 | loss scale: 1.0 | grad norm: 17.295 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2024-10-25 02:48:54] iteration       10/  365000 | consumed samples:        20480 | elapsed time per iteration (ms): 18787.4 | learning rate: 1.500000E-06 | global batch size:  2048 | lm loss: 9.840721E+00 | z_loss: 4.333767E+00 | load_balancing_loss: 1.271878E+00 | loss scale: 1.0 | grad norm: 15.654 | num zeros: 0.0 | params norm: 1125.558 | number of skipped iterations:   0 | number of nan iterations:   0 |

I can of course start introducing PP>1 as described in the MoE Doc, but not being able to scale the DP dimension beyond 8 still seems to be an issue for me. Any help would be appreciated!

To Reproduce
Below is my script

export CUDA_DEVICE_MAX_CONNECTIONS=1

GPUS_PER_NODE=8
NNODES=16

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
echo $SLURM_JOB_NODELIST
export LOGLEVEL=INFO

DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NNODES
    --rdzv_id $RANDOM
    --rdzv_backend c10d
    --rdzv_endpoint $head_node_ip:29500
)

MODEL_ARGS=(
    --disable-bias-linear
    --seq-length 2048
    --max-position-embeddings 32768
    --num-layers 36
    --hidden-size 2304
    --ffn-hidden-size 7680
    --num-attention-heads 36
    --init-method-std 0.01  # need to check source code
    --attention-dropout 0.0
    --hidden-dropout 0.0
    --normalization RMSNorm
    --position-embedding-type rope
    --rotary-base 200000
    --swiglu
    --untie-embeddings-and-output-weights
    --no-position-embedding
)

MOE_ARGS=(
    --num-experts 8
    --expert-model-parallel-size 8
    --moe-router-topk 2
    --moe-grouped-gemm # no big influence
    --moe-token-dispatcher-type alltoall
    --moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
    --moe-aux-loss-coeff 1e-2
    --moe-z-loss-coeff 1e-3
    --moe-expert-capacity-factor 1.0
    --moe-token-drop-policy probs
)

DATA_ARGS=(
    --tokenizer-type ${TOKENIZER_TYPE}
    --tokenizer-model ${TOKENIZER_MODEL}
    --data-path $DATA_PATH
    --split 99990,8,2
)


TRAINING_ARGS=(
    --micro-batch-size 2
    --global-batch-size 2048
    --seed 42
    --lr 3e-6 # 3e-4
    --adam-beta1 0.9
    --adam-beta2 0.95
    --train-iters 365000 
    --lr-decay-iters 330000
    --lr-decay-style cosine
    --min-lr 3e-7
    --lr-warmup-iters 2000
    --weight-decay 0.1
    --clip-grad 1.0
    --bf16
    --overlap-param-gather
    --overlap-grad-reduce
)


MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 1
    --pipeline-model-parallel-size 1
    --sequence-parallel
    --use-distributed-optimizer
    --distributed-timeout-minutes 60
)

LOGGING_ARGS=(
    --log-interval 2
    --save-interval 50000
    --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard"
    --log-memory-to-tensorboard 
    --log-params-norm
    --log-timers-to-tensorboard
)

CHECKPOINT_ARGS=(
    --eval-interval 1000
    --eval-iters 5
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
)

srun --container-mounts=MOUNT-DIRS --container-name=CONTAINER-NAME \
    --container-image=IMAGE-NAME \
    torchrun ${DISTRIBUTED_ARGS[@]} $WORK_DIR/pretrain_gpt.py \
    ${MODEL_ARGS[@]} \
    ${MOE_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${CHECKPOINT_ARGS[@]} \
    ${LOGGING_ARGS[@]}

Expected behavior
Training throughput scale along the DP dim when DP>8.

Stack trace/logs
N/A

Environment (please complete the following information):
NGC official PyTorch Container 24.09
nvcr.io/nvidia/pytorch:24.09-py3

Proposed fix
N/A

@hwang595 hwang595 changed the title [BUG] MoE pre-training does not scale beyond DP dim > 16 [BUG] MoE pre-training does not scale beyond DP dim>16 Oct 25, 2024
@hwang595 hwang595 changed the title [BUG] MoE pre-training does not scale beyond DP dim>16 [BUG] MoE pre-training does not scale beyond DP dim>8 Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant