Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument return_z_loss to flce #530

Merged
merged 1 commit into from
Jan 18, 2025
Merged

Add argument return_z_loss to flce #530

merged 1 commit into from
Jan 18, 2025

Conversation

Tcc0403
Copy link
Collaborator

@Tcc0403 Tcc0403 commented Jan 18, 2025

Summary

Resolves #527

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Jan 18, 2025

unit tests only failed on irrelevant test (rope) due to transformers v4.48.0 breaking changes

❯ uv pip install transformers==4.47.1
Resolved 17 packages in 387ms
Uninstalled 1 package in 110ms
Installed 1 package in 180ms
 - transformers==4.48.0
 + transformers==4.47.1
❯ make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
====================================================== test session starts =======================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.0, xdist-3.6.1
collected 1021 items
...
=============================== 806 passed, 215 skipped, 41 warnings, 1 rerun in 174.73s (0:02:54) ===============================
❯ make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
====================================================== test session starts =======================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.0, xdist-3.6.1
collecting ...
------------------------------------------------------ live log collection -------------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1 available.
collected 17 items

test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] PASSED [  5%]
test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] PASSED   [ 11%]
test/convergence/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 17%]
test/convergence/test_mini_models.py::test_mini_model[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] PASSED   [ 23%]
test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 29%]
test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype5-0.001-0.01-0.1-0.01-0.01-0.01] PASSED    [ 35%]
test/convergence/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype6-1e-05-0.1-0.005-1e-05-0.005-1e-05] PASSED [ 41%]
test/convergence/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-0.001-0.05-0.1-0.01-0.01-0.01] PASSED [ 47%]
test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 52%]
test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype9-0.001-0.01-0.1-0.01-0.01-0.01] PASSED     [ 58%]
test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype10-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 64%]
test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype11-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 70%]
test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 76%]
test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-0.001-0.01-0.1-0.01-0.01-0.01] PASSED  [ 82%]
test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%]
test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 94%]
test/convergence/test_mini_models.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [100%]

=========================================== 17 passed, 1 warning in 150.97s (0:02:30) ============================================
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
====================================================== test session starts =======================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.0, xdist-3.6.1
collecting ...
------------------------------------------------------ live log collection -------------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1 available.
collected 4 items

test/convergence/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 25%]
test/convergence/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 50%]
test/convergence/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 75%]
test/convergence/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [100%]

============================================ 4 passed, 5 warnings in 83.36s (0:01:23) ============================================
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
====================================================== test session starts =======================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
configfile: pyproject.toml
plugins: rerunfailures-15.0, xdist-3.6.1
collecting ...
------------------------------------------------------ live log collection -------------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1 available.
collected 17 items

test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_llama3-32-0.0001-dtype0-1e-08-2e-05-0.0001-1e-05-0.005-1e-05] PASSED [  5%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_llama3-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 11%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_mllama-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 17%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_mllama-32-0.0001-dtype3-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 23%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype4-1e-08-1e-05-0.005-1e-05-0.005-1e-05]
 PASSED [ 29%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype5-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 35%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype6-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 41%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 47%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype8-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 52%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype9-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 58%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype10-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 64%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype11-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 70%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1-32-0.0001-dtype12-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 76%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1-32-0.0001-dtype13-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 82%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [ 88%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01]PPASSED [ 94%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED [100%]

=========================================== 17 passed, 1 warning in 150.29s (0:02:30) ============================================

@Tcc0403 Tcc0403 merged commit f958596 into main Jan 18, 2025
3 of 5 checks passed
@Tcc0403 Tcc0403 deleted the tcc/flce-return-z-loss branch January 18, 2025 11:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

return_z_loss is not supported for LigerFusedLinearCrossEntropyFunction and LigerFusedLinearCrossEntropyLoss
2 participants