DeepLearning101's picture
Upload 21 files
45311fe
raw
history blame
4.95 kB
# -*- coding: utf-8 -*-
# @Time : 2022/2/15 7:57 下午
# @Author : JianingWang
# @File : trie
import logging
from typing import List
from collections import OrderedDict
logger = logging.getLogger(__name__)
class Trie:
def __init__(self):
self.data = {}
def add(self, word: str):
"""
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
The special key `""` is used to represent termination.
This function is idempotent, adding twice the same word will leave the trie unchanged
Example:
```python
>>> trie = Trie()
>>> trie.add("Hello 友達")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
>>> trie.add("Hello")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
```
"""
if not word:
# Prevent empty string
return
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref = ref[char]
ref[""] = 1
def find(self, text: str):
states = OrderedDict()
offsets = []
skip = 0
for current, current_char in enumerate(text):
if skip and current < skip:
continue
to_remove = set()
reset = False
for start, trie_pointer in states.items():
if "" in trie_pointer:
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
break
elif lookstart < start:
lookahead_index = current + 1
end = current + 1
else:
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
break
next_char = text[lookahead_index]
offsets.append([start, end])
reset = True
break
elif current_char in trie_pointer:
trie_pointer = trie_pointer[current_char]
states[start] = trie_pointer
else:
to_remove.add(start)
if reset:
states = {}
else:
for start in to_remove:
del states[start]
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]
for start, trie_pointer in states.items():
if "" in trie_pointer:
end = len(text)
offsets.append([start, end])
break
return offsets
def split(self, text: str) -> List[str]:
"""
Example:
```python
>>> trie = Trie()
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS] This is a extra_id_100"]
>>> trie.add("[CLS]")
>>> trie.add("extra_id_1")
>>> trie.add("extra_id_100")
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
```
"""
word_sets = self.find(text)
offsets = [0]
for w in word_sets:
offsets.extend(w)
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
offsets.append(len(text))
tokens = []
start = 0
for end in offsets:
if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif start == end:
continue
tokens.append(text[start:end])
start = end
return tokens
def __reduce__(self):
return None
if __name__ == "__main__":
trie = Trie()
for word in ["A", "AB", "BD", "BWA"]:
trie.add(word)
print(trie.__reduce__())