diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index d46228fe9..bbd3f44ba 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -116,7 +116,10 @@ def __init__(self, cfg_string: str, tokenizer): self.cfg_string = cfg_string self.tokenizer = tokenizer + + # Set eos_token_id if available self.eos_token_id = self.tokenizer.eos_token_id + self.parser = PartialLark( cfg_string, parser="lalr", @@ -149,14 +152,20 @@ def get_next_instruction(self, state: CFGState) -> Instruction: """ if state.parser_state is None: - return Write(torch.tensor([self.eos_token_id])) + if self.eos_token_id is not None: + return Write(torch.tensor([self.eos_token_id])) + else: + return None # No instruction if eos_token_id is not set valid_tokens = list( - self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values()) + self.iter_valid_token_ids(state, list(self.tokenizer.vocabulary.values())) ) - if len(valid_tokens) == 1: + if not valid_tokens: + return None # No valid tokens to generate + elif len(valid_tokens) == 1: return Write(torch.tensor(valid_tokens)) - return Generate(torch.tensor(valid_tokens)) + else: + return Generate(torch.tensor(valid_tokens)) def iter_valid_token_ids( self, state: CFGState, candidate_token_ids: list @@ -177,11 +186,12 @@ def iter_valid_token_ids( Valid token ids. """ if state.parser_state is None: - yield self.eos_token_id + if self.eos_token_id is not None: + yield self.eos_token_id return for token_id in candidate_token_ids: - if token_id == self.eos_token_id: + if token_id == self.eos_token_id and self.eos_token_id is not None: if self.can_terminate_state(state): yield token_id else: @@ -234,20 +244,14 @@ def _get_parser_state_token_applied( """ parser_state = copy.copy(state.parser_state) # prevent side effects - # normalize - if state.prev_token is None: - new_token_str = self.tokenizer.decode([token_id])[0] - else: - prev_token_str = self.tokenizer.decode([[state.prev_token]])[0] - combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[ - 0 - ] - new_token_str = combined_token_str[len(prev_token_str) :] - - if new_token_str == "": + # Decode the token + token_str = self.tokenizer.decode([token_id]) + if not token_str: raise ValueError("empty next token") - # update parser with new token + new_token_str = token_str[0] # Assuming decode returns a list + + # Update parser with new token parser_state.lexer.state.text += new_token_str self.parser.parse_from_state(parser_state, is_end=False) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d2bc15f77..50ae6e3ee 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -89,14 +89,17 @@ def process_logits( if self._seq_start_idx is None: self._seq_start_idx = len(input_ids[0]) - sequence_states: List[int] = [] # vector of states corresponding to `input_ids` + sequence_states: List[Any] = [] # vector of states corresponding to `input_ids` for seq_ids in input_ids: gen_ids = seq_ids[self._seq_start_idx :] curr_state_key = hash(tuple(gen_ids.tolist())) if curr_state_key not in self._guide_states: - prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + prev_state_key = hash(tuple(gen_ids[:-1].tolist())) + prev_state = self._guide_states.get( + prev_state_key, self.guide.initial_state + ) curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) self._guide_states[curr_state_key] = curr_state @@ -107,19 +110,26 @@ def process_logits( allowed_tokens_batch = [] batch_indices = [] for i, guide_state in enumerate(sequence_states): - allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( - mask.device, non_blocking=True - ) + instruction = self.guide.get_next_instruction(guide_state) + if instruction is None: + continue # Skip if no instruction is available + allowed_tokens = instruction.tokens + if allowed_tokens is None: + continue # Skip if no tokens are allowed + allowed_tokens = allowed_tokens.to(mask.device, non_blocking=True) + + # Filter out invalid token IDs + allowed_tokens = allowed_tokens[allowed_tokens < logits.size(1)] allowed_tokens_batch.append(allowed_tokens) - batch_indices.append( - torch.full_like(allowed_tokens, i) - ) # Store batch index for each allowed token + batch_indices.append(torch.full_like(allowed_tokens, i)) - allowed_tokens_concat = torch.cat(allowed_tokens_batch) - batch_indices_concat = torch.cat(batch_indices) + if allowed_tokens_batch: + allowed_tokens_concat = torch.cat(allowed_tokens_batch) + batch_indices_concat = torch.cat(batch_indices) - mask[batch_indices_concat, allowed_tokens_concat] = False - logits.masked_fill_(mask, float("-inf")) + mask[batch_indices_concat, allowed_tokens_concat] = False + + logits = logits.masked_fill(mask, float("-inf")) return logits @@ -221,14 +231,17 @@ def process_logits( if self._seq_start_idx is None: self._seq_start_idx = len(input_ids[0]) - sequence_states: List = [] # vector of states corresponding to `input_ids` + sequence_states: List[Any] = [] # vector of states corresponding to `input_ids` for seq_ids in input_ids: gen_ids = seq_ids[self._seq_start_idx :] curr_state_key = hash(tuple(gen_ids.tolist())) if curr_state_key not in self._guide_states: - prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + prev_state_key = hash(tuple(gen_ids[:-1].tolist())) + prev_state = self._guide_states.get( + prev_state_key, self.guide.initial_state + ) curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) self._guide_states[curr_state_key] = curr_state @@ -236,11 +249,16 @@ def process_logits( mask = torch.full_like(logits, -math.inf) for i, guide_state in enumerate(sequence_states): - first_legal_token = next( + valid_tokens = list( self.guide.iter_valid_token_ids( - guide_state, torch.argsort(logits[i], descending=True) + guide_state, torch.arange(logits.size(1), device=logits.device) ) ) - mask[i, [first_legal_token]] = logits[i, [first_legal_token]] + if valid_tokens: + # Keep only valid tokens + mask[i, valid_tokens] = logits[i, valid_tokens] + else: + # No valid tokens; generation should stop + mask[i] = logits[i] return mask