Skip to content

Commit

Permalink
HPU: offload logits processing to CPU (HabanaAI#358)
Browse files Browse the repository at this point in the history
Due to high dynamicity on logits processing it's better to offload it
completely to CPU instead of computing it on HPU.
  • Loading branch information
madamczykhabana authored Oct 29, 2024
1 parent 3e135ae commit 3203bd9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
58 changes: 41 additions & 17 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,48 @@
from transformers import PreTrainedTokenizerBase


# Unfortunately we cannot use lru_cache as it breaks pickling
# so we use a simpler implementation
def _cached(fn):
cache = {}

def cached_fn(*args):
if args in cache:
result = cache[args]
else:
result = fn(*args)
cache[args] = result
return result

return cached_fn


class BaseLogitsProcessor:

def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self._cached_get_mask_tensor = _cached(self._get_mask_tensor)

@staticmethod
@lru_cache(maxsize=128)
def _create_mask_tensor(allowed_tokens, vocab_size, device):
mask = torch.full((vocab_size, ), -math.inf, device=device)
mask[list(allowed_tokens)] = 0
return mask

def _get_mask_tensor(self, state_id, vocab_size, device):
instruction = self._guide.get_next_instruction(state=state_id)
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")
return BaseLogitsProcessor._create_mask_tensor(tuple(allowed_tokens),
vocab_size, device)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -64,23 +101,10 @@ def __call__(self, input_ids: List[int],
import_paths=[grammars.GRAMMAR_PATH],
)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])

if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores = scores.add(mask)
state_id = self._fsm_state[seq_id]
mask = self._cached_get_mask_tensor(state_id, scores.size(-1),
scores.device)
scores.add_(mask)
return scores


Expand Down
23 changes: 20 additions & 3 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,28 @@ def _prune_hidden_states(
return hidden_states


def get_num_parameters(logits_processor):
"""Extracts the number of parameters from the
signature and stores it for further use"""
if hasattr(logits_processor, 'num_parameters'):
return logits_processor.num_parameters
logits_processor.num_parameters = len(
inspect.signature(logits_processor).parameters)
return logits_processor.num_parameters


def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
found_logits_processors = any(
seq_group.sampling_params.logits_processors
for seq_group in sampling_metadata.seq_groups)
offload_to_cpu = current_platform.is_hpu() and found_logits_processors
if offload_to_cpu:
logits_device = logits.device
logits = logits.cpu()
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
Expand All @@ -138,8 +154,7 @@ def _apply_logits_processors(
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids

for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
if get_num_parameters(logits_processor) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
Expand All @@ -155,4 +170,6 @@ def _apply_logits_processors(
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
if offload_to_cpu:
logits = logits.to(logits_device)
return logits

0 comments on commit 3203bd9

Please sign in to comment.