File size: 6,675 Bytes
901bbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from dataclasses import dataclass
from typing import Tuple

from dataclasses import dataclass


@dataclass
class PartialUTF8:
    """
    A data class representing the state of a partially decoded UTF-8 sequence.

    Attributes:
    - value (int): The current accumulated value of the partially decoded Unicode code point.
                   This attribute stores the bits that have been decoded so far. For a fully decoded
                   character or before any partial decoding has started, this would typically be `0`.

    - n_remain (int): The number of bytes remaining to complete the current UTF-8 encoded character.
                      A value of `-1` indicates that there is no ongoing partial decoding, i.e.,
                      either decoding has not started, or the last character was fully decoded.

    This class is used to handle situations where UTF-8 encoded data may end in the middle of a character
    sequence, allowing for the decoding process to be resumed when more data becomes available.
    """

    value: int = 0  # Default to 0, indicating no partial value accumulated
    n_remain: int = (
        -1
    )  # Default to -1, indicating no bytes are currently expected to complete the character

    def __hash__(self):
        return hash((self.value, self.n_remain))

    def __eq__(self, other):
        if not isinstance(other, PartialUTF8):
            return NotImplemented
        return self.value == other.value and self.n_remain == other.n_remain


from typing import List, Tuple
from functools import lru_cache


@lru_cache(maxsize=3000000)
def decode_utf8(
    src: bytes, partial_start: PartialUTF8
) -> Tuple[List[int], PartialUTF8]:
    # Lookup table for determining the total bytes based on the first byte's high 4 bits
    lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4]
    pos = 0  # Position in the src bytes to start decoding from
    code_points = []  # List to store the decoded Unicode code points
    value = partial_start.value  # Start with any previously partial decoded value
    n_remain = partial_start.n_remain  # Number of bytes remaining from a partial decode

    # If there's a partial sequence left from last decode, try to continue decoding it
    while pos < len(src) and n_remain > 0:
        next_byte = src[pos]  # Get the next byte to process
        # Check if the continuation byte format is correct (`10xxxxxx`)
        if (next_byte >> 6) != 2:
            # If not, it's an invalid sequence. Abort and return a special error state.
            code_points = [0]
            return code_points, PartialUTF8(0, -1)

        # Accumulate the value by shifting left and adding the relevant 6 bits
        value = (value << 6) + (next_byte & 0x3F)
        pos += 1  # Move to the next byte
        n_remain -= 1  # Decrement the number of remaining bytes

    # If we've completed a partial sequence, add its value to the code points
    if partial_start.n_remain > 0 and n_remain == 0:
        code_points.append(value)

    # Process the rest of src as complete or new UTF-8 sequences
    while pos < len(src):
        first_byte = src[pos]  # Get the first byte of the next sequence
        highbits = first_byte >> 4  # Extract the high 4 bits for the lookup table
        n_remain = lookup[highbits] - 1  # Determine remaining bytes in this sequence

        # If lookup returns an invalid number, it's an invalid sequence. Abort.
        if n_remain < 0:
            # raise ValueError("Invalid UTF-8 sequence")
            code_points = [0]
            return code_points, PartialUTF8(0, -1)

        # Calculate the mask to isolate significant bits from the first byte
        mask = (1 << (7 - n_remain)) - 1
        value = first_byte & mask  # Apply the mask to get the initial value
        pos += 1  # Move to the next byte

        # Process the continuation bytes
        while pos < len(src) and n_remain > 0:
            next_byte = src[pos]
            # Shift the accumulated value and add the next 6 significant bits
            value = (value << 6) + (next_byte & 0x3F)
            pos += 1  # Move to the next byte
            n_remain -= 1  # Decrement the number of remaining bytes

        # If the sequence is complete, add its decoded value to the code points
        if n_remain == 0:
            code_points.append(value)

    # # Append a terminating value to indicate the end (following llama-cpp implementation)
    # code_points.append(0)
    # the following line is crucial for LRU cache to work, as it reset to the initial state
    if n_remain == 0:
        n_remain = -1
        value = 0

    # Return the decoded code points and the state of any partial decoding
    return code_points, PartialUTF8(value, n_remain)


def decode_utf8_leading_char(src: bytes) -> tuple:
    first_byte = src[0]
    highbits = first_byte >> 4
    lookup = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
    char_len = lookup[highbits]

    # Extract the relevant bytes for the UTF-8 character
    utf8_char_bytes = src[:char_len]

    # Decode the character
    char = utf8_char_bytes.decode("utf-8")

    # Use ord() to convert the single character to its Unicode code point
    code_point = ord(char)

    # Remaining bytes
    remaining_bytes = src[char_len:]

    return code_point, remaining_bytes


def decode_utf8_string(utf8_bytes: bytes) -> list:
    code_points = []
    while utf8_bytes:
        code_point, utf8_bytes = decode_utf8_leading_char(utf8_bytes)
        code_points.append(code_point)
    return code_points

if __name__ == "__main__":
    # Given string
    my_string = "€Hello"  # The Euro symbol followed by "Hello"

    # Get UTF-8 encoded bytes
    utf8_bytes = my_string.encode("utf-8")

    assert utf8_bytes == b"\xe2\x82\xacHello"

    # Example usage with the Euro symbol followed by more characters
    code_point, remaining_bytes = decode_utf8_leading_char(utf8_bytes)

    print(f"Code Point: {code_point}")  # Expected Output: 8364 (Euro symbol)
    print(f"Remaining Bytes: {remaining_bytes}")  # Expected Output: b'Hello'

    # Example usage with the entire string
    code_points = decode_utf8_string(utf8_bytes)

    print(
        f"Code Points: {code_points}"
    )  # Expected Output: [8364, 72, 101, 108, 108, 111]

    print("-" * 50)

    # Example usage:
    utf8_bytes = b"\xe2\x82\xacHello"  # UTF-8 encoded string (Euro symbol + "Hello")
    partial_start = PartialUTF8()  # Assuming start with no partial sequence
    code_points, partial_utf8 = decode_utf8(utf8_bytes, partial_start)

    print("Code Points:", code_points)
    print("Remaining UTF-8 State:", partial_utf8.value, partial_utf8.n_remain)