from copy import copy from functools import partial from outlines.fsm.guide import RegexGuide from pydantic import BaseModel from transformers import PreTrainedTokenizerBase def merge_successive_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j): states_to_token_maps = dict(states_to_token_maps) transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]} transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]} transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i) for s1, s2 in transitions_i.items(): while s2 in transitions_j: s2 = transitions_j[s2] if s2 != transitions_i[s1]: states_to_token_maps[s1] = dict(states_to_token_maps[s1]) states_to_token_maps[s1][i] = s2 return states_to_token_maps def replace_transitions(states_to_token_maps: dict[int, dict[int, int]], i, j): states_to_token_maps = dict(states_to_token_maps) transitions_i = {(s1, states_to_token_maps[s1][i]) for s1 in states_to_token_maps if i in states_to_token_maps[s1]} transitions_j = {(s1, states_to_token_maps[s1][j]) for s1 in states_to_token_maps if j in states_to_token_maps[s1]} transitions_i, transitions_j = dict(transitions_i - transitions_j), dict(transitions_j - transitions_i) for s1, s2 in transitions_i.items(): if s2 != transitions_j.get(s1): states_to_token_maps[s1] = dict(states_to_token_maps[s1]) if s1 in transitions_j: states_to_token_maps[s1][i] = transitions_j[s1] else: states_to_token_maps[s1].pop(i) states_to_token_maps[s1][j] = s2 return states_to_token_maps def find_paths_with_transitions(states_to_token_maps: dict[int, dict[int, int]], transitions: list[int]) -> list[list[int]]: possible_s0 = {s0 for s0 in states_to_token_maps if transitions[0] in states_to_token_maps[s0]} possible_s1 = {s1 for s1 in states_to_token_maps if transitions[1] in states_to_token_maps[s1]} - possible_s0 starts = sorted( s0 for s0 in possible_s0 if states_to_token_maps[s0][transitions[0]] in possible_s1 ) paths = [[start] for start in starts] for path in paths: for i in transitions: if i in states_to_token_maps[path[-1]]: path.append(states_to_token_maps[path[-1]][i]) else: break return [path for path in paths if len(path) == len(transitions) + 1] def replace_fields(fsm: RegexGuide, model: BaseModel, new_fields: list[str], tokenizer: PreTrainedTokenizerBase, make_infinite_loop: bool = False) -> RegexGuide: assert len(new_fields) <= len(model.model_fields) sttm = dict(fsm.states_to_token_maps) encode = partial(tokenizer.encode, add_special_tokens=False) quote = encode('"')[0] # Let's replace the placeholder fields from the model in the finite state model by the new fields for orig_field, new_field in zip(model.model_fields, new_fields): orig_field_tokens = [encode(orig_field_char)[0] for orig_field_char in orig_field] new_field_tokens = encode(new_field) assert len(new_field_tokens) <= len(orig_field_tokens) # Merge transitions until we have number of transitions = number of tokens in the field name for k in reversed(range(len(new_field_tokens), len(orig_field_tokens))): sttm = merge_successive_transitions(sttm, orig_field_tokens[k - 1], orig_field_tokens[k]) # Replace the token ids in the transitions with the ones of the new field name for k in range(len(new_field_tokens)): sttm = replace_transitions(sttm, orig_field_tokens[k], new_field_tokens[k]) if len(new_fields) < len(model.model_fields) or make_infinite_loop: # Set the last field last state to generate less than the number of fields in the model # We need to do this for every possible path # e.g. multiple paths are used to count items when setting a min/max length orig_last_field = list(model.model_fields)[-1] new_last_field = new_fields[-1] orig_last_field_paths = find_paths_with_transitions(sttm, [quote] + [encode(c)[0] for c in orig_last_field]) new_last_field_paths = find_paths_with_transitions(sttm, [quote] + encode(new_last_field)) if make_infinite_loop: # this is a hack to loop on the same states over and over again orig_last_field_paths = [orig_last_field_paths[0]] * len(orig_last_field_paths) for orig_last_field_path, new_last_field_path in zip( orig_last_field_paths, new_last_field_paths ): orig_last_field_last_state = orig_last_field_path[-1] new_last_field_second_last_state = new_last_field_path[-2] sttm[new_last_field_second_last_state] = dict(sttm[new_last_field_second_last_state]) sttm[new_last_field_second_last_state][encode(new_last_field)[-1]] = orig_last_field_last_state fsm = copy(fsm) fsm.states_to_token_maps = sttm return fsm