-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory Optimization with Liger Kernel Shows Limited Effect on larger Model (more than 7B) #517
Comments
Hi @dyyoungg, In a similar scenario, I encountered a memory spike during the optimizer step while using the PyTorch Adam optimizer with the default setting Could you provide more details about your setup? I can then try to reproduce it and perform an analysis. Specifically:
|
I tested training in 8*A100, and here’s my training settings:
|
Hi @dyyoungg, When training, two types of memory occupy your GPU’s RAM:
There are many ways to optimize and leverage GPU memory. In your case, you're using DeepSpeed Zero Stage 3, a data parallelism method that partitions static memory across GPUs. It gathers model weights using methods like With DeepSpeed and BF16 mixed-precision training, the static memory size is approximately 16P (2+2+4+4+4), which equals 16 x 14B = 224GB. With 8 A100 GPUs, each GPU stores 224GB / 8 = 28GB. For the dynamic memory:
Now, let’s examine the most memory-intensive operation, cross-entropy peak memory: Since your sequence length and batch size are relatively small, the total GPU memory usage is primarily dominated by static memory. A summary of VRAM usage per GPU:
On Liger Kernel and Chunking OptimizationLiger Kernel significantly reduces memory peaks related to logits. By chunking and fusing, the peak memory for logits is limited to: Back to Your SituationThe observed 6% memory reduction after enabling Liger Kernel is because your batch size (1) and sequence length (4096) are small, so static memory dominates overall usage. However, enabling Liger Kernel allows you to scale batch size and sequence length without significant memory spikes. Without it, increasing the batch size from 1 to 2 would double logits-related memory from 6.96GB to 13.92GB, likely causing an OOM error. ConclusionIn conclusion, I don't believe there are inherent limitations with the Liger Kernel in large model training. The issue arises when there is a limited number of GPUs available to partition the massive static memory required. The first bottleneck encountered is the static memory. Once this bottleneck is addressed (e.g., by using more GPUs for partitioning or employing offloading techniques), the next bottleneck shifts to the dynamic memory, such as activations or intermediate values, which is where the Liger Kernel excels. For an extreme example, if you are training a 70B model on 8xA100 GPUs, you may encounter an out-of-memory (OOM) error during the model loading stage. At that point, there are no activations or dynamic allocations in play; therefore, enabling or disabling the Liger Kernel will not prevent the OOM error because the initial static memory bottleneck remains unresolved. |
I have been using the Liger Kernel to replace standard operators to train Qwen25 model series with deepspeed ZERO3 strategy.
It significantly reduces memory usage on a 7B model(about 36%), however,it shows limited memory saving (about 6%) on a 14B model.
Questions:
The text was updated successfully, but these errors were encountered: