This is the official code for SAMPa: Sharpness-aware Minimization Parallelized accepted at NeurIPS 2024.
SAMPa introduces a fully parallelized version of sharpness-aware minimization (SAM) by allowing the two gradient computations to occur simultaneously:
where the gradients
SAMPa serves as one of the most efficient SAM variants:
conda create -n sampa python=3.8
conda activate sampa
# On GPU
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -r requirements.txt
This code is for SAMPa's implementation. It parallelizes two gradient computations on 2 GPUs.
Specifically in train.py
, global_rank:0
handles global_rank:1
handles
To train ResNet-56 on CIFAR-10 using SAMPa, use the following command:
CUDA_VISIBLE_DEVICES=0,1 python train.py --model resnet56 --dataset cifar10 --rho 0.1 --epochs 200