Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposition on the implementation of token alignement #1239

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,38 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string, tokenizer)

# token alignment
crossing_tokens = find_crossing_tokens(self.states_to_token_maps, tokenizer.vocabulary)
highest_state = max(
max(self.states_to_token_maps.keys()),
max(max(items.values()) for items in self.states_to_token_maps.values()),
)
prefixes_map = create_prefixes_map({x[1] for x in crossing_tokens}, highest_state+1, tokenizer.vocabulary)
for item in crossing_tokens:
prefix_map = prefixes_map[item[1]]
self.states_to_token_maps.update(prefix_map)
prefix_map_starting_state = min(prefix_map.keys())
self.states_to_token_maps[prefix_map_starting_state][item[0]] = item[3]
self.crossing_tokens_prefixes_map = {key: min(value.keys()) for key, value in prefixes_map.items()}

self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}
self._cache_state_to_token_tensor()

def get_starting_states(self, prompts: List[str]) -> List[Tuple[int, str]]:
"""Get the starting state and the character sequence that should be removed from each prompt"""
results = []
for prompt in prompts:
longest_prefix = ""
target_state = self.initial_state
for prefix, starting_state in self.crossing_tokens_prefixes_map.items():
if prompt.endswith(prefix) and len(prefix) > len(longest_prefix):
longest_prefix = prefix
target_state = starting_state
results.append((target_state, longest_prefix))
return results

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.

Expand Down Expand Up @@ -475,3 +503,78 @@ def is_final_state(self, state: int) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the FSM."""
return CFGGuide(self.cfg_string, self.tokenizer)


### token alignment functions ###

def find_crossing_tokens(states_to_token_maps: dict, vocabulary: dict) -> List[Tuple[int, str, str, int]]:
"""Find the crossing tokens for a given states_to_token_maps.
Crossing tokens are tokens that can be decomposed into a prefix and a postfix,
such that the postfix is a valid sequence of characters for the states_to_token_maps.
Returns a list of tuples, where each tuple contains the token id, the prefix, the postfix and the target state.
"""

def get_target_state(vocabulary: dict, states_to_token_map: dict, char_seq: str):
"""Get the target state in the states_to_token_map for a sequence of characters.
Return None if the sequence is not valid.
"""
state = 0
for char in char_seq:
char_token = vocabulary.get(char)
try:
state = states_to_token_map[state][char_token]
except KeyError:
return None
return state

crossing_tokens = []
invalid_postfixes = set()
valid_postfixes = {}

for char_seq, token_id in vocabulary.items():
if len(char_seq) == 1:
continue
# we want to look at all possible "crossing positions" of the token (between char 1 and 2, 2 and 3, etc)
for i in range(1, len(char_seq)):
prefix = char_seq[:i]
postfix = char_seq[i:]
if postfix in invalid_postfixes:
continue
if postfix in valid_postfixes.keys():
crossing_tokens.append([token_id, prefix, postfix, valid_postfixes[postfix]])
continue
target_state = get_target_state(vocabulary, states_to_token_maps, postfix)
if target_state is None:
invalid_postfixes.add(postfix)
else:
valid_postfixes[postfix] = target_state
crossing_tokens.append([token_id, prefix, postfix, target_state])

return crossing_tokens


def create_prefixes_map(prefixes: List[str], starting_state: int, vocabulary: dict) -> dict:
"""Create a state to token map for each prefix.
The starting state is the first available state number in the existing FSM.
Return a dictionary where each key is a prefix and the value is the associated states_to_token_map.
"""

def get_states_to_token_map(char_seq: str, starting_state: int, states_to_token_map: dict, vocabulary: dict):
"""Create the states_to_token_map representing all ways of generating the sequence of characters."""
for i in range(1, len(char_seq) + 1):
if char_seq[:i] in vocabulary.keys():
if starting_state not in states_to_token_map:
states_to_token_map[starting_state] = {}
if i == len(char_seq):
states_to_token_map[starting_state][vocabulary[char_seq[:i]]] = 0
else:
states_to_token_map[starting_state][vocabulary[char_seq[:i]]] = starting_state + i
get_states_to_token_map(char_seq[i:], starting_state + i, states_to_token_map, vocabulary)
return states_to_token_map

prefixes_map = {}
for prefix in prefixes:
prefixes_map[prefix] = get_states_to_token_map(prefix, starting_state, {}, vocabulary)
starting_state += len(prefix)

return prefixes_map
10 changes: 9 additions & 1 deletion outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ def format(sequences):
max_tokens, stop_at, seed
)

removed_chars_from_prompts = self.logits_processor.get_removed_chars_from_prompts(prompts)

completions = self.model.generate(
prompts,
generation_params,
Expand All @@ -508,7 +510,13 @@ def format(sequences):
**model_specific_params,
)

return format(completions)
trimmed_completions = [
completion[len(removed_chars):]
for completion, removed_chars in zip(completions, removed_chars_from_prompts)
]

return format(trimmed_completions)


def stream(
self,
Expand Down
22 changes: 16 additions & 6 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,20 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_states: Dict[int, int] = {}
self._fsm_states: List[Dict[int, int]] = []
self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None
self.default_starting_state = 0
self.token_alignment_starting_states = []

def get_removed_chars_from_prompts(self, prompts: List[str]) -> List[str]:
"""For each prompt, get the postfix to be removed and the resulting starting state.
Update the token_alignment_starting_states attribute and return the postfixes to be removed.
"""
starting_states_and_prefixes = self.fsm.get_starting_states(prompts)
self.token_alignment_starting_states = [starting_state for starting_state, _ in starting_states_and_prefixes]
return [prefix for _, prefix in starting_states_and_prefixes]

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
Expand All @@ -89,18 +99,18 @@ def process_logits(
self._is_first_token = False
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
sequence_states = [0] * len(input_ids)
sequence_states = self.token_alignment_starting_states if self.token_alignment_starting_states else [self.default_starting_state] * len(input_ids)
self._fsm_states = [{hash(tuple([])): sequence_states[i]} for i in range(len(input_ids))]

else:
for seq_ids in input_ids:
for i, seq_ids in enumerate(input_ids):
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]
prev_state = self._fsm_states[i][prev_state_key]

curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])

self._fsm_states[curr_state_key] = curr_state
self._fsm_states[i][curr_state_key] = curr_state
sequence_states.append(curr_state)

mask = torch.full_like(logits, -math.inf)
Expand Down