Skip to content

Commit

Permalink
Update fetch_server_test_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Dec 27, 2024
1 parent 0e87ae2 commit 0a5d527
Showing 1 changed file with 72 additions and 52 deletions.
124 changes: 72 additions & 52 deletions scripts/fetch_server_test_models.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
'''
This script fetches all the models used in the server tests.
Expand All @@ -7,13 +8,14 @@
Example:
python scripts/fetch_server_test_models.py
( cd examples/server/tests && ./tests.sh --tags=slow )
( cd examples/server/tests && ./tests.sh -v -x -m slow )
'''
from behave.parser import Parser
import ast
import glob
import logging
import os
from typing import Generator
from pydantic import BaseModel
import re
import subprocess
import sys

Expand All @@ -26,53 +28,71 @@ class Config:
frozen = True


models = set()

model_file_re = re.compile(r'a model file ([^\s\n\r]+) from HF repo ([^\s\n\r]+)')


def process_step(step):
if (match := model_file_re.search(step.name)):
(hf_file, hf_repo) = match.groups()
models.add(HuggingFaceModel(hf_repo=hf_repo, hf_file=hf_file))


feature_files = glob.glob(
os.path.join(
os.path.dirname(__file__),
'../examples/server/tests/features/*.feature'))

for feature_file in feature_files:
with open(feature_file, 'r') as file:
feature = Parser().parse(file.read())
if not feature: continue

if feature.background:
for step in feature.background.steps:
process_step(step)

for scenario in feature.walk_scenarios(with_outlines=True):
for step in scenario.steps:
process_step(step)

cli_path = os.environ.get(
'LLAMA_SERVER_BIN_PATH',
os.path.join(
os.path.dirname(__file__),
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))

for m in sorted(list(models), key=lambda m: m.hf_repo):
if '<' in m.hf_repo or '<' in m.hf_file:
continue
if '-of-' in m.hf_file:
print(f'# Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file', file=sys.stderr)
continue
print(f'# Ensuring model at {m.hf_repo} / {m.hf_file} is fetched')
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
cmd.append('-fa')
def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:
print(f'# Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}', file=sys.stderr)
exit(1)
with open(test_file) as f:
tree = ast.parse(f.read())
except Exception as e:
logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
return

for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list:
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
param_names = ast.literal_eval(dec.args[0]).split(",")
if not "hf_repo" in param_names or not "hf_file" in param_names:
continue

raw_param_values = dec.args[1]
if not isinstance(raw_param_values, ast.List):
logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
continue

hf_repo_idx = param_names.index("hf_repo")
hf_file_idx = param_names.index("hf_file")

for t in raw_param_values.elts:
if not isinstance(t, ast.Tuple):
logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
continue
yield HuggingFaceModel(
hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
hf_file=ast.literal_eval(t.elts[hf_file_idx]))


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

models = sorted(list(set([
model
for test_file in glob.glob('examples/server/tests/unit/test_*.py')
for model in collect_hf_model_test_parameters(test_file)
])), key=lambda m: (m.hf_repo, m.hf_file))

logging.info(f'Found {len(models)} models in parameterized tests:')
for m in models:
logging.info(f' - {m.hf_repo} / {m.hf_file}')

cli_path = os.environ.get(
'LLAMA_SERVER_BIN_PATH',
os.path.join(
os.path.dirname(__file__),
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' \
else '../build/bin/llama-cli'))

for m in models:
if '<' in m.hf_repo or '<' in m.hf_file:
continue
if '-of-' in m.hf_file:
logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
continue
logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
cmd.append('-fa')
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError:
logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
exit(1)

0 comments on commit 0a5d527

Please sign in to comment.