Skip to content

Commit

Permalink
Merge branch 'HabanaAI:habana_main' into Intermediate_states_accumula…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
jkaniecki authored Dec 17, 2024
2 parents f4f9322 + da61ecf commit 8252fe6
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ tasks:
- name: "gsm8k_cot_llama"
metrics:
- name: "exact_match,strict-match"
value: 0.8317
value: 0.664
- name: "exact_match,flexible-extract"
value: 0.8355
limit: null
value: 0.676
limit: 250
num_fewshot: 8
dtype: "bfloat16"
fewshot_as_multiturn: true
Expand Down
16 changes: 16 additions & 0 deletions .jenkins/lm-eval-harness/inc_unit_scales_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "unit_scale",
"allowlist": {
"types": [],
"names": []
},
"blocklist": {
"types": [],
"names": [
"lm_head"
]
},
"dump_stats_path": ""
}
8 changes: 5 additions & 3 deletions .jenkins/lm-eval-harness/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ do
export PT_HPU_ENABLE_LAZY_COLLECTIVES=true
export VLLM_SKIP_WARMUP=true
RANDOM_SUFFIX=$(tr -dc A-Za-z0-9 </dev/urandom | head -c 4; echo)
JUNIT_SUFFIX=""
JUNIT_FAMILY=""
JUNIT_XML=""
if [[ -n "$TEST_RESULTS_DIR" ]]; then
LOG_DIR=$TEST_RESULTS_DIR
LOG_FILENAME="test_${MODEL_CONFIG}_${RANDOM_SUFFIX}.xml"
LOG_PATH="${LOG_DIR}/${LOG_FILENAME}"
JUNIT_SUFFIX="-o junit_family=xunit1 --junitxml=${LOG_PATH}"
JUNIT_FAMILY="-o junit_family=xunit1"
JUNIT_XML="--junitxml=${LOG_PATH}"
fi
pytest -s test_lm_eval_correctness.py "$JUNIT_SUFFIX" || LOCAL_SUCCESS=$?
pytest -s test_lm_eval_correctness.py "$JUNIT_FAMILY" "$JUNIT_XML" || LOCAL_SUCCESS=$?

if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
Expand Down
8 changes: 3 additions & 5 deletions .jenkins/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@
TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)


def setup_fp8(model_path, device_type):
flavor = f"g{device_type[-1]}"
normalized_model_name = Path(model_path).parts[-1].lower()
def setup_fp8():
os.environ[
"QUANT_CONFIG"] = \
f"/software/data/vllm-benchmarks/inc/{normalized_model_name}/maxabs_quant_{flavor}.json"
"inc_unit_scales_config.json"


def fail_on_exit():
Expand Down Expand Up @@ -147,7 +145,7 @@ def test_lm_eval_correctness(record_xml_attribute, record_property):

# Set up environment for FP8 inference
if eval_config.get("fp8"):
setup_fp8(eval_config["model_name"], platform)
setup_fp8()
# Launch eval requests.
start_time = time.perf_counter()
results = launch_lm_eval(eval_config)
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/multiproc_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def _check_executor_parameters(self):
f"please ensure that world_size ({world_size}) "
f"is less than than max local hpu count ({hpu_device_count})")

def shutdown_inc(self):
self._run_workers("shutdown_inc")

def __del__(self):
self.shutdown()

Expand Down
27 changes: 21 additions & 6 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))

if not is_fake_hpu() and htorch.utils.internal.is_lazy():
if not is_fake_hpu():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# Unfortunately one_hot on CPU
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
Expand Down Expand Up @@ -2019,6 +2019,19 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2105,8 +2118,8 @@ def execute_model(
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for i in range(len(cache_orig_output_tokens_len)):
seq_group_metadata = seq_group_metadata_list[i]
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
Expand Down Expand Up @@ -2184,16 +2197,18 @@ def try_revert_dummy_output_tokens():
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
for seq_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
Expand Down

0 comments on commit 8252fe6

Please sign in to comment.