From 0a5d52750833433bddf82698740e04ec9752f1f5 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 27 Dec 2024 00:58:59 +0000 Subject: [PATCH] Update fetch_server_test_models.py --- scripts/fetch_server_test_models.py | 124 ++++++++++++++++------------ 1 file changed, 72 insertions(+), 52 deletions(-) mode change 100644 => 100755 scripts/fetch_server_test_models.py diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py old mode 100644 new mode 100755 index 75da54a5dd536..7d7aa2b5992dc --- a/scripts/fetch_server_test_models.py +++ b/scripts/fetch_server_test_models.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python ''' This script fetches all the models used in the server tests. @@ -7,13 +8,14 @@ Example: python scripts/fetch_server_test_models.py - ( cd examples/server/tests && ./tests.sh --tags=slow ) + ( cd examples/server/tests && ./tests.sh -v -x -m slow ) ''' -from behave.parser import Parser +import ast import glob +import logging import os +from typing import Generator from pydantic import BaseModel -import re import subprocess import sys @@ -26,53 +28,71 @@ class Config: frozen = True -models = set() - -model_file_re = re.compile(r'a model file ([^\s\n\r]+) from HF repo ([^\s\n\r]+)') - - -def process_step(step): - if (match := model_file_re.search(step.name)): - (hf_file, hf_repo) = match.groups() - models.add(HuggingFaceModel(hf_repo=hf_repo, hf_file=hf_file)) - - -feature_files = glob.glob( - os.path.join( - os.path.dirname(__file__), - '../examples/server/tests/features/*.feature')) - -for feature_file in feature_files: - with open(feature_file, 'r') as file: - feature = Parser().parse(file.read()) - if not feature: continue - - if feature.background: - for step in feature.background.steps: - process_step(step) - - for scenario in feature.walk_scenarios(with_outlines=True): - for step in scenario.steps: - process_step(step) - -cli_path = os.environ.get( - 'LLAMA_SERVER_BIN_PATH', - os.path.join( - os.path.dirname(__file__), - '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) - -for m in sorted(list(models), key=lambda m: m.hf_repo): - if '<' in m.hf_repo or '<' in m.hf_file: - continue - if '-of-' in m.hf_file: - print(f'# Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file', file=sys.stderr) - continue - print(f'# Ensuring model at {m.hf_repo} / {m.hf_file} is fetched') - cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable'] - if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'): - cmd.append('-fa') +def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: try: - subprocess.check_call(cmd) - except subprocess.CalledProcessError: - print(f'# Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}', file=sys.stderr) - exit(1) + with open(test_file) as f: + tree = ast.parse(f.read()) + except Exception as e: + logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') + return + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for dec in node.decorator_list: + if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': + param_names = ast.literal_eval(dec.args[0]).split(",") + if not "hf_repo" in param_names or not "hf_file" in param_names: + continue + + raw_param_values = dec.args[1] + if not isinstance(raw_param_values, ast.List): + logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') + continue + + hf_repo_idx = param_names.index("hf_repo") + hf_file_idx = param_names.index("hf_file") + + for t in raw_param_values.elts: + if not isinstance(t, ast.Tuple): + logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') + continue + yield HuggingFaceModel( + hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), + hf_file=ast.literal_eval(t.elts[hf_file_idx])) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + + models = sorted(list(set([ + model + for test_file in glob.glob('examples/server/tests/unit/test_*.py') + for model in collect_hf_model_test_parameters(test_file) + ])), key=lambda m: (m.hf_repo, m.hf_file)) + + logging.info(f'Found {len(models)} models in parameterized tests:') + for m in models: + logging.info(f' - {m.hf_repo} / {m.hf_file}') + + cli_path = os.environ.get( + 'LLAMA_SERVER_BIN_PATH', + os.path.join( + os.path.dirname(__file__), + '../build/bin/Release/llama-cli.exe' if os.name == 'nt' \ + else '../build/bin/llama-cli')) + + for m in models: + if '<' in m.hf_repo or '<' in m.hf_file: + continue + if '-of-' in m.hf_file: + logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') + continue + logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') + cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable'] + if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'): + cmd.append('-fa') + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError: + logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') + exit(1)