Skip to content

Commit

Permalink
Improve ViT README (#201)
Browse files Browse the repository at this point in the history
- add more detailed data download instructions
- fix train/eval index paths
- add an out-of-the-box interactive pretraining example
- add a `Hardware Specifications` section
  • Loading branch information
ashors1 authored Sep 8, 2023
1 parent 632b622 commit bc9c0ac
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 15 deletions.
3 changes: 3 additions & 0 deletions rosetta/rosetta/data/generate_wds_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os

import wds2idx
from absl import logging
from braceexpand import braceexpand


Expand All @@ -38,3 +39,5 @@
creator = wds2idx.IndexCreator(url, os.path.join(args.index_dir, f'idx_{i}.txt'))
creator.create_index()
creator.close()

logging.info(f'Done! Index files written to {args.index_dir}.')
54 changes: 41 additions & 13 deletions rosetta/rosetta/projects/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

This directory provides an implementation of the [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929.pdf) model. This implementation is a direct adaptation of Google's [original ViT implementation](https://github.com/google-research/vision_transformer/tree/main). We have extended the original ViT implementation to include model parallel support. Model configurations are also based on the the original ViT implementation. Presently, convergence has been verified on ViT-B/16. Support for a wider range of models will be added in the future.

## Hardware Specifications
Convergence and performance has been validated on NVIDIA DGX A100 (8x A100 80G) nodes. Pretraining and fine-tuning of ViT/B-16 can be performed on a single DGX A100 80G node. We provide both singlenode and multinode support for pre-training and fine-tuning. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU. You may also use gradient accumulation to reduce the microbatch size while keeping the global batch size and GPU count constant, but note that gradient accumulation with ViT works only with scale-invariant optimizers such as Adam and Adafactor. See the [known issues](#Known-issues) section below for more information.

## Building a Container
We provide and fully built and ready-to-use container here: `ghcr.io/nvidia/rosetta-t5x:vit-2023-07-21`

If you do not plan on making changes to the Rosetta source code and would simply like to run experiments on top of Rosetta, we strongly recommend using the pre-built container. Run the following command to launch a container interactively:
```
export CONTAINER=ghcr.io/nvidia/rosetta-t5x:vit-2023-07-21
docker run -ti --gpus=all --net=host --ipc=host -v <IMAGENET_PATH>:/opt/rosetta/datasets/imagenet -v <WORKSPACE_PATH>:/opt/rosetta/workspace -v <TRAIN_INDEX_PATH>:/opt/rosetta/train_idxs -v <EVAL_INDEX_PATH>:/opt/rosetta/eval_idxs --privileged $CONTAINER /bin/bash
```
where `<IMAGENET_PATH>` is the path to the ImageNet-1k dataset (see the [downloading the dataset](#Downloading-the-dataset) section below for details) and ``<TRAIN_INDEX_PATH>`` and ``<EVAL_INDEX_PATH>`` refer to the paths to the train and eval indices for the ImageNet tar files (see the [before launching a run](#before-launching-a-run) section below for more information about these paths). ``<WORKSPACE_PATH>`` refers to the directory where you would like to store any persistent files. Any custom configurations or run scripts needed for your experiments should reside here.
Expand Down Expand Up @@ -40,40 +44,60 @@ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/
### exit the `raw_data` directory
cd ../../..
```
4. Rosetta expects the dataset to be in WebDataset format. To convert the data to the appropriate format, first download "Development kit (Tasks 1 & 2)" from [here](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) and place it in the root data directory (`raw_data/imagenet_1k`).
Next, create the dataset directory and then run [makeshards.py](https://github.com/webdataset/webdataset-lightning/blob/main/makeshards.py) as follows:
4. Download "Development kit (Tasks 1 & 2)" from [here](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) and place it in the root data directory (`raw_data/imagenet_1k`).
5. Rosetta expects the dataset to be in WebDataset format. To convert the data to the appropriate format, first enter into the container. Be sure to mount the raw data to the container. Also, mount the location where you would like to store the dataset (referred to here as `<IMAGENET_PATH>`) to the container:
```
docker run -ti --gpus=all --net=host --ipc=host -v ${PWD}/raw_data/imagenet_1k:/opt/rosetta/raw_data/imagenet_1k -v <IMAGENET_PATH>:/opt/rosetta/datasets/imagenet --privileged $CONTAINER /bin/bash
```

6. Next, we will run [makeshards.py](https://github.com/webdataset/webdataset-lightning/blob/main/makeshards.py) to convert the data to WebDataset format. `makeshards.py` requires torchvision, which can be installed in the container using the following command:
```
pip install torchvision
```
7. Download `makeshards.py` and write the WebDataset shards to `datasets/imagenet`:
```
mkdir -p datasets/imagenet
wget https://raw.githubusercontent.com/webdataset/webdataset-lightning/main/makeshards.py
python3 makeshards.py --shards datasets/imagenet --data raw_data/imagenet_1k
```
Note that torchvision is required and must be manually installed prior to performing this preprocessing step.

## Before Launching a Run
ViT uses [DALI](https://github.com/NVIDIA/DALI/tree/c4f105e1119ef887f037830a5551c04f9062bb74) on CPU for performant dataloading. Loading WebDataset tar files is done using DALI's [webdataset reader](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/operations/nvidia.dali.fn.readers.webdataset.html#nvidia.dali.fn.readers.webdataset). The reader expects each tar file to have a corresponding index file. These files can be generated using `rosetta/data/generate_wds_indices.py`. To generate the indices for the training data, use the following command. Note that braceexpand notation is used to specify the range of tar files to generate index files for.

ViT uses [DALI](https://github.com/NVIDIA/DALI/tree/c4f105e1119ef887f037830a5551c04f9062bb74) on CPU for performant dataloading. Loading WebDataset tar files is done using DALI's [webdataset reader](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/operations/nvidia.dali.fn.readers.webdataset.html#nvidia.dali.fn.readers.webdataset). The reader expects each tar file to have a corresponding index file. These files can be generated using `rosetta/data/generate_wds_indices.py`. To generate the indices for the training data, first enter into the container, being sure to mount the locations where you would like to store the index files:
```
docker run -ti --gpus=all --net=host --ipc=host -v <IMAGENET_PATH>:/opt/rosetta/datasets/imagenet -v <TRAIN_INDEX_PATH>:/opt/rosetta/train_idxs -v <EVAL_INDEX_PATH>:/opt/rosetta/eval_idxs --privileged $CONTAINER /bin/bash
```
python3 -m rosetta.data.generate_wds_indices --archive "/opt/rosetta/datasets/imagenet/imagenet-train-{000000..000146}.tar" --index_dir "/opt/rosetta/train_indices"
The train and eval indices will be saved to `<TRAIN_INDEX_PATH>` and `<EVAL_INDEX_PATH>`, respectively. Once inside of the container, run the following command. Note that braceexpand notation is used to specify the range of tar files to generate index files for.
```
python3 -m rosetta.data.generate_wds_indices --archive "/opt/rosetta/datasets/imagenet/imagenet-train-{000000..000146}.tar" --index_dir "/opt/rosetta/train_idxs"
```
Similarly, to generate indices for the validation dataset,
```
python3 -m rosetta.data.generate_wds_indices --archive "/opt/rosetta/datasets/imagenet/imagenet-val-{000000..000006}.tar" --index_dir "/opt/rosetta/eval_indices"
python3 -m rosetta.data.generate_wds_indices --archive "/opt/rosetta/datasets/imagenet/imagenet-val-{000000..000006}.tar" --index_dir "/opt/rosetta/eval_idxs"
```
This step is optional. If no indices are provided to the WebDataset reader, they will be inferred automatically, but it typically takes around 10 minutes to infer the index files for the train dataset.
When launching subsequent runs in the container, mount `<TRAIN_INDEX_PATH>` and `<EVAL_INDEX_PATH>` to the container to avoid having to regenerate the indices.

_Note_: This step is optional. If no indices are provided to the webdataset reader, they will be inferred automatically, but it typically takes around 10 minutes to infer the index files for the train dataset.

## Training Runs

### Pre-training
#### Single-process
Use the following command to launch a single-process pre-training run interactively from the top-level directory of the repository:
```
bash rosetta/projects/vit/scripts/singleprocess_pretrain.sh base bfloat16 <NUM GPUS> <BATCH SIZE PER GPU> <LOG DIR> <MODEL DIR LOCAL> <TRAIN INDEX DIR> <EVAL INDEX DIR>
bash rosetta/projects/vit/scripts/singleprocess_pretrain.sh base bfloat16 <NUM GPUS> <BATCH SIZE PER GPU> <LOG DIR> <MODEL DIR LOCAL> <TRAIN INDEX DIR> <EVAL INDEX DIR> <GRADIENT ACCUMULATION>
```
`<MODEL DIR LOCAL>` refers to the _relative_ path to save checkpoints, summaries and configuration details to.

The following command can be used to launch a pre-training convergence run interactively on 8 GPUs:
```
bash rosetta/projects/vit/scripts/singleprocess_pretrain.sh base bfloat16 8 512 log_dir base_pretrain_dir /opt/rosetta/train_idxs /opt/rosetta/eval_idxs 1
```

#### Multi-process
See `rosetta/projects/vit/scripts/example_slurm_pretrain.sub` for an example submit file that can be used to launch a multiprocess pre-training run with a SLURM + pyxis cluster. The following command can be used to launch a pre-training convergence run on a single node:
```
BASE_WORKSPACE_DIR=<PATH TO WORKSPACE> BASE_WDS_DATA_DIR=<PATH TO DATASET> BASE_TRAIN_IDX_DIR=<PATH TO TRAIN INDICES> BASE_EVAL_IDX_DIR=<PATH TO EVAL INDICES> VIT_SIZE=base PREC=bfloat16 GPUS_PER_NODE=8 BSIZE_PER_GPU=512 MODEL_DIR_LOCAL=base_pretrain_dir sbatch -N 1 -A <ACCOUNT> -p <PARTITION> -J <JOBNAME> example_slurm_pretrain.sub
BASE_WORKSPACE_DIR=<PATH TO WORKSPACE> BASE_WDS_DATA_DIR=<PATH TO DATASET> BASE_TRAIN_IDX_DIR=<PATH TO TRAIN INDICES> BASE_EVAL_IDX_DIR=<PATH TO EVAL INDICES> VIT_SIZE=base PREC=bfloat16 GPUS_PER_NODE=8 BSIZE_PER_GPU=512 GRAD_ACCUM=1 MODEL_DIR_LOCAL=base_pretrain_dir sbatch -N 1 -A <ACCOUNT> -p <PARTITION> -J <JOBNAME> example_slurm_pretrain.sub
```
Here, `MODEL_DIR_LOCAL` is the directory to save checkpoints, summaries and configuration details to, relative to `BASE_WORKSPACE_DIR`.

### Pre-training to Fine-tuning
For improved fine-tuning accuracy, ViT pre-trains using a resolution of 224 and finetunes using a resolution of 384. Additionally, the classification heads used during pre-training and fine-tuning differ: the classification head consists of a two-layer MLP during pre-training and a single linear layer during fine-tuning. The script `rosetta/projects/vit/scripts/convert_t5x_pre-train_to_finetune_ckpt.py` converts the pre-trained checkpoint to a checkpoint that is compatible with the desired fine-tuning configuration. Run the following command to generate the checkpoint to be used during fine-tuning:
Expand All @@ -86,14 +110,14 @@ python3 -m rosetta.projects.vit.scripts.convert_t5x_pretrain_to_finetune_ckpt --
#### Single-process
Use the following command to launch a single-process fine-tuning run:
```
bash rosetta/projects/vit/scripts/singleprocess_finetune.sh base bfloat16 <NUM GPUS> <BATCH SIZE PER GPU> <LOG DIR> <MODEL DIR LOCAL> <TRAIN INDEX DIR> <EVAL INDEX DIR>
bash rosetta/projects/vit/scripts/singleprocess_finetune.sh base bfloat16 <NUM GPUS> <BATCH SIZE PER GPU> <LOG DIR> <MODEL DIR LOCAL> <TRAIN INDEX DIR> <EVAL INDEX DIR> <GRADIENT ACCUMULATION>
```
where `<MODEL DIR LOCAL>` corresponds to the directory containing the converted pre-training checkpoint.

#### Multi-process
See `rosetta/projects/vit/scripts/example_slurm_finetune.sub` for an example submit file for launching a fine-tuning run with a SLURM + pyxis cluster. The following command can be used to launch a fine-tuning convergence run:
```
BASE_WORKSPACE_DIR=<PATH TO WORKSPACE> BASE_WDS_DATA_DIR=<PATH TO DATASET> BASE_TRAIN_IDX_DIR=<PATH TO TRAIN INDICES> BASE_EVAL_IDX_DIR=<PATH TO EVAL INDICES> VIT_SIZE=base PREC=bfloat16 GPUS_PER_NODE=8 BSIZE_PER_GPU=128 MODEL_DIR_LOCAL=base_finetune_dir sbatch -N 1 -A <ACCOUNT> -p <PARTITION> -J <JOBNAME> example_slurm_finetune.sub
BASE_WORKSPACE_DIR=<PATH TO WORKSPACE> BASE_WDS_DATA_DIR=<PATH TO DATASET> BASE_TRAIN_IDX_DIR=<PATH TO TRAIN INDICES> BASE_EVAL_IDX_DIR=<PATH TO EVAL INDICES> VIT_SIZE=base PREC=bfloat16 GPUS_PER_NODE=8 BSIZE_PER_GPU=128 GRAD_ACCUM=1 MODEL_DIR_LOCAL=base_finetune_dir sbatch -N 1 -A <ACCOUNT> -p <PARTITION> -J <JOBNAME> example_slurm_finetune.sub
```

## Convergence Results
Expand Down Expand Up @@ -130,3 +154,7 @@ Pre-training was performed on 1 node with a global batch size of 4096. Models we

## Future Improvements
1. ViT currently does not support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). We plan to add Transformer Engine support to further accelerate pre-training and fine-tuning in the near future.

## Known Issues
1. By default, gradient accumulation (GA) sums loss across the microbatches. As a result, loss is scaled up when using gradient accumulation, and training with GA only works when using a scale-invariant optimizer such as Adam or Adafactor. ViT fine-tuning is performed using SGD; thus, GA should not be used when fine-tuning.

Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ PREC=${PREC:="bfloat16"} # bfloat16, float32
GPUS_PER_NODE=${GPUS_PER_NODE:=8} # usually 8
BSIZE_PER_GPU=${BSIZE_PER_GPU:=128} # local batch size/gpu
MODEL_DIR_LOCAL=${MODEL_DIR_LOCAL:="finetune_dir"} # directory to save checkpoints and config dump to, relative to BASE_WORKSPACE_DIR
GRAD_ACCUM=${GRAD_ACCUM:=1}

read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& nvidia-smi \
&& bash rosetta/projects/vit/scripts/multiprocess_finetune.sh $VIT_SIZE $PREC $GPUS_PER_NODE $BSIZE_PER_GPU workspace/$MODEL_DIR_LOCAL $TRAIN_IDX_DIR $EVAL_IDX_DIR
&& bash rosetta/projects/vit/scripts/multiprocess_finetune.sh $VIT_SIZE $PREC $GPUS_PER_NODE $BSIZE_PER_GPU workspace/$MODEL_DIR_LOCAL $TRAIN_IDX_DIR $EVAL_IDX_DIR $GRAD_ACCUM
EOF

OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/finetune-${VIT_SIZE}-${PREC}-N${SLURM_JOB_NUM_NODES}-n${GPUS_PER_NODE}-BS${BSIZE_PER_GPU}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ PREC=${PREC:="bfloat16"} # bfloat16, float32
GPUS_PER_NODE=${GPUS_PER_NODE:=8} # usually 8
BSIZE_PER_GPU=${BSIZE_PER_GPU:=512} # local batch size/gpu
MODEL_DIR_LOCAL=${MODEL_DIR_LOCAL:="pretrain_dir"} # directory to save checkpoints and config dump to, relative to BASE_WORKSPACE_DIR
GRAD_ACCUM=${GRAD_ACCUM:=1}

read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& nvidia-smi \
&& bash rosetta/projects/vit/scripts/multiprocess_pretrain.sh $VIT_SIZE $PREC $GPUS_PER_NODE $BSIZE_PER_GPU workspace/$MODEL_DIR_LOCAL $TRAIN_IDX_DIR $EVAL_IDX_DIR
&& bash rosetta/projects/vit/scripts/multiprocess_pretrain.sh $VIT_SIZE $PREC $GPUS_PER_NODE $BSIZE_PER_GPU workspace/$MODEL_DIR_LOCAL $TRAIN_IDX_DIR $EVAL_IDX_DIR $GRAD_ACCUM
EOF

OUTPUT_DIR="${BASE_WORKSPACE_DIR}/outputs/pretrain-${VIT_SIZE}-${PREC}-N${SLURM_JOB_NUM_NODES}-n${GPUS_PER_NODE}-BS${BSIZE_PER_GPU}"
Expand Down

0 comments on commit bc9c0ac

Please sign in to comment.