Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Optimization with Liger Kernel Shows Limited Effect on larger Model (more than 7B) #517

Open
dyyoungg opened this issue Jan 8, 2025 · 3 comments

Comments

@dyyoungg
Copy link

dyyoungg commented Jan 8, 2025

I have been using the Liger Kernel to replace standard operators to train Qwen25 model series with deepspeed ZERO3 strategy.
It significantly reduces memory usage on a 7B model(about 36%), however,it shows limited memory saving (about 6%) on a 14B model.

Questions:

  1. Are there known limitations in Liger Kernel optimizations for larger models like 14B?
  2. Is there any recommended configuration or parameter adjustment to improve memory efficiency for larger models?
@DandinPower
Copy link
Contributor

Hi @dyyoungg,
I’m curious about this issue and wanted to share some insights based on my past experience.

In a similar scenario, I encountered a memory spike during the optimizer step while using the PyTorch Adam optimizer with the default setting foreach=True. This happened because it required a copy of the model weights. I wonder if your situation involves a similar memory peak that isn’t solely caused by the cross-entropy logits peak during training.
Even if the cross-entropy logits peak is mitigated by enabling the Liger Kernel, other memory bottlenecks might still be contributing to the overall peak memory usage, which could explain why the reduction isn’t significant.

Could you provide more details about your setup? I can then try to reproduce it and perform an analysis. Specifically:

  1. How many GPUs are you using? (ZeRO-3 partitions memory differently depending on the number of GPUs, which can significantly impact results.)
  2. What is the micro-batch size per GPU and the context length? (These parameters greatly affect activation memory usage.)
  3. Are you using other memory-efficient techniques, such as gradient checkpointing?

@dyyoungg
Copy link
Author

Hi @dyyoungg, I’m curious about this issue and wanted to share some insights based on my past experience.

In a similar scenario, I encountered a memory spike during the optimizer step while using the PyTorch Adam optimizer with the default setting foreach=True. This happened because it required a copy of the model weights. I wonder if your situation involves a similar memory peak that isn’t solely caused by the cross-entropy logits peak during training. Even if the cross-entropy logits peak is mitigated by enabling the Liger Kernel, other memory bottlenecks might still be contributing to the overall peak memory usage, which could explain why the reduction isn’t significant.

Could you provide more details about your setup? I can then try to reproduce it and perform an analysis. Specifically:

  1. How many GPUs are you using? (ZeRO-3 partitions memory differently depending on the number of GPUs, which can significantly impact results.)
  2. What is the micro-batch size per GPU and the context length? (These parameters greatly affect activation memory usage.)
  3. Are you using other memory-efficient techniques, such as gradient checkpointing?

I tested training in 8*A100, and here’s my training settings:

deepspeed  train.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path qwen25_14B \
    --data_path xx.json \
    --bf16 True \
    --output_dir ./output \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 2 \
    --learning_rate 1e-5 \
    --weight_decay 0.0 \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 4096 \
    --gradient_checkpointing True \
    --dataloader_num_workers 2 \
    --report_to tensorboard \

@DandinPower
Copy link
Contributor

DandinPower commented Jan 14, 2025

Hi @dyyoungg,

When training, two types of memory occupy your GPU’s RAM:

  1. Static Memory: Includes model weights, gradients, and optimizer states.
  2. Dynamic Memory: Includes activation memory and intermediate values.

There are many ways to optimize and leverage GPU memory. In your case, you're using DeepSpeed Zero Stage 3, a data parallelism method that partitions static memory across GPUs. It gathers model weights using methods like allgather during forward/backward passes, averages gradients with allreduce, and updates them locally on each GPU.

With DeepSpeed and BF16 mixed-precision training, the static memory size is approximately 16P (2+2+4+4+4), which equals 16 x 14B = 224GB. With 8 A100 GPUs, each GPU stores 224GB / 8 = 28GB.

For the dynamic memory:

  • Since gradient checkpointing is enabled, you only need to store the values for batch * seq * num_layers * hidden_size * dtype (if using Hugging Face models, they implement gradient_checkpoint_func for each decoder layer). For your case, this equals:
    1 * 4096 * 48 * 5120 * 2 = 1.875 GB.
  • Assuming you're using the DeepSpeed-provided FusedAdam optimizer, there’s no additional memory overhead for optimizer steps.

Now, let’s examine the most memory-intensive operation, cross-entropy peak memory:
In Hugging Face’s implementation, logits-related loss calculations create three intermediate logits values. The total size for these is:
3 * batch * seq * vocab * fp32_dtype.
Given a vocab size of 152,064 (e.g., for Qwen-2.5), this results in:
3 * 1 * 4096 * 152064 * 4 = 6.96 GB.

Since your sequence length and batch size are relatively small, the total GPU memory usage is primarily dominated by static memory. A summary of VRAM usage per GPU:

  1. Static Memory: 28GB
  2. Checkpointed Values: 1.875GB
  3. Logits-Related Intermediate Values: 6.96GB
  4. Other Intermediate Values: Includes temporary activation values within one decoder layer (due to Hugging Face’s gradient checkpointing before and after each decoder layer).

On Liger Kernel and Chunking Optimization

Liger Kernel significantly reduces memory peaks related to logits. By chunking and fusing, the peak memory for logits is limited to:
chunk_size * vocab * dtype
For your case, with a chunk size of 256, the peak memory is negligible (in MBs level). This optimization prevents logits-related memory peaks from becoming a bottleneck, even if batch size or sequence length increases.

Back to Your Situation

The observed 6% memory reduction after enabling Liger Kernel is because your batch size (1) and sequence length (4096) are small, so static memory dominates overall usage. However, enabling Liger Kernel allows you to scale batch size and sequence length without significant memory spikes. Without it, increasing the batch size from 1 to 2 would double logits-related memory from 6.96GB to 13.92GB, likely causing an OOM error.

Conclusion

In conclusion, I don't believe there are inherent limitations with the Liger Kernel in large model training. The issue arises when there is a limited number of GPUs available to partition the massive static memory required. The first bottleneck encountered is the static memory. Once this bottleneck is addressed (e.g., by using more GPUs for partitioning or employing offloading techniques), the next bottleneck shifts to the dynamic memory, such as activations or intermediate values, which is where the Liger Kernel excels. For an extreme example, if you are training a 70B model on 8xA100 GPUs, you may encounter an out-of-memory (OOM) error during the model loading stage. At that point, there are no activations or dynamic allocations in play; therefore, enabling or disabling the Liger Kernel will not prevent the OOM error because the initial static memory bottleneck remains unresolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants