diff --git a/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-mss.yaml b/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-mss.yaml new file mode 100644 index 0000000000000..8dc02ce0765b2 --- /dev/null +++ b/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-mss.yaml @@ -0,0 +1,16 @@ +# FIXME(kzawora): these scores were generated using vLLM on HPU, we need to confirm them on HF +# VLLM_SKIP_WARMUP=true bash run-lm-eval-gsm-cot-llama-vllm-baseline.sh -m "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct" -b 128 -l 1319 -f 8 -t 1 +model_name: "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-8B-Instruct" +tasks: +- name: "gsm8k_cot_llama" + metrics: + - name: "exact_match,strict-match" + value: 0.8317 + - name: "exact_match,flexible-extract" + value: 0.8355 +limit: null +num_fewshot: 8 +dtype: "bfloat16" +fewshot_as_multiturn: true +apply_chat_template: true +num_scheduler_steps: 2 \ No newline at end of file diff --git a/.jenkins/lm-eval-harness/configs/models-mss.txt b/.jenkins/lm-eval-harness/configs/models-mss.txt new file mode 100644 index 0000000000000..cfcc3d42d108f --- /dev/null +++ b/.jenkins/lm-eval-harness/configs/models-mss.txt @@ -0,0 +1 @@ +Meta-Llama-3.1-8B-Instruct-mss.yaml \ No newline at end of file diff --git a/.jenkins/lm-eval-harness/test_lm_eval_correctness.py b/.jenkins/lm-eval-harness/test_lm_eval_correctness.py index 3df0621f49a72..4fce75479972b 100644 --- a/.jenkins/lm-eval-harness/test_lm_eval_correctness.py +++ b/.jenkins/lm-eval-harness/test_lm_eval_correctness.py @@ -54,6 +54,9 @@ def launch_lm_eval(eval_config): model_args += ",quantization=inc," \ "kv_cache_dtype=fp8_inc," \ "weights_load_device=cpu" + if eval_config.get("num_scheduler_steps"): + model_args += \ + f",num_scheduler_steps={eval_config.get('num_scheduler_steps')}" kwargs = {} if 'fewshot_as_multiturn' in eval_config: kwargs['fewshot_as_multiturn'] = eval_config['fewshot_as_multiturn'] diff --git a/.jenkins/test_config.yaml b/.jenkins/test_config.yaml index b32563d6222e9..b4d09bfd85420 100644 --- a/.jenkins/test_config.yaml +++ b/.jenkins/test_config.yaml @@ -1,29 +1,40 @@ # test_config.yaml stages: - - name: test_gsm8k_small_models + # - name: test_gsm8k_small_models + # steps: + # - name: gsm8k_small_g3_tp1 + # flavor: g3 + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1 + # - name: gsm8k_small_g3_tp2 + # flavor: g3.s + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2 + # - name: gsm8k_small_g2_tp1 + # flavor: g2 + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1 + # - name: gsm8k_small_g2_tp2 + # flavor: g2.s + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2 + # - name: test_gsm8k_large_models + # steps: + # - name: gsm8k_large_g3_tp2 + # flavor: g3.s + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 2 + # - name: gsm8k_large_g2_tp4 + # flavor: g2.m + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 4 + # - name: test_gsm8k_fp8 + # steps: + # - name: gsm8k_small_g3_tp1_fp8 + # flavor: g3 + # command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-fp8.txt -t 1 + - name: test_gsm8k_mss steps: - - name: gsm8k_small_g3_tp1 + - name: gsm8k_small_g3_tp1_mss flavor: g3 - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1 - - name: gsm8k_small_g3_tp2 - flavor: g3.s - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2 - - name: gsm8k_small_g2_tp1 + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 1 + - name: gsm8k_small_g2_tp1_mss flavor: g2 - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 1 - - name: gsm8k_small_g2_tp2 - flavor: g2.s - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-small.txt -t 2 - - name: test_gsm8k_large_models - steps: - - name: gsm8k_large_g3_tp2 + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 1 + - name: gsm8k_small_g3_tp2_mss flavor: g3.s - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 2 - - name: gsm8k_large_g2_tp4 - flavor: g2.m - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-large.txt -t 4 - - name: test_gsm8k_fp8 - steps: - - name: gsm8k_small_g3_tp1_fp8 - flavor: g3 - command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-fp8.txt -t 1 + command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2 diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c50e4e244dffe..a0cd05bfbd062 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2109,6 +2109,19 @@ def execute_model( # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True + cache_orig_output_tokens_len: List[Dict] = [] + + def try_revert_dummy_output_tokens(): + if len(cache_orig_output_tokens_len) > 0: + # Reuse the original output token ids length + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + for j, data in seq_group_metadata.seq_data.items(): + orig_output_tokens_len = \ + cache_orig_output_tokens_len[i][j] + data.output_token_ids = \ + data.output_token_ids[:orig_output_tokens_len] + for i in range(num_steps): with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( @@ -2155,17 +2168,22 @@ def execute_model( htorch.core.mark_step() if i < num_steps - 1: if i == 0: - import copy ctx = model_input.async_callback.keywords[ # type: ignore "ctx"] seq_group_metadata_list = ctx.seq_group_metadata_list - seq_group_metadata_list = copy.deepcopy( - seq_group_metadata_list) + # Cache the original output token ids + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + cache_orig_output_tokens_len.append({}) + for j, data in seq_group_metadata.seq_data.items(): + cache_orig_output_tokens_len[i][j] = \ + len(data.output_token_ids) for seq_group_metadata in seq_group_metadata_list: for data in seq_group_metadata.seq_data.values(): max_output_len = sampling_metadata.seq_groups[ 0].sampling_params.max_tokens if len(data.output_token_ids) < max_output_len - 1: + # add a place holder for prepare_decode # arbitrary value, this could be any token dummy_token = (540, ) data.output_token_ids += (dummy_token) @@ -2173,6 +2191,7 @@ def execute_model( if num_steps == 1: return [output] else: + try_revert_dummy_output_tokens() return [] result = self._prepare_decode(seq_group_metadata_list, @@ -2185,6 +2204,8 @@ def execute_model( "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) + else: + try_revert_dummy_output_tokens() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event