|
import torch |
|
import torch.nn.functional as F |
|
import json |
|
import logging |
|
|
|
class TrieNode: |
|
def __init__(self, |
|
token_id=None, raw_likelihood=None, raw_score=None, |
|
success_rate=1, |
|
is_start_of_sequence=False, is_end_of_sequence=False, |
|
eos_token_id=2): |
|
self.children = {} |
|
self.parent = None |
|
self.token_id = token_id |
|
self.raw_likelihood = raw_likelihood |
|
self.raw_score = raw_score |
|
|
|
|
|
self.success_rate = success_rate |
|
|
|
self.eos_token_id = eos_token_id |
|
self.is_start_of_sequence = is_start_of_sequence |
|
self.is_end_of_sequence = is_end_of_sequence |
|
|
|
def insert(self, child_node): |
|
""" |
|
Insert child_node into the children dictionary |
|
""" |
|
if child_node.token_id not in self.children: |
|
self.children[child_node.token_id] = child_node |
|
child_node.parent = self |
|
|
|
if child_node.token_id == self.eos_token_id: |
|
child_node.is_end_of_sequence = True |
|
|
|
|
|
return self.update_success_rate() |
|
else: |
|
return 0 |
|
|
|
def insert_accepted_tokens(self, scores, acceptance): |
|
""" |
|
Create node from acceptance and scores and |
|
insert as children of self node |
|
""" |
|
likelihoods = F.softmax(scores, dim=-1) |
|
|
|
for batch_index in range(acceptance.size(0)): |
|
accepted_tokens = acceptance[batch_index].nonzero().squeeze(-1) |
|
|
|
for token_id in accepted_tokens: |
|
if token_id not in self.children: |
|
raw_likelihood = likelihoods[batch_index, token_id].item() |
|
raw_score = scores[batch_index, token_id].item() |
|
|
|
child_node = TrieNode( |
|
token_id=token_id.item(), |
|
raw_likelihood=raw_likelihood, |
|
raw_score=raw_score) |
|
|
|
self.insert(child_node) |
|
|
|
def get_success_rate(self, token_id): |
|
""" |
|
Return Approximated Expected Future Grammaticality of the token_id |
|
""" |
|
if token_id in self.children: |
|
return self.children[token_id].success_rate |
|
else: |
|
return 1 |
|
|
|
def update_success_rate(self): |
|
""" |
|
Re-compute the success rate from the updated success rate of children |
|
""" |
|
if self.children: |
|
total_success_rate = sum(child.raw_likelihood * child.success_rate for child in self.children.values()) |
|
|
|
|
|
updated_rate = self.success_rate - total_success_rate |
|
self.success_rate = total_success_rate |
|
|
|
|
|
if self.parent: |
|
return self.parent.update_success_rate() |
|
|
|
return updated_rate |
|
|
|
def prefix_raw_likelihood(self): |
|
if self.parent: |
|
return self.raw_likelihood * self.parent.prefix_raw_likelihood() |
|
else: |
|
return self.raw_likelihood |
|
|
|
def search_token(self, token_id): |
|
""" |
|
Check if the self node has a children with token_id |
|
Return the children node if it exists, return None otherwise |
|
""" |
|
if token_id in self.children: |
|
return self.children[token_id] |
|
else: |
|
return None |
|
|
|
def to_dict(self): |
|
""" |
|
Convert a trie into a dictionary by removing the pointer to the parent |
|
""" |
|
return { |
|
"token_id": self.token_id, |
|
"raw_likelihood": self.raw_likelihood, |
|
"raw_score": self.raw_score, |
|
"success_rate": self.success_rate, |
|
"eos_token_id": self.eos_token_id, |
|
"is_start_of_sequence": self.is_start_of_sequence, |
|
"is_end_of_sequence": self.is_end_of_sequence, |
|
"children": [child.to_dict() for child in self.children.values()] |
|
} |
|
|
|
@staticmethod |
|
def from_dict(d): |
|
""" |
|
Recursively (re)construct trie from dictionary |
|
""" |
|
node = TrieNode( |
|
token_id=d['token_id'], |
|
raw_likelihood=d['raw_likelihood'], |
|
raw_score=d['raw_score'], |
|
success_rate=d['success_rate'], |
|
is_start_of_sequence=d['is_start_of_sequence'], |
|
is_end_of_sequence=d['is_end_of_sequence'], |
|
eos_token_id=d['eos_token_id']) |
|
|
|
node.children = {child['token_id']:TrieNode.from_dict(child) for child in node.children} |
|
for child in node.children.values(): |
|
child.parent = node |
|
|
|
return node |
|
|
|
def __repr__(self): |
|
parent_token_id = 'None (Root Node)' if self.parent is None else self.parent.token_id |
|
return (f"TrieNode(token_id={self.token_id}', " |
|
f"raw_likelihood={self.raw_likelihood}, raw_score={self.raw_score}, children={list(self.children.keys())}, " |
|
f"parent={parent_token_id}, success rate={self.success_rate})") |
|
|
|
class Trie: |
|
def __init__(self): |
|
self.root = TrieNode() |
|
self.root.is_start_of_sequence = True |
|
|
|
def search_last_parent(self, prefix: torch.LongTensor): |
|
""" |
|
Search the longest prefix in the trie that matches to the input sequence of tokens 'prefix' |
|
""" |
|
matched_prefix = [] |
|
current_parent = self.root |
|
|
|
|
|
for time_step, token_id in enumerate(prefix[0]): |
|
token_id = token_id.item() |
|
if token_id in current_parent.children: |
|
current_parent = current_parent.children[token_id] |
|
matched_prefix.append(current_parent.token_id) |
|
else: |
|
print( |
|
f"matched prefix is {matched_prefix}; current {token_id} not found in the trie at time step {time_step}") |
|
return None |
|
|
|
return current_parent |
|
|
|
def search(self, sequence): |
|
""" |
|
Return the sequence of nodes that exactly matches with the input |
|
""" |
|
node = self.root |
|
nodes = [] |
|
for token_id in sequence: |
|
if token_id not in node.children: |
|
return None |
|
node = node.children[token_id] |
|
nodes.append(node) |
|
return nodes |
|
|
|
def raw_likelihood(self, sequence): |
|
""" |
|
Return the raw likelihood (before the adjustment) of sequence |
|
""" |
|
if isinstance(sequence, torch.Tensor): |
|
sequence = sequence.tolist() |
|
|
|
nodes = self.search(sequence) |
|
if nodes is None: |
|
return None |
|
|
|
likelihood = 1 |
|
for node in nodes: |
|
likelihood *= node.raw_likelihood |
|
return likelihood |
|
|
|
def json(self): |
|
return json.dumps(self.root.to_dict(), indent=2) |
|
|
|
@staticmethod |
|
def loads(js): |
|
trie = Trie() |
|
trie.root = TrieNode.from_dict(json.loads(js)) |
|
|
|
return trie |
|
|
|
def print_trie(self, node=None, prefix=None): |
|
""" |
|
Print all the leaves in the trie |
|
""" |
|
if node is None: |
|
node = self.root |
|
if prefix is None: |
|
prefix = [] |
|
|
|
|
|
if node.is_end_of_sequence or len(node.children) == 0: |
|
print(prefix) |
|
|
|
|
|
for char, child_node in node.children.items(): |
|
self.print_trie(child_node, prefix + [char]) |
|
|
|
def has_full_information(self): |
|
""" |
|
Checks if all paths in the trie end with an is_end_of_sequence node set to True. |
|
Returns True if the trie has full information, False otherwise. |
|
""" |
|
return self._check_full_information(self.root) |
|
|
|
def _check_full_information(self, node): |
|
|
|
if not node.children: |
|
return node.is_end_of_sequence |
|
|
|
|
|
return all(self._check_full_information(child) for child in node.children.values()) |
|
|
|
def print_all_nodes(self, node=None, depth=0): |
|
""" |
|
Print all the nodes in the trie (including non-leaves) |
|
""" |
|
|
|
if node is None: |
|
node = self.root |
|
|
|
|
|
indent = " " * depth |
|
node_details = (f"{indent}TrieNode(token_id={node.token_id}', " |
|
f"raw_likelihood={node.raw_likelihood}, raw_score={node.raw_score}, success rate={node.success_rate}, " |
|
f"children={list(node.children.keys())}, " |
|
f"parent={node.parent.token_id if node.parent else None}, " |
|
f"is_end_of_sequence={node.is_end_of_sequence})") |
|
print(node_details) |
|
|
|
|
|
for child_node in node.children.values(): |
|
self.print_all_nodes(child_node, depth + 1) |