diff --git a/requirements-test.txt b/requirements-test.txt index 9b88fcce3e842..178fb15ad121c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -22,4 +22,8 @@ timm # required for internvl test aiohttp # quantization -bitsandbytes==0.42.0 \ No newline at end of file +bitsandbytes==0.42.0 + +# HPU lm_eval tests +lm_eval +immutabledict \ No newline at end of file diff --git a/tests/hpu/meta-configs/bbh/bbh_3shot_cot.yaml b/tests/hpu/meta-configs/bbh/bbh_3shot_cot.yaml new file mode 100755 index 0000000000000..7ffb925e17950 --- /dev/null +++ b/tests/hpu/meta-configs/bbh/bbh_3shot_cot.yaml @@ -0,0 +1,28 @@ +dataset_path: meta-llama/Meta-Llama-3.1-8B-evals +dataset_name: Meta-Llama-3.1-8B-evals__bbh__details +task: meta_bbh +output_type: generate_until +process_docs: !function bbh_utils.process_docs +test_split: latest +doc_to_text: !function bbh_utils.doc_to_text +doc_to_target: answer +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: 'the answer is (.*?)\.' + - function: "take_first" +generation_kwargs: + until: "\n\nQ: " + do_sample: false + temperature: 0 + max_gen_toks: 512 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/tests/hpu/meta-configs/bbh/bbh_utils.py b/tests/hpu/meta-configs/bbh/bbh_utils.py new file mode 100755 index 0000000000000..1e1ed449c7805 --- /dev/null +++ b/tests/hpu/meta-configs/bbh/bbh_utils.py @@ -0,0 +1,21 @@ +import random +import re + +import datasets + + + +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "answer": doc["input_correct_responses"][0], + } + return out_doc + dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","output_prediction_text"]) + dataset = dataset.rename_column("is_correct","previously_is_correct") + dataset = dataset.map(_process_doc) + return dataset.map(_process_doc) diff --git a/tests/hpu/meta-configs/gpqa_cot/gpqa_0shot_cot.yaml b/tests/hpu/meta-configs/gpqa_cot/gpqa_0shot_cot.yaml new file mode 100755 index 0000000000000..1a7d8ec6e8eec --- /dev/null +++ b/tests/hpu/meta-configs/gpqa_cot/gpqa_0shot_cot.yaml @@ -0,0 +1,29 @@ +dataset_path: meta-llama/Meta-Llama-3.1-8B-Instruct-evals +dataset_name: Meta-Llama-3.1-8B-Instruct-evals__gpqa__details +task: meta_gpqa +output_type: generate_until +process_docs: !function gpqa_utils.process_docs +test_split: latest +doc_to_text: !function gpqa_utils.doc_to_text +doc_to_target: gold +filter_list: + - name: "strict-match" + filter: + - function: "regex" + group_select: -1 + regex_pattern: 'best answer is ([A-Z])' + - function: "take_first" +generation_kwargs: + until: [] + do_sample: false + temperature: 0 + max_gen_toks: 2048 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/tests/hpu/meta-configs/gpqa_cot/gpqa_utils.py b/tests/hpu/meta-configs/gpqa_cot/gpqa_utils.py new file mode 100755 index 0000000000000..6a0349fca5edd --- /dev/null +++ b/tests/hpu/meta-configs/gpqa_cot/gpqa_utils.py @@ -0,0 +1,20 @@ +import random +import re + +import datasets + + + +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "gold": doc["input_correct_responses"][0], + } + return out_doc + dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","input_choice_list","output_prediction_text"]) + dataset = dataset.rename_column("is_correct","previously_is_correct") + dataset = dataset.map(_process_doc) + return dataset.map(_process_doc) diff --git a/tests/hpu/meta-configs/ifeval/ifeval.yaml b/tests/hpu/meta-configs/ifeval/ifeval.yaml new file mode 100755 index 0000000000000..21019dc67af28 --- /dev/null +++ b/tests/hpu/meta-configs/ifeval/ifeval.yaml @@ -0,0 +1,32 @@ +task: meta_ifeval +dataset_path: parquet +dataset_kwargs: + data_files: ./meta-configs/joined_ifeval.parquet +output_type: generate_until +test_split: train +num_fewshot: 0 +doc_to_text: prompt +doc_to_target: 0 +generation_kwargs: + until: [] + do_sample: false + temperature: 0.0 + max_gen_toks: 1280 +process_results: !function ifeval_utils.process_results +metric_list: + - metric: prompt_level_strict_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_strict_acc + aggregation: !function ifeval_utils.agg_inst_level_acc + higher_is_better: true + - metric: prompt_level_loose_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_loose_acc + aggregation: !function ifeval_utils.agg_inst_level_acc + higher_is_better: true +metadata: + version: 2.0 +fewshot_config: + sampler: first_n diff --git a/tests/hpu/meta-configs/ifeval/ifeval_utils.py b/tests/hpu/meta-configs/ifeval/ifeval_utils.py new file mode 100755 index 0000000000000..73e2aab902170 --- /dev/null +++ b/tests/hpu/meta-configs/ifeval/ifeval_utils.py @@ -0,0 +1,139 @@ +import dataclasses +from typing import Dict, List, Optional, Union + +from lm_eval.tasks.ifeval import instructions_registry + + +@dataclasses.dataclass +class InputExample: + key: int + instruction_id_list: List[str] + prompt: str + kwargs: List[Dict[str, Optional[Union[str, int]]]] + + +@dataclasses.dataclass +class OutputExample: + instruction_id_list: List[str] + prompt: str + response: str + follow_all_instructions: bool + follow_instruction_list: List[bool] + + +def test_instruction_following_strict( + inp, + response, +): + """Tests response to see if instructions are followed.""" + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + if response.strip() and instruction.check_following(response): + is_following_list.append(True) + else: + is_following_list.append(False) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def test_instruction_following_loose( + inp, + response, +): + """Tests response for an upper bound for following instructions.""" + r = response.split("\n") + response_remove_first = "\n".join(r[1:]).strip() + response_remove_last = "\n".join(r[:-1]).strip() + response_remove_both = "\n".join(r[1:-1]).strip() + revised_response = response.replace("*", "") + revised_response_remove_first = response_remove_first.replace("*", "") + revised_response_remove_last = response_remove_last.replace("*", "") + revised_response_remove_both = response_remove_both.replace("*", "") + all_responses = [ + response, + revised_response, + response_remove_first, + response_remove_last, + response_remove_both, + revised_response_remove_first, + revised_response_remove_last, + revised_response_remove_both, + ] + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + is_following = False + for r in all_responses: + if r.strip() and instruction.check_following(r): + is_following = True + break + + is_following_list.append(is_following) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def process_results(doc, results): + new_kwargs = [] + for item in doc["kwargs"]: + if item["nth_paragraph"]: + item["nth_paragraph"] = int(item["nth_paragraph"]) + new_kwargs.append(item) + inp = InputExample( + key=doc["key"], + instruction_id_list=doc["instruction_id_list"], + prompt=doc["prompt"], + kwargs=new_kwargs, + ) + response = results[0] + + out_strict = test_instruction_following_strict(inp, response) + out_loose = test_instruction_following_loose(inp, response) + + return { + "prompt_level_strict_acc": out_strict.follow_all_instructions, + "inst_level_strict_acc": out_strict.follow_instruction_list, + "prompt_level_loose_acc": out_loose.follow_all_instructions, + "inst_level_loose_acc": out_loose.follow_instruction_list, + } + + +def agg_inst_level_acc(items): + flat_items = [item for sublist in items for item in sublist] + inst_level_acc = sum(flat_items) / len(flat_items) + return inst_level_acc diff --git a/tests/hpu/meta-configs/joined_ifeval.parquet b/tests/hpu/meta-configs/joined_ifeval.parquet new file mode 100755 index 0000000000000..e485e780f4d2f Binary files /dev/null and b/tests/hpu/meta-configs/joined_ifeval.parquet differ diff --git a/tests/hpu/meta-configs/joined_math.parquet b/tests/hpu/meta-configs/joined_math.parquet new file mode 100755 index 0000000000000..dfc170d3f6d2e Binary files /dev/null and b/tests/hpu/meta-configs/joined_math.parquet differ diff --git a/tests/hpu/meta-configs/math_hard/math_hard_0shot_cot.yaml b/tests/hpu/meta-configs/math_hard/math_hard_0shot_cot.yaml new file mode 100755 index 0000000000000..b7ddf34e4766d --- /dev/null +++ b/tests/hpu/meta-configs/math_hard/math_hard_0shot_cot.yaml @@ -0,0 +1,21 @@ +dataset_path: parquet +dataset_kwargs: + data_files: ./meta-configs/joined_math.parquet +task: meta_math_hard +process_docs: !function math_hard_utils.process_docs +output_type: generate_until +test_split: train +doc_to_text: !function math_hard_utils.doc_to_text +process_results: !function math_hard_utils.process_results +doc_to_target: answer +generation_kwargs: + until: [] + do_sample: false + temperature: 0 + max_gen_toks: 5120 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/tests/hpu/meta-configs/math_hard/math_hard_utils.py b/tests/hpu/meta-configs/math_hard/math_hard_utils.py new file mode 100755 index 0000000000000..b6a7913f88faa --- /dev/null +++ b/tests/hpu/meta-configs/math_hard/math_hard_utils.py @@ -0,0 +1,269 @@ +# Most of the code taken from https://github.com/EleutherAI/lm-evaluation-harness/blob/cddce0a148ec1710e2d60546c6f92727dd8a78fd/lm_eval/tasks/leaderboard/math/utils.py +import re +import signal +from typing import Dict, List, Optional + +import datasets + +from lm_eval.utils import eval_logger + + +try: + import sympy + from sympy.parsing.latex import parse_latex +except ModuleNotFoundError: + raise ModuleNotFoundError( + "`sympy` is required for generating translation task prompt templates. \ +please install sympy via pip install lm-eval[math] or pip install -e .[math]", + ) + +# taken from +# https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "answer": normalize_final_answer( + remove_boxed(last_boxed_only_string(doc["solution"])) + ), + "meta_target": doc["input_correct_responses"] + } + return out_doc + return dataset.map(_process_doc) + + +def process_results(doc: dict, results: List[str]) -> Dict[str, int]: + candidates = results[0] + last_boxed_string = last_boxed_only_string(candidates) + if not last_boxed_string: + # No boxed string found, so we can't evaluate + return {"exact_match": 0} + unnormalized_answer = remove_boxed(last_boxed_string) + answer = normalize_final_answer(unnormalized_answer) + + if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]): + retval = 1 + else: + retval = 0 + + return { + "exact_match": retval, + } + + +def last_boxed_only_string(string: str) -> Optional[str]: + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def remove_boxed(s: Optional[str]) -> str: + assert s is not None + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def is_equiv(x1: str, x2: str) -> bool: + """ + x1 and x2 are normalized latex string + """ + try: + with timeout(seconds=5): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + eval_logger.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + eval_logger.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + eval_logger.debug( + f"Had some trouble simplifying when comparing {x1} and {x2}" + ) + return False + except TimeoutError: + eval_logger.debug(f"Timed out comparing {x1} and {x2}") + return False + except ImportError as e: + eval_logger.error(e) + raise + except Exception as e: + eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") + return False + + +def get_unnormalized_answer(text: str) -> str: + INVALID_ANSWER = "[invalidanswer]" + end_seq = "I hope it is correct." + text += end_seq + match = re.search( + r"Final Answer: The final answer is(.*?). I hope it is correct.", + text, + ) + if match: + return match.group(1).strip() + else: + return INVALID_ANSWER + + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer diff --git a/tests/hpu/meta-configs/meta_instruct.yaml b/tests/hpu/meta-configs/meta_instruct.yaml new file mode 100755 index 0000000000000..2f3a3d5454693 --- /dev/null +++ b/tests/hpu/meta-configs/meta_instruct.yaml @@ -0,0 +1,6 @@ +group: meta_instruct +task: + - meta_ifeval + - meta_math_hard + - meta_gpqa + - meta_mmlu_pro_instruct diff --git a/tests/hpu/meta-configs/meta_pretrain.yaml b/tests/hpu/meta-configs/meta_pretrain.yaml new file mode 100755 index 0000000000000..1b21dd1abed13 --- /dev/null +++ b/tests/hpu/meta-configs/meta_pretrain.yaml @@ -0,0 +1,4 @@ +group: meta_pretrain +task: + - meta_bbh + - meta_mmlu_pro_pretrain diff --git a/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml b/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml new file mode 100755 index 0000000000000..39a9e62cfba4f --- /dev/null +++ b/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_instruct.yaml @@ -0,0 +1,29 @@ +task: meta_mmlu_pro_instruct +dataset_path: meta-llama/Meta-Llama-3.1-8B-Instruct-evals +dataset_name: Meta-Llama-3.1-8B-Instruct-evals__mmlu_pro__details +test_split: latest +output_type: generate_until +process_docs: !function mmlu_utils.process_docs +doc_to_text: !function mmlu_utils.doc_to_text +doc_to_target: gold +filter_list: + - name: "strict-match" + filter: + - function: "regex" + group_select: -1 + regex_pattern: 'best answer is ([A-Z])' + - function: "take_first" +generation_kwargs: + until: [] + do_sample: false + temperature: 0 + max_gen_toks: 1024 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_pretrain.yaml b/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_pretrain.yaml new file mode 100755 index 0000000000000..dad98120ca0ea --- /dev/null +++ b/tests/hpu/meta-configs/mmlu_pro/mmlu_pro_5shot_cot_pretrain.yaml @@ -0,0 +1,28 @@ +task: meta_mmlu_pro_pretrain +dataset_path: meta-llama/Meta-Llama-3.1-8B-evals +dataset_name: Meta-Llama-3.1-8B-evals__mmlu_pro__details +test_split: latest +output_type: generate_until +process_docs: !function utils.process_docs +doc_to_text: !function utils.doc_to_text +doc_to_target: gold +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: 'answer is \(([A-Z])\)' + - function: "take_first" +generation_kwargs: + until: "\n\nQ: " + do_sample: false + temperature: 0 + max_gen_toks: 512 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/tests/hpu/meta-configs/mmlu_pro/mmlu_utils.py b/tests/hpu/meta-configs/mmlu_pro/mmlu_utils.py new file mode 100755 index 0000000000000..51d9f71f92a0a --- /dev/null +++ b/tests/hpu/meta-configs/mmlu_pro/mmlu_utils.py @@ -0,0 +1,21 @@ +import string + + +import datasets + + + +def doc_to_text(doc: dict) -> str: + return doc["input_final_prompts"][0] + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["input_question"], + "gold": doc["input_correct_responses"][0], + } + return out_doc + dataset = dataset.select_columns(["input_question", "input_correct_responses", "input_final_prompts", "is_correct","input_question_hash","input_choice_list","output_prediction_text"]) + dataset = dataset.rename_column("is_correct","previously_is_correct") + dataset = dataset.map(_process_doc) + return dataset.map(_process_doc) diff --git a/tests/hpu/test_hpu_lmeval.py b/tests/hpu/test_hpu_lmeval.py new file mode 100644 index 0000000000000..a59279662a74e --- /dev/null +++ b/tests/hpu/test_hpu_lmeval.py @@ -0,0 +1,322 @@ +import statistics +from dataclasses import replace + +from lm_eval import tasks, evaluator +import numpy as np +import pytest +import itertools +import time +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class LMTask: + + def __init__(self, lm_instance, task_cfg): + self.lm_instance = lm_instance + self.task_cfg = task_cfg + self.task_manager = tasks.TaskManager(include_path="./meta-configs") + assert "task_name" in self.task_cfg, ("Task config must contain " + "a task_name!") + self.task_name = self.task_cfg["task_name"] + self.task_dict = tasks.get_task_dict(self.task_name, self.task_manager) + if "task_config_overrides" in self.task_cfg: + self.task_dict[self.task_name]._config = replace( + self.task_dict[self.task_name]._config, + **self.task_cfg["task_config_overrides"]) + + def patch_parallel_state(self): + # NOTE(kzawora): This a really nasty workaround - for whatever reason, + # tensor and pipeline parallel states are getting corrupted when moving + # to a new test context and need to be re-initialized. + # Possibly other vllm globals can be corrupted too. + # Recognition to anyone who figures out how to prevent it, + # while maintaining vLLM instance reuse across tests. + # For now, we restore TP & PP groups saved in lm_instance. Nasty. + # If this makes you worried, remove vLLM instance reuse. + # (remove scope='module' from lm_instance fixture) + vllm = pytest.importorskip('vllm') + if vllm.distributed.parallel_state._PP is None: + vllm.distributed.parallel_state._PP = self.lm_instance.pp_group + logger.warning('vLLM pipeline parallel state is empty!') + if vllm.distributed.parallel_state._TP is None: + vllm.distributed.parallel_state._TP = self.lm_instance.tp_group + logger.warning('vLLM tensor parallel state is empty!') + if vllm.distributed.parallel_state._WORLD is None: + vllm.distributed.parallel_state._WORLD = self.lm_instance.world + logger.warning('vLLM world state is empty!') + + def run_evaluate(self): + self.patch_parallel_state() + if self.task_cfg.get('eval_kwargs', None) is None: + self.task_cfg["eval_kwargs"] = {} + results = evaluator.evaluate(lm=self.lm_instance.LM, + task_dict=self.task_dict, + **self.task_cfg["eval_kwargs"]) + return results + + +class LMInstance: + + def __init__(self, lm_instance_cfg, vllm): + self.model_name = lm_instance_cfg['model_name'] + self.cfg = lm_instance_cfg + from lm_eval.models.vllm_causallms import VLLM + self.LM = VLLM(**lm_instance_cfg["vllm_kwargs"], + **lm_instance_cfg["lm_eval_kwargs"]) + self.pp_group = vllm.distributed.parallel_state._PP + self.tp_group = vllm.distributed.parallel_state._TP + self.world = vllm.distributed.parallel_state._WORLD + + +# from lm_eval import api +# self.LM = api.registry.get_model("vllm").create_from_arg_obj( +# lm_instance_cfg["vllm_kwargs"], lm_instance_cfg["lm_eval_kwargs"]) + + +def assert_server_idle(lm): + running = len(lm.model.llm_engine.scheduler[0].running) + waiting = len(lm.model.llm_engine.scheduler[0].waiting) + swapped = len(lm.model.llm_engine.scheduler[0].swapped) + assert running == 0, f'There are {running} requests running!' + assert waiting == 0, f'There are {running} requests waiting!' + assert swapped == 0, f'There are {running} requests swapped!' + + +@pytest.fixture(scope='module') +def lm_instance(request): + vllm = pytest.importorskip('vllm') + lm = LMInstance(request.param, vllm) + assert_server_idle(lm.LM) + yield lm + assert_server_idle(lm.LM) + logger.debug('Destroying LM instance') + + +@pytest.fixture +def task_cfg(request) -> dict: + return request.param + + +@pytest.fixture(autouse=True) +def lm_task(lm_instance: LMInstance, task_cfg: dict): + task = LMTask(lm_instance, task_cfg) + assert_server_idle(task.lm_instance.LM) + yield task + assert_server_idle(task.lm_instance.LM) + logger.debug('Destroying task') + + +class LMConfigs: + llama3_1_8b_instruct_bs128_bf16 = { + "model_name": "Meta-Llama-3.1-8B-Instruct", + "lm_eval_kwargs": { + "batch_size": "auto" + }, + "vllm_kwargs": { + "pretrained": + "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct", + "max_num_seqs": 128, + "max_model_len": 8192, + "dtype": "bfloat16", + "data_parallel_size": 1, + "tensor_parallel_size": 1, + "disable_log_stats": False + }, + } + llama3_1_8b_bs128_bf16 = { + "model_name": "Meta-Llama-3.1-8B", + "lm_eval_kwargs": { + "batch_size": "auto" + }, + "vllm_kwargs": { + "pretrained": + "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct", + "max_num_seqs": 128, + "max_model_len": 8192, + "dtype": "bfloat16", + "data_parallel_size": 1, + "tensor_parallel_size": 1, + "disable_log_stats": False + }, + } + + +class TaskConfigs: + gsm8k_llama_cot = { + "task_name": "gsm8k_cot_llama", + "eval_kwargs": { + "limit": None, + "fewshot_as_multiturn": True, + "apply_chat_template": True, + }, + } + + ifeval = { + "task_name": "ifeval", + "task_config_overrides": { + "fewshot_config": { + "sampler": "first_n" + } + }, + "eval_kwargs": { + "limit": 10, + "fewshot_as_multiturn": True, + "apply_chat_template": True, + }, + } + + meta_mmlu_pro_instruct = { + "task_name": "meta_mmlu_pro_instruct", + } + + meta_mmlu_pro_pretrain = { + "task_name": "meta_mmlu_pro_pretrain", + } + + meta_math_hard = { + "task_name": "meta_math_hard", + } + + meta_gpqa = { + "task_name": "meta_gpqa", + } + + meta_ifeval = { + "task_name": "meta_ifeval", + } + + meta_bbh = { + "task_name": "meta_bbh", + } + + +class LMTaskTargets: + default_atol = 0.05 + default_rtol = 0.05 + targets = { + 'Meta-Llama-3.1-8B-Instruct': { + "gsm8k_cot_llama": { + 'score': 0.845 + }, + "ifeval": { + 'score': 0.804 + }, + "meta_math_hard": { + 'score': 0.804 + }, + "meta_gpqa": { + "score": 0.328 + }, + "meta_ifeval": { + 'score': 0.804 + }, + "meta_mmlu_pro_instruct": { + 'score': 0.47 + }, + }, + 'Meta-Llama-3.1-8B': { + "meta_mmlu_pro_pretrain": { + 'score': 0.356 + }, + "meta_bbh": { + 'score': 0.642 + }, + } + } + + +def get_task_name(task_cfg): + return task_cfg["task_name"] + + +@pytest.mark.parametrize('lm_instance', + [LMConfigs.llama3_1_8b_instruct_bs128_bf16], + ids=['llama3_1_8b_instruct_bs128_bf16'], + indirect=True) +@pytest.mark.parametrize("task_cfg", [ + TaskConfigs.gsm8k_llama_cot, TaskConfigs.ifeval, + TaskConfigs.meta_mmlu_pro_instruct, TaskConfigs.meta_ifeval, + TaskConfigs.meta_gpqa, TaskConfigs.meta_math_hard +], + ids=get_task_name, + indirect=True) +def test_task_instruct(lm_task: LMTask): + generic_test_task(lm_task) + + +@pytest.mark.parametrize('lm_instance', [LMConfigs.llama3_1_8b_bs128_bf16], + ids=['llama3_1_8b_bs128_bf16'], + indirect=True) +@pytest.mark.parametrize( + "task_cfg", [TaskConfigs.meta_bbh, TaskConfigs.meta_mmlu_pro_pretrain], + ids=get_task_name, + indirect=True) +def test_task_pretrain(lm_task: LMTask): + generic_test_task(lm_task) + + +def generic_test_task(lm_task: LMTask) -> None: + start = time.perf_counter() + res = lm_task.run_evaluate() + end = time.perf_counter() + total_time = end - start + task_name = lm_task.task_name + model_name = lm_task.lm_instance.model_name + metrics_to_extract = [ + m['metric'] + for m in lm_task.task_dict[lm_task.task_name]._config.metric_list + ] # ugh... + extracted_metrics = { + k: v + for k, v in res['results'][lm_task.task_name].items() + for metric in metrics_to_extract if metric in k and "stderr" not in k + } # UGH... + score = statistics.mean(extracted_metrics.values()) + target_dict = LMTaskTargets.targets[model_name][task_name] + target_score = target_dict['score'] + atol = target_dict[ + 'atol'] if 'atol' in target_dict else LMTaskTargets.default_atol + rtol = target_dict[ + 'rtol'] if 'rtol' in target_dict else LMTaskTargets.default_rtol + if True: + tokenizer = lm_task.lm_instance.LM.tokenizer + samples = res['samples'][lm_task.task_name] + # tokenized_inputs = [tokenizer(x['doc']['prompt'])['input_ids'] + # for x in samples] + tokenized_inputs = [ + tokenizer(x['arguments'][0][0])['input_ids'] for x in samples + ] + tokenized_inputs_lens = [len(x) for x in tokenized_inputs] + tokenized_outputs = [ + list( + itertools.chain.from_iterable( + tokenizer(list(itertools.chain.from_iterable( + x['resps'])))['input_ids'])) for x in samples + ] + tokenized_outputs_lens = [len(x) for x in tokenized_outputs] + report_accuracy(extracted_metrics, score, target_score, atol, rtol) + report_performance(tokenized_inputs_lens, tokenized_outputs_lens, + total_time) + + np.testing.assert_allclose(score, target_score, atol=atol, rtol=rtol) + + +def report_accuracy(metrics, score, target, atol, rtol): + logger.info( + f'accuracy: {metrics}\nfinal score: {score}\n, target: {target} (atol: {atol}, rtol: {rtol})' # noqa: G004, E501 + ) + + +def report_performance(input_lens, output_lens, time): + assert len(input_lens) == len(output_lens) + context_lens = [i + o for i, o in zip(input_lens, output_lens)] + gen_tput = sum(output_lens) / time + logger.info( + f'gen tput: {gen_tput:.2f} tok/s \n' # noqa: G004 + f'input_tokens | min: {min(input_lens)} | max: {max(input_lens)} | mean: {statistics.mean(input_lens):.2f} | stddev: {statistics.stdev(input_lens):.2f}\n' # noqa: E501 + f'output_tokens | min: {min(output_lens)} | max: {max(output_lens)} | mean: {statistics.mean(output_lens):.2f} | stddev: {statistics.stdev(output_lens):.2f}\n' # noqa: E501 + f'context_length | min: {min(context_lens)} | max: {max(context_lens)} | mean: {statistics.mean(context_lens):.2f} | stddev: {statistics.stdev(context_lens):.2f}\n' # noqa: E501 + )