From 4c9bb9577e9d3d7e92bbd784e30dbc61a4596e61 Mon Sep 17 00:00:00 2001 From: Angel Luu Date: Tue, 17 Sep 2024 16:08:25 -0600 Subject: [PATCH] lint: fix more fmt errors Signed-off-by: Angel Luu --- tests/utils/test_merge_model_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/utils/test_merge_model_utils.py b/tests/utils/test_merge_model_utils.py index f079234e2..e6b5c2687 100644 --- a/tests/utils/test_merge_model_utils.py +++ b/tests/utils/test_merge_model_utils.py @@ -15,19 +15,23 @@ """Unit Tests for SFT Trainer's merge_model_utils functions """ +# Standard +import os +import tempfile + # Third Party from safetensors import safe_open -import tempfile -import torch -import os import pytest +import torch # Local from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens dir_path = os.path.dirname(os.path.realpath(__file__)) -DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS = os.path.join(dir_path, - "../artifacts/tuned_llama_with_added_tokens") +DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS = os.path.join( + dir_path, "../artifacts/tuned_llama_with_added_tokens" +) + @pytest.mark.skipif( not (torch.cuda.is_available()), @@ -40,8 +44,10 @@ def test_post_process_vLLM_adapters_new_tokens(): """ # first, double check dummy tuned llama has a lm_head.weight found_lm_head = False - with safe_open(os.path.join(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, "adapter_model.safetensors"), - framework="pt") as f: + with safe_open( + os.path.join(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, "adapter_model.safetensors"), + framework="pt", + ) as f: for k in f.keys(): if "lm_head.weight" in k: found_lm_head = True @@ -49,7 +55,9 @@ def test_post_process_vLLM_adapters_new_tokens(): # do the post processing with tempfile.TemporaryDirectory() as tempdir: - post_process_vLLM_adapters_new_tokens(DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir) + post_process_vLLM_adapters_new_tokens( + DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir + ) # check that new_embeddings.safetensors exist new_embeddings = os.path.join(tempdir, "new_embeddings.safetensors")