Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IndexError caused by invalid token IDs in CFGGuide #1251

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def __init__(self, cfg_string: str, tokenizer):

self.cfg_string = cfg_string
self.tokenizer = tokenizer

# Set eos_token_id if available
self.eos_token_id = self.tokenizer.eos_token_id

self.parser = PartialLark(
cfg_string,
parser="lalr",
Expand Down Expand Up @@ -149,14 +152,20 @@ def get_next_instruction(self, state: CFGState) -> Instruction:
"""

if state.parser_state is None:
return Write(torch.tensor([self.eos_token_id]))
if self.eos_token_id is not None:
return Write(torch.tensor([self.eos_token_id]))
else:
return None # No instruction if eos_token_id is not set

valid_tokens = list(
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
self.iter_valid_token_ids(state, list(self.tokenizer.vocabulary.values()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why converting this to a list?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for converting self.tokenizer.vocabulary.values() to a list when passing it to the iter_valid_token_ids method is to ensure that we're working with a concrete, indexable collection of token IDs

)
if len(valid_tokens) == 1:
if not valid_tokens:
return None # No valid tokens to generate
elif len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))
else:
return Generate(torch.tensor(valid_tokens))

def iter_valid_token_ids(
self, state: CFGState, candidate_token_ids: list
Expand All @@ -177,11 +186,12 @@ def iter_valid_token_ids(
Valid token ids.
"""
if state.parser_state is None:
yield self.eos_token_id
if self.eos_token_id is not None:
yield self.eos_token_id
return

for token_id in candidate_token_ids:
if token_id == self.eos_token_id:
if token_id == self.eos_token_id and self.eos_token_id is not None:
if self.can_terminate_state(state):
yield token_id
else:
Expand Down Expand Up @@ -234,20 +244,14 @@ def _get_parser_state_token_applied(
"""
parser_state = copy.copy(state.parser_state) # prevent side effects

# normalize
if state.prev_token is None:
new_token_str = self.tokenizer.decode([token_id])[0]
else:
prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
0
]
new_token_str = combined_token_str[len(prev_token_str) :]

if new_token_str == "":
# Decode the token
token_str = self.tokenizer.decode([token_id])
if not token_str:
raise ValueError("empty next token")

# update parser with new token
new_token_str = token_str[0] # Assuming decode returns a list

# Update parser with new token
parser_state.lexer.state.text += new_token_str
self.parser.parse_from_state(parser_state, is_end=False)

Expand Down
52 changes: 35 additions & 17 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,17 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

Expand All @@ -107,19 +110,26 @@ def process_logits(
allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
instruction = self.guide.get_next_instruction(guide_state)
if instruction is None:
continue # Skip if no instruction is available
allowed_tokens = instruction.tokens
if allowed_tokens is None:
continue # Skip if no tokens are allowed
allowed_tokens = allowed_tokens.to(mask.device, non_blocking=True)

# Filter out invalid token IDs
allowed_tokens = allowed_tokens[allowed_tokens < logits.size(1)]
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token
batch_indices.append(torch.full_like(allowed_tokens, i))

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
if allowed_tokens_batch:
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)

mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))
mask[batch_indices_concat, allowed_tokens_concat] = False

logits = logits.masked_fill(mask, float("-inf"))

return logits

Expand Down Expand Up @@ -221,26 +231,34 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, guide_state in enumerate(sequence_states):
first_legal_token = next(
valid_tokens = list(
self.guide.iter_valid_token_ids(
guide_state, torch.argsort(logits[i], descending=True)
guide_state, torch.arange(logits.size(1), device=logits.device)
)
)
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]
if valid_tokens:
# Keep only valid tokens
mask[i, valid_tokens] = logits[i, valid_tokens]
else:
# No valid tokens; generation should stop
mask[i] = logits[i]

return mask