diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 0dc26c8c1..50a353511 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -13,7 +13,7 @@ usage() { echo "Usage: $0 [OPTIONS]" echo "" echo " OPTIONS DESCRIPTION" - echo " -a, --additional-args Additional args to pass to MaxText/train.py" + echo " -a, --additional-args Additional args to pass to MaxText/train.py. Can be passed many times." echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65". Default to 0.90, contradicting JAX default of 0.75. echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models: gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b" @@ -34,7 +34,7 @@ usage() { 1. test-maxtext.sh -b 2 --model-name=gpt3-52k 2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8 3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess - 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false + 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a "scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false" 5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess 6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess 7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess @@ -76,7 +76,7 @@ eval set -- "$args" while [ : ]; do case "$1" in -a | --additional-args) - ADDITIONAL_ARGS="$2" + ADDITIONAL_ARGS="$ADDITIONAL_ARGS $2" shift 2 ;; --mem-fraction)