Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IndexError caused by invalid token IDs in CFGGuide #1251

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def __init__(self, cfg_string: str, tokenizer):

self.cfg_string = cfg_string
self.tokenizer = tokenizer

# Set eos_token_id if available
self.eos_token_id = self.tokenizer.eos_token_id

self.parser = PartialLark(
cfg_string,
parser="lalr",
Expand Down Expand Up @@ -147,14 +150,20 @@ def get_next_instruction(self, state: CFGState) -> Instruction:
"""

if state.parser_state is None:
return Write(torch.tensor([self.eos_token_id]))
if self.eos_token_id is not None:
return Write(torch.tensor([self.eos_token_id]))
else:
return None # No instruction if eos_token_id is not set

valid_tokens = list(
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
self.iter_valid_token_ids(state, list(self.tokenizer.vocabulary.values()))
RohitRathore1 marked this conversation as resolved.
Show resolved Hide resolved
)
if len(valid_tokens) == 1:
if not valid_tokens:
return None # No valid tokens to generate
elif len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))
else:
return Generate(torch.tensor(valid_tokens))

def iter_valid_token_ids(
self, state: CFGState, candidate_token_ids: list
Expand All @@ -175,11 +184,12 @@ def iter_valid_token_ids(
Valid token ids.
"""
if state.parser_state is None:
yield self.eos_token_id
if self.eos_token_id is not None:
yield self.eos_token_id
return

for token_id in candidate_token_ids:
if token_id == self.eos_token_id:
if token_id == self.eos_token_id and self.eos_token_id is not None:
if self.can_terminate_state(state):
yield token_id
else:
Expand Down Expand Up @@ -232,20 +242,14 @@ def _get_parser_state_token_applied(
"""
parser_state = copy.copy(state.parser_state) # prevent side effects

# normalize
if state.prev_token is None:
new_token_str = self.tokenizer.decode([token_id])[0]
else:
prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
0
]
new_token_str = combined_token_str[len(prev_token_str) :]

if new_token_str == "":
# Decode the token
token_str = self.tokenizer.decode([token_id])
if not token_str:
raise ValueError("empty next token")

# update parser with new token
new_token_str = token_str[0] # Assuming decode returns a list

# Update parser with new token
parser_state.lexer.state.text += new_token_str
self.parser.parse_from_state(parser_state, is_end=False)

Expand Down
52 changes: 35 additions & 17 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,17 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List[int] = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

Expand All @@ -108,19 +111,26 @@ def process_logits(
allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
instruction = self.guide.get_next_instruction(guide_state)
if instruction is None:
continue # Skip if no instruction is available
allowed_tokens = instruction.tokens
if allowed_tokens is None:
continue # Skip if no tokens are allowed
allowed_tokens = allowed_tokens.to(mask.device, non_blocking=True)

# Filter out invalid token IDs
allowed_tokens = allowed_tokens[allowed_tokens < logits.size(1)]
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token
batch_indices.append(torch.full_like(allowed_tokens, i))

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
if allowed_tokens_batch:
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)

mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))
mask[batch_indices_concat, allowed_tokens_concat] = False

logits = logits.masked_fill(mask, float("-inf"))

return logits

Expand Down Expand Up @@ -222,26 +232,34 @@ def process_logits(
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])

sequence_states: List = [] # vector of states corresponding to `input_ids`
sequence_states: List[Any] = [] # vector of states corresponding to `input_ids`

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
prev_state_key = hash(tuple(gen_ids[:-1].tolist()))
prev_state = self._guide_states.get(
prev_state_key, self.guide.initial_state
)
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
for i, guide_state in enumerate(sequence_states):
first_legal_token = next(
valid_tokens = list(
self.guide.iter_valid_token_ids(
guide_state, torch.argsort(logits[i], descending=True)
guide_state, torch.arange(logits.size(1), device=logits.device)
)
)
mask[i, [first_legal_token]] = logits[i, [first_legal_token]]
if valid_tokens:
# Keep only valid tokens
mask[i, valid_tokens] = logits[i, valid_tokens]
else:
# No valid tokens; generation should stop
mask[i] = logits[i]

return mask
42 changes: 42 additions & 0 deletions tests/fsm/test_cfg_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,45 @@ def test_cfg_grammar_sample(request, sample_name, tokenizer_name, cleanup_lark_i
state = cfg_guide.get_next_state(state, token_id)
final_instruction = cfg_guide.get_next_instruction(state)
assert tokenizer.eos_token_id in final_instruction.tokens


# Add the first test: Unit test with mock tokenizer
def test_invalid_eos_token_id_handling():
from outlines.fsm.guide import CFGGuide

# Mock tokenizer with limited vocabulary and invalid eos_token_id
class MockTokenizer:
vocabulary = {"a": 0, "b": 1}
token_to_id = vocabulary
id_to_token = {v: k for k, v in vocabulary.items()}
special_tokens = {}
eos_token_id = len(vocabulary) # Invalid eos_token_id

def decode(self, token_ids):
return [self.id_to_token.get(token_id, "") for token_id in token_ids]

# Define a simple CFG
cfg_string = r"""
?start: "a" "b"
"""

# Initialize the guide with the mock tokenizer
tokenizer = MockTokenizer()
guide = CFGGuide(cfg_string, tokenizer)

# Build the initial state
state = guide.initial_state
instruction = guide.get_next_instruction(state)
valid_tokens = instruction.tokens

# Check that valid_tokens do not contain invalid token IDs
invalid_tokens = [
token_id for token_id in valid_tokens if token_id >= len(tokenizer.vocabulary)
]
assert not invalid_tokens, f"Found invalid token IDs: {invalid_tokens}"

try:
next_token_id = valid_tokens[0] # Take the first valid token
next_state = guide.get_next_state(state, next_token_id) # noqa: F841
except IndexError as e:
pytest.fail(f"IndexError encountered: {e}")
Loading