Skip to content

Commit

Permalink
Allow additional args to be passed many times
Browse files Browse the repository at this point in the history
  • Loading branch information
gpupuck authored Oct 15, 2024
1 parent 70f67a8 commit 9983cb0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ usage() {
echo "Usage: $0 [OPTIONS]"
echo ""
echo " OPTIONS DESCRIPTION"
echo " -a, --additional-args Additional args to pass to MaxText/train.py"
echo " -a, --additional-args Additional args to pass to MaxText/train.py. Can be passed many times."
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65". Default to 0.90, contradicting JAX default of 0.75.
echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
Expand All @@ -34,7 +34,7 @@ usage() {
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a "scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false"
5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
Expand Down Expand Up @@ -76,7 +76,7 @@ eval set -- "$args"
while [ : ]; do
case "$1" in
-a | --additional-args)
ADDITIONAL_ARGS="$2"
ADDITIONAL_ARGS="$ADDITIONAL_ARGS $2"
shift 2
;;
--mem-fraction)
Expand Down

0 comments on commit 9983cb0

Please sign in to comment.