|
import copy |
|
import logging |
|
from functools import lru_cache |
|
from typing import List, Tuple, Dict |
|
|
|
from transformers_gad.parser import ( |
|
END_OF_RULE_MARKER, |
|
END_OF_ALTERNATE_MARKER, |
|
parse_ebnf, |
|
REF_RULE_MARKER, |
|
) |
|
from transformers_gad.utf8_utils import PartialUTF8, decode_utf8 |
|
from transformers_gad.utils import intervals_intersect |
|
import logging |
|
|
|
|
|
class AcceptState: |
|
def __init__(self, stacks, partial_utf8): |
|
self.stacks = stacks |
|
self.partial_utf8 = partial_utf8 |
|
|
|
@staticmethod |
|
def empty_state(): |
|
return AcceptState([], PartialUTF8()) |
|
|
|
|
|
class StringRecognizer: |
|
def __init__( |
|
self, |
|
grammar_encoding: List[int], |
|
start_rule_id: int = None, |
|
rule_offsets: List[int] = None, |
|
stacks: List[List[int]] = None, |
|
): |
|
|
|
|
|
|
|
self.grammar_encoding = grammar_encoding |
|
if rule_offsets is not None: |
|
self.rule_offsets = rule_offsets |
|
else: |
|
if start_rule_id is None: |
|
raise ValueError("start_rule_id cannot be None if rule_offsets is None") |
|
self.rule_offsets = self.init_rules(start_rule_id) |
|
|
|
|
|
if stacks is not None: |
|
self.stacks = stacks |
|
else: |
|
if start_rule_id is None: |
|
raise ValueError("start_rule_id cannot be None if stacks is None") |
|
self.stacks: List[List[int]] = self.init_stack(start_rule_id) |
|
self.start_rule_id = start_rule_id |
|
|
|
def init_rules(self, start_rule_id: int) -> List[int]: |
|
_rule_offset = 0 |
|
rule_offsets = [] |
|
|
|
while self.grammar_encoding[_rule_offset] != 0xFFFF: |
|
rule_id = self.grammar_encoding[_rule_offset] |
|
|
|
if len(rule_offsets) <= rule_id: |
|
rule_offsets.extend([-1] * (rule_id - len(rule_offsets) + 1)) |
|
rule_offsets[rule_id] = _rule_offset |
|
|
|
|
|
|
|
simple_rhs_offset = _rule_offset + 1 |
|
|
|
|
|
while self.grammar_encoding[simple_rhs_offset] != END_OF_RULE_MARKER: |
|
simple_rhs_offset = ( |
|
simple_rhs_offset + 1 + self.grammar_encoding[simple_rhs_offset] |
|
) |
|
|
|
|
|
|
|
_rule_offset = simple_rhs_offset + 1 |
|
|
|
retrieved_start_rule_id = self.grammar_encoding[rule_offsets[start_rule_id]] |
|
assert retrieved_start_rule_id == start_rule_id |
|
|
|
return rule_offsets |
|
|
|
def init_stack(self, start_rule_id: int) -> List[List[int]]: |
|
|
|
stacks = [] |
|
|
|
sub_rhs_offset = self.rule_offsets[start_rule_id] + 1 |
|
while self.grammar_encoding[sub_rhs_offset]: |
|
stack: List[int] = [] |
|
|
|
element_offset = sub_rhs_offset + 1 |
|
if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER: |
|
stack.append(element_offset) |
|
stacks.extend(self.advance_stack(tuple(stack))) |
|
sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset] |
|
return stacks |
|
|
|
def get_initial_accept_state(self) -> AcceptState: |
|
return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8()) |
|
|
|
def get_termination_accept_state(self) -> AcceptState: |
|
return AcceptState([], PartialUTF8()) |
|
|
|
@lru_cache(maxsize=32768) |
|
def advance_stack(self, stack: Tuple[int]) -> List[List[int]]: |
|
stack = list(stack) |
|
if len(stack) == 0: |
|
return [stack] |
|
|
|
|
|
cur_element_offset = stack[-1] |
|
|
|
|
|
if self.grammar_encoding[cur_element_offset] != REF_RULE_MARKER: |
|
return [stack] |
|
|
|
else: |
|
ref_rule_id = self.grammar_encoding[cur_element_offset + 1] |
|
|
|
ref_subrule_offset = self.rule_offsets[ref_rule_id] + 1 |
|
new_stacks: List[List[int]] = [] |
|
|
|
while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER: |
|
|
|
new_stack = stack[:-1] |
|
|
|
next_element_offset = cur_element_offset + 2 |
|
if ( |
|
self.grammar_encoding[next_element_offset] |
|
!= END_OF_ALTERNATE_MARKER |
|
): |
|
new_stack.append(next_element_offset) |
|
|
|
|
|
ref_element_offset = ref_subrule_offset + 1 |
|
if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER: |
|
new_stack.append(ref_element_offset) |
|
|
|
new_stacks.extend(self.advance_stack(tuple(new_stack))) |
|
ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1 |
|
|
|
return new_stacks |
|
|
|
def _consume_byte(self, byte: int, accept_state: AcceptState): |
|
|
|
|
|
self._consume_bytes(bytes([byte]), accept_state) |
|
|
|
|
|
def _probe_bytes( |
|
self, |
|
byte_seq: bytes, |
|
stacks: List[List[int]], |
|
partial_utf8: PartialUTF8, |
|
verbose=True, |
|
): |
|
if type(byte_seq) is list: |
|
byte_seq = bytes(byte_seq) |
|
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) |
|
if verbose: |
|
logging.debug( |
|
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" |
|
) |
|
new_stacks = self._consume_code_points(code_points, stacks) |
|
|
|
for stack in new_stacks: |
|
|
|
|
|
if len(stack) == 0: |
|
return True |
|
element_offset = stack[-1] |
|
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8): |
|
return True |
|
return False |
|
|
|
def _consume_bytes( |
|
self, |
|
byte_seq: bytes, |
|
accept_state: AcceptState = None, |
|
verbose=True, |
|
): |
|
if accept_state is None: |
|
accept_state = self.get_initial_accept_state() |
|
stacks = accept_state.stacks |
|
partial_utf8 = accept_state.partial_utf8 |
|
if type(byte_seq) is list: |
|
byte_seq = bytes(byte_seq) |
|
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8) |
|
if verbose: |
|
logging.debug( |
|
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}" |
|
) |
|
new_stacks = self._consume_code_points(code_points, stacks) |
|
|
|
new_new_stacks = [] |
|
for stack in new_stacks: |
|
if len(stack) == 0: |
|
continue |
|
element_offset = stack[-1] |
|
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8): |
|
new_new_stacks.append(stack) |
|
return AcceptState(new_new_stacks, new_partial_utf8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=30000) |
|
def _consume_code_point( |
|
self, code_point: int, stacks: Tuple[Tuple[int]] |
|
) -> List[List[int]]: |
|
""" |
|
consume a character from the stack |
|
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] |
|
""" |
|
new_stacks = [] |
|
|
|
stacks: List[List[int]] = list([list(stack) for stack in stacks]) |
|
if code_point == 0: |
|
return new_stacks |
|
for stack in stacks: |
|
new_stacks.extend( |
|
self._consume_code_point_per_stack(code_point, tuple(stack)) |
|
) |
|
return new_stacks |
|
|
|
@lru_cache(maxsize=30000) |
|
def _consume_code_point_per_stack( |
|
self, code_point: int, stack: Tuple[int] |
|
) -> List[List[int]]: |
|
""" |
|
consume a character from the stack |
|
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127] |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
stack = list(stack) |
|
new_stacks = [] |
|
if code_point == 0: |
|
return new_stacks |
|
|
|
if len(stack) == 0: |
|
return new_stacks |
|
|
|
element_offset = stack[-1] |
|
|
|
found = self.accept_code_point_at_element(code_point, element_offset) |
|
if not found: |
|
return new_stacks |
|
|
|
size = self.grammar_encoding[element_offset] |
|
element_offset += size + 1 |
|
new_stack = stack[:-1] |
|
if self.grammar_encoding[element_offset]: |
|
new_stack.append(element_offset) |
|
return self.advance_stack(tuple(new_stack)) |
|
|
|
def _consume_code_points( |
|
self, code_points: List[int], stacks: List[List[int]], verbose=False |
|
) -> List[List[int]]: |
|
for i, code_point in enumerate(code_points): |
|
|
|
tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks]) |
|
stacks = self._consume_code_point(code_point, tuple_stacks) |
|
if len(stacks) > 0 and verbose: |
|
accepted_code_point = code_points[: i + 1] |
|
corresponding_char = chr(code_point) |
|
logging.debug( |
|
f"code point {accepted_code_point} corresponding to {corresponding_char} is accepted" |
|
) |
|
return stacks |
|
|
|
def _accept_code_points( |
|
self, code_points: List[int], stacks: List[List[int]], verbose=False |
|
) -> bool: |
|
stacks = self._consume_code_points(code_points, stacks, verbose) |
|
return len(stacks) > 0 |
|
|
|
@lru_cache(maxsize=30000) |
|
def accept_code_point_at_element( |
|
self, code_point: int, element_offset: int |
|
) -> bool: |
|
size = self.grammar_encoding[element_offset] |
|
|
|
element_offset += 1 |
|
for i in range(0, size, 2): |
|
if ( |
|
self.grammar_encoding[element_offset + i] |
|
<= code_point |
|
<= self.grammar_encoding[element_offset + i + 1] |
|
): |
|
return True |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def partial_utf8_accept_at_element( |
|
self, element_offset: int, partial_utf8: PartialUTF8 |
|
) -> bool: |
|
|
|
partial_value = partial_utf8.value |
|
n_remain = partial_utf8.n_remain |
|
|
|
|
|
if n_remain == 1 and partial_value < 2: |
|
return False |
|
|
|
|
|
if n_remain <= 0: |
|
return True |
|
|
|
|
|
low = partial_value << (n_remain * 6) |
|
|
|
high = low | ((1 << (n_remain * 6)) - 1) |
|
|
|
|
|
|
|
if low == 0: |
|
if n_remain == 2: |
|
low = 1 << 11 |
|
elif n_remain == 3: |
|
low = 1 << 16 |
|
|
|
|
|
size = self.grammar_encoding[element_offset] |
|
|
|
element_offset += 1 |
|
|
|
|
|
for i in range(0, size, 2): |
|
|
|
if intervals_intersect( |
|
low, |
|
high, |
|
self.grammar_encoding[element_offset + i], |
|
self.grammar_encoding[element_offset + i + 1], |
|
): |
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _consume_string(self, string: str, accept_state: AcceptState): |
|
|
|
code_points = [ord(char) for char in string] |
|
stacks = self._consume_code_points(code_points, accept_state.stacks) |
|
return AcceptState(stacks, accept_state.partial_utf8) |
|
|
|
def _accept_prefix(self, string: str, accept_state: AcceptState = None): |
|
if accept_state is None: |
|
accept_state = self.get_initial_accept_state() |
|
new_accept_state = self._consume_string(string, accept_state) |
|
return len(new_accept_state.stacks) > 0 |
|
|
|
def _accept_string(self, string: str, accept_state: AcceptState = None): |
|
if accept_state is None: |
|
accept_state = self.get_initial_accept_state() |
|
new_accept_state = self._consume_string(string, accept_state) |
|
at_least_one_stack_is_empty = any( |
|
len(stack) == 0 for stack in new_accept_state.stacks |
|
) |
|
return at_least_one_stack_is_empty |
|
|
|
def _can_stop(self, stacks: List[List[int]]): |
|
|
|
if len(stacks) == 0: |
|
return True |
|
|
|
for stack in stacks: |
|
if len(stack) == 0: |
|
return True |
|
else: |
|
return False |
|
|
|
def _must_stop(self, stacks: List[List[int]]): |
|
return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def char_acceptance_at_element(self, element_offset): |
|
""" |
|
Caches and returns a dictionary indicating whether a Unicode character is accepted |
|
at a given rule position. This function considers Unicode characters, dynamically |
|
inserting accepted ranges into a dictionary to optimize memory usage. |
|
|
|
Args: |
|
- rule_offset: The offset in the grammar encoding where the rule starts. |
|
|
|
Returns: |
|
- A dictionary where each key is a Unicode character (or range) and the value is True if accepted. |
|
""" |
|
logging.debug(f"element_offset: {element_offset}") |
|
acceptance = {} |
|
num_chars = self.grammar_encoding[element_offset] |
|
element_offset += 1 |
|
for i in range(0, num_chars, 2): |
|
start = self.grammar_encoding[element_offset + i] |
|
end = self.grammar_encoding[element_offset + i + 1] |
|
for j in range(start, end + 1): |
|
acceptance[j] = True |
|
logging.debug(acceptance) |
|
return acceptance |
|
|
|
def _consume_code_points_new( |
|
self, code_points: List[int], stacks: List[List[int]], verbose=False |
|
) -> List[List[int]]: |
|
new_stacks: List[List[int]] = [] |
|
for stack in stacks: |
|
new_stacks.extend( |
|
self._consume_code_points_per_stack( |
|
tuple(code_points), tuple(stack), verbose |
|
) |
|
) |
|
return new_stacks |
|
|
|
@lru_cache(maxsize=30000) |
|
def _consume_code_points_per_stack( |
|
self, code_points: Tuple[int], stack: Tuple[int], verbose=False |
|
) -> List[List[int]]: |
|
code_points = list(code_points) |
|
stacks = (stack,) |
|
for i, code_point in enumerate(code_points): |
|
|
|
stacks = self._consume_code_point(code_point, stacks) |
|
stacks = tuple([tuple(stack) for stack in stacks]) |
|
return [list(stack) for stack in stacks] |