File size: 18,343 Bytes
901bbd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 |
import copy
import logging
from functools import lru_cache
from typing import List, Tuple, Dict
from transformers_gad.parser import (
END_OF_RULE_MARKER,
END_OF_ALTERNATE_MARKER,
parse_ebnf,
REF_RULE_MARKER,
)
from transformers_gad.utf8_utils import PartialUTF8, decode_utf8
from transformers_gad.utils import intervals_intersect
import logging
class AcceptState:
def __init__(self, stacks, partial_utf8):
self.stacks = stacks
self.partial_utf8 = partial_utf8
@staticmethod
def empty_state():
return AcceptState([], PartialUTF8())
class StringRecognizer:
def __init__(
self,
grammar_encoding: List[int],
start_rule_id: int = None,
rule_offsets: List[int] = None,
stacks: List[List[int]] = None,
):
# strictly speaking, we don't need to copy grammar_encoding because we don't modify it
# but we do it anyway to be safe
# in case where the grammar is very large, we can consider not copying it
self.grammar_encoding = grammar_encoding
if rule_offsets is not None:
self.rule_offsets = rule_offsets
else:
if start_rule_id is None:
raise ValueError("start_rule_id cannot be None if rule_offsets is None")
self.rule_offsets = self.init_rules(start_rule_id)
# each stack is a list of indices into grammar_encoding
# each index points to a rule's
if stacks is not None:
self.stacks = stacks
else:
if start_rule_id is None:
raise ValueError("start_rule_id cannot be None if stacks is None")
self.stacks: List[List[int]] = self.init_stack(start_rule_id)
self.start_rule_id = start_rule_id
def init_rules(self, start_rule_id: int) -> List[int]:
_rule_offset = 0
rule_offsets = []
# Build `rules` as an array of rule IDs to their positions in `grammar_src`
while self.grammar_encoding[_rule_offset] != 0xFFFF:
rule_id = self.grammar_encoding[_rule_offset]
# store the offset idx
if len(rule_offsets) <= rule_id:
rule_offsets.extend([-1] * (rule_id - len(rule_offsets) + 1))
rule_offsets[rule_id] = _rule_offset
# Skip rule ID
# _rule_offset += 1
simple_rhs_offset = _rule_offset + 1
# Skip rule alternates
while self.grammar_encoding[simple_rhs_offset] != END_OF_RULE_MARKER:
simple_rhs_offset = (
simple_rhs_offset + 1 + self.grammar_encoding[simple_rhs_offset]
)
# Skip 0 denoting end of rule
# _rule_offset += 1
_rule_offset = simple_rhs_offset + 1
retrieved_start_rule_id = self.grammar_encoding[rule_offsets[start_rule_id]]
assert retrieved_start_rule_id == start_rule_id
return rule_offsets
def init_stack(self, start_rule_id: int) -> List[List[int]]:
stacks = []
# Loop over alternates of start rule to build initial stacks
sub_rhs_offset = self.rule_offsets[start_rule_id] + 1
while self.grammar_encoding[sub_rhs_offset]:
stack: List[int] = []
# If alternate is nonempty, add to stack
element_offset = sub_rhs_offset + 1
if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER:
stack.append(element_offset)
stacks.extend(self.advance_stack(tuple(stack)))
sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset]
return stacks
def get_initial_accept_state(self) -> AcceptState:
return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8())
def get_termination_accept_state(self) -> AcceptState:
return AcceptState([], PartialUTF8())
@lru_cache(maxsize=32768)
def advance_stack(self, stack: Tuple[int]) -> List[List[int]]:
stack = list(stack)
if len(stack) == 0:
return [stack]
# we get the last element of the stack, which is the element we are currently processing
cur_element_offset = stack[-1]
# if the element is a terminal, we don't need to advance the stack
if self.grammar_encoding[cur_element_offset] != REF_RULE_MARKER:
return [stack]
# the remaining case is that the element is a non-terminal, i.e. a reference to another rule
else:
ref_rule_id = self.grammar_encoding[cur_element_offset + 1]
# find the offset of the referenced rule
ref_subrule_offset = self.rule_offsets[ref_rule_id] + 1
new_stacks: List[List[int]] = []
# Loop over alternates of referenced rule to build new stacks
while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER:
# copy the original stack without the last element
new_stack = stack[:-1]
# if the rule ref is followed by another element, we add it to the stack
next_element_offset = cur_element_offset + 2
if (
self.grammar_encoding[next_element_offset]
!= END_OF_ALTERNATE_MARKER
):
new_stack.append(next_element_offset)
# if the referenced rule is not empty, we add its element offset to the stack
ref_element_offset = ref_subrule_offset + 1
if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER:
new_stack.append(ref_element_offset)
new_stacks.extend(self.advance_stack(tuple(new_stack)))
ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1
return new_stacks
def _consume_byte(self, byte: int, accept_state: AcceptState):
# suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes
# we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times
self._consume_bytes(bytes([byte]), accept_state)
# @lru_cache(maxsize=32768)
def _probe_bytes(
self,
byte_seq: bytes,
stacks: List[List[int]],
partial_utf8: PartialUTF8,
verbose=True,
):
if type(byte_seq) is list:
byte_seq = bytes(byte_seq)
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
if verbose:
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
for stack in new_stacks:
# stack is empty, meaning that the variables are all consumed
if len(stack) == 0:
return True
element_offset = stack[-1]
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
return True
return False
def _consume_bytes(
self,
byte_seq: bytes,
accept_state: AcceptState = None,
verbose=True,
):
if accept_state is None:
accept_state = self.get_initial_accept_state()
stacks = accept_state.stacks
partial_utf8 = accept_state.partial_utf8
if type(byte_seq) is list:
byte_seq = bytes(byte_seq)
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
if verbose:
logging.debug(
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
)
new_stacks = self._consume_code_points(code_points, stacks)
new_new_stacks = []
for stack in new_stacks:
if len(stack) == 0:
continue
element_offset = stack[-1]
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
new_new_stacks.append(stack)
return AcceptState(new_new_stacks, new_partial_utf8)
##########################
#
# Code point recognition
#
##########################
@lru_cache(maxsize=30000)
def _consume_code_point(
self, code_point: int, stacks: Tuple[Tuple[int]]
) -> List[List[int]]:
"""
consume a character from the stack
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
"""
new_stacks = []
stacks: List[List[int]] = list([list(stack) for stack in stacks])
if code_point == 0:
return new_stacks
for stack in stacks:
new_stacks.extend(
self._consume_code_point_per_stack(code_point, tuple(stack))
)
return new_stacks
@lru_cache(maxsize=30000)
def _consume_code_point_per_stack(
self, code_point: int, stack: Tuple[int]
) -> List[List[int]]:
"""
consume a character from the stack
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
"""
# TODO, the below code will raise an error when the stack is empty, but why is this happening?
# if len(stacks) == 0:
# raise ValueError("Stacks don't contain any stack, meaning that no character can be consumed")
# code_point = 0 is a special case when the uf8 sequence is not complete, we return an empty stack
# to indicate that the character is not accepted
stack = list(stack)
new_stacks = []
if code_point == 0:
return new_stacks
# stack is empty
if len(stack) == 0:
return new_stacks
element_offset = stack[-1]
found = self.accept_code_point_at_element(code_point, element_offset)
if not found:
return new_stacks
size = self.grammar_encoding[element_offset]
element_offset += size + 1
new_stack = stack[:-1]
if self.grammar_encoding[element_offset]:
new_stack.append(element_offset)
return self.advance_stack(tuple(new_stack))
def _consume_code_points(
self, code_points: List[int], stacks: List[List[int]], verbose=False
) -> List[List[int]]:
for i, code_point in enumerate(code_points):
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
stacks = self._consume_code_point(code_point, tuple_stacks)
if len(stacks) > 0 and verbose:
accepted_code_point = code_points[: i + 1]
corresponding_char = chr(code_point)
logging.debug(
f"code point {accepted_code_point} corresponding to {corresponding_char} is accepted"
)
return stacks
def _accept_code_points(
self, code_points: List[int], stacks: List[List[int]], verbose=False
) -> bool:
stacks = self._consume_code_points(code_points, stacks, verbose)
return len(stacks) > 0
@lru_cache(maxsize=30000)
def accept_code_point_at_element(
self, code_point: int, element_offset: int
) -> bool:
size = self.grammar_encoding[element_offset]
# to make idx point to the range_start of the first range
element_offset += 1
for i in range(0, size, 2):
if (
self.grammar_encoding[element_offset + i]
<= code_point
<= self.grammar_encoding[element_offset + i + 1]
):
return True
return False
# def _accept_code_point(self, code_point: int, stacks: List[List[int]]):
# # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
# tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
# new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks)
# return len(new_stacks) > 0
#############################
#
# Partial UTF-8 recognition
#
#############################
def partial_utf8_accept_at_element(
self, element_offset: int, partial_utf8: PartialUTF8
) -> bool:
# Extract the accumulated value and the number of remaining bytes from the partial_utf8 object.
partial_value = partial_utf8.value
n_remain = partial_utf8.n_remain
# Return False if there are no remaining bytes to process or if it's an invalid UTF-8 sequence.
if n_remain == 1 and partial_value < 2:
return False
# If there are no remaining bytes, this means we had already consumed a complete UTF-8 sequence.
if n_remain <= 0:
return True
# Calculate the lowest possible Unicode code point that can be formed with the remaining bytes.
low = partial_value << (n_remain * 6)
# Calculate the highest possible Unicode code point by setting all remaining bits to 1.
high = low | ((1 << (n_remain * 6)) - 1)
# If the low end of the range is 0 and a specific number of bytes remain, adjust low to the minimum value
# that can be represented with that number of bytes. This accounts for UTF-8 encoding rules.
if low == 0:
if n_remain == 2:
low = 1 << 11 # Minimum value representable with 2 additional bytes.
elif n_remain == 3:
low = 1 << 16 # Minimum value representable with 3 additional bytes.
# Get the size of the grammar rule starting at the current element_offset.
size = self.grammar_encoding[element_offset]
# Move the element_offset to the start of the grammar rule's definition.
element_offset += 1
# Iterate over the grammar rule, checking if the range defined by low-high overlaps with any specified ranges.
for i in range(0, size, 2):
# If the current range (specified in the grammar encoding) overlaps with the low-high range, return True.
if intervals_intersect(
low,
high,
self.grammar_encoding[element_offset + i],
self.grammar_encoding[element_offset + i + 1],
):
return True
# If no overlap is found with any of the ranges, return False, indicating no valid partial match.
return False
#############################
#
# String recognition
#
#############################
def _consume_string(self, string: str, accept_state: AcceptState):
# _bytes = bytes(string, "utf-8")
code_points = [ord(char) for char in string]
stacks = self._consume_code_points(code_points, accept_state.stacks)
return AcceptState(stacks, accept_state.partial_utf8)
def _accept_prefix(self, string: str, accept_state: AcceptState = None):
if accept_state is None:
accept_state = self.get_initial_accept_state()
new_accept_state = self._consume_string(string, accept_state)
return len(new_accept_state.stacks) > 0
def _accept_string(self, string: str, accept_state: AcceptState = None):
if accept_state is None:
accept_state = self.get_initial_accept_state()
new_accept_state = self._consume_string(string, accept_state)
at_least_one_stack_is_empty = any(
len(stack) == 0 for stack in new_accept_state.stacks
)
return at_least_one_stack_is_empty
def _can_stop(self, stacks: List[List[int]]):
# This happens in practice, but maybe it shouldn't? TODO
if len(stacks) == 0:
return True
# if any of the stack is empty, we can stop
for stack in stacks:
if len(stack) == 0:
return True
else:
return False
def _must_stop(self, stacks: List[List[int]]):
return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks)
#############################
#
# Not Used
#
#############################
# For each sub-rule in the grammar, cache whether each byte is accepted.
@lru_cache(maxsize=None)
def char_acceptance_at_element(self, element_offset):
"""
Caches and returns a dictionary indicating whether a Unicode character is accepted
at a given rule position. This function considers Unicode characters, dynamically
inserting accepted ranges into a dictionary to optimize memory usage.
Args:
- rule_offset: The offset in the grammar encoding where the rule starts.
Returns:
- A dictionary where each key is a Unicode character (or range) and the value is True if accepted.
"""
logging.debug(f"element_offset: {element_offset}")
acceptance = {}
num_chars = self.grammar_encoding[element_offset]
element_offset += 1
for i in range(0, num_chars, 2):
start = self.grammar_encoding[element_offset + i]
end = self.grammar_encoding[element_offset + i + 1]
for j in range(start, end + 1):
acceptance[j] = True
logging.debug(acceptance)
return acceptance
def _consume_code_points_new(
self, code_points: List[int], stacks: List[List[int]], verbose=False
) -> List[List[int]]:
new_stacks: List[List[int]] = []
for stack in stacks:
new_stacks.extend(
self._consume_code_points_per_stack(
tuple(code_points), tuple(stack), verbose
)
)
return new_stacks
@lru_cache(maxsize=30000)
def _consume_code_points_per_stack(
self, code_points: Tuple[int], stack: Tuple[int], verbose=False
) -> List[List[int]]:
code_points = list(code_points)
stacks = (stack,)
for i, code_point in enumerate(code_points):
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
stacks = self._consume_code_point(code_point, stacks)
stacks = tuple([tuple(stack) for stack in stacks])
return [list(stack) for stack in stacks] |