import torch import torch.nn.functional as F import json import logging class TrieNode: def __init__(self, token_id=None, raw_likelihood=None, raw_score=None, success_rate=1, is_start_of_sequence=False, is_end_of_sequence=False, eos_token_id=2): self.children = {} self.parent = None self.token_id = token_id self.raw_likelihood = raw_likelihood self.raw_score = raw_score # The default approximation of EFG self.success_rate = success_rate self.eos_token_id = eos_token_id self.is_start_of_sequence = is_start_of_sequence self.is_end_of_sequence = is_end_of_sequence def insert(self, child_node): """ Insert child_node into the children dictionary """ if child_node.token_id not in self.children: self.children[child_node.token_id] = child_node child_node.parent = self if child_node.token_id == self.eos_token_id: child_node.is_end_of_sequence = True # update the success rate of the parent node return self.update_success_rate() else: return 0 def insert_accepted_tokens(self, scores, acceptance): """ Create node from acceptance and scores and insert as children of self node """ likelihoods = F.softmax(scores, dim=-1) for batch_index in range(acceptance.size(0)): accepted_tokens = acceptance[batch_index].nonzero().squeeze(-1) for token_id in accepted_tokens: if token_id not in self.children: raw_likelihood = likelihoods[batch_index, token_id].item() raw_score = scores[batch_index, token_id].item() child_node = TrieNode( token_id=token_id.item(), raw_likelihood=raw_likelihood, raw_score=raw_score) self.insert(child_node) def get_success_rate(self, token_id): """ Return Approximated Expected Future Grammaticality of the token_id """ if token_id in self.children: return self.children[token_id].success_rate else: return 1 def update_success_rate(self): """ Re-compute the success rate from the updated success rate of children """ if self.children: total_success_rate = sum(child.raw_likelihood * child.success_rate for child in self.children.values()) # Get how much of unexplored nodes are covered with this update updated_rate = self.success_rate - total_success_rate self.success_rate = total_success_rate # Back propagate the success rate if self.parent: return self.parent.update_success_rate() return updated_rate def prefix_raw_likelihood(self): if self.parent: return self.raw_likelihood * self.parent.prefix_raw_likelihood() else: return self.raw_likelihood def search_token(self, token_id): """ Check if the self node has a children with token_id Return the children node if it exists, return None otherwise """ if token_id in self.children: return self.children[token_id] else: return None def to_dict(self): """ Convert a trie into a dictionary by removing the pointer to the parent """ return { "token_id": self.token_id, "raw_likelihood": self.raw_likelihood, "raw_score": self.raw_score, "success_rate": self.success_rate, "eos_token_id": self.eos_token_id, "is_start_of_sequence": self.is_start_of_sequence, "is_end_of_sequence": self.is_end_of_sequence, "children": [child.to_dict() for child in self.children.values()] } @staticmethod def from_dict(d): """ Recursively (re)construct trie from dictionary """ node = TrieNode( token_id=d['token_id'], raw_likelihood=d['raw_likelihood'], raw_score=d['raw_score'], success_rate=d['success_rate'], is_start_of_sequence=d['is_start_of_sequence'], is_end_of_sequence=d['is_end_of_sequence'], eos_token_id=d['eos_token_id']) node.children = {child['token_id']:TrieNode.from_dict(child) for child in node.children} for child in node.children.values(): child.parent = node return node def __repr__(self): parent_token_id = 'None (Root Node)' if self.parent is None else self.parent.token_id return (f"TrieNode(token_id={self.token_id}', " f"raw_likelihood={self.raw_likelihood}, raw_score={self.raw_score}, children={list(self.children.keys())}, " f"parent={parent_token_id}, success rate={self.success_rate})") class Trie: def __init__(self): self.root = TrieNode() self.root.is_start_of_sequence = True def search_last_parent(self, prefix: torch.LongTensor): """ Search the longest prefix in the trie that matches to the input sequence of tokens 'prefix' """ matched_prefix = [] current_parent = self.root # Assume one batch of prefix for time_step, token_id in enumerate(prefix[0]): token_id = token_id.item() if token_id in current_parent.children: current_parent = current_parent.children[token_id] matched_prefix.append(current_parent.token_id) else: print( f"matched prefix is {matched_prefix}; current {token_id} not found in the trie at time step {time_step}") return None return current_parent def search(self, sequence): """ Return the sequence of nodes that exactly matches with the input """ node = self.root nodes = [] for token_id in sequence: if token_id not in node.children: return None node = node.children[token_id] nodes.append(node) return nodes def raw_likelihood(self, sequence): """ Return the raw likelihood (before the adjustment) of sequence """ if isinstance(sequence, torch.Tensor): sequence = sequence.tolist() nodes = self.search(sequence) if nodes is None: return None likelihood = 1 for node in nodes: likelihood *= node.raw_likelihood return likelihood def json(self): return json.dumps(self.root.to_dict(), indent=2) @staticmethod def loads(js): trie = Trie() trie.root = TrieNode.from_dict(json.loads(js)) return trie def print_trie(self, node=None, prefix=None): """ Print all the leaves in the trie """ if node is None: node = self.root if prefix is None: prefix = [] # If current node marks the end of a sequence, print the prefix as a list if node.is_end_of_sequence or len(node.children) == 0: print(prefix) # Recursively call print_trie for all children, appending the current character/token to the prefix for char, child_node in node.children.items(): self.print_trie(child_node, prefix + [char]) def has_full_information(self): """ Checks if all paths in the trie end with an is_end_of_sequence node set to True. Returns True if the trie has full information, False otherwise. """ return self._check_full_information(self.root) def _check_full_information(self, node): # If the node has no children, check if it is marked as the end of a sequence if not node.children: return node.is_end_of_sequence # Recursively check all children return all(self._check_full_information(child) for child in node.children.values()) def print_all_nodes(self, node=None, depth=0): """ Print all the nodes in the trie (including non-leaves) """ if node is None: node = self.root # Print current node's details indent = " " * depth # Create indentation based on the depth in the trie node_details = (f"{indent}TrieNode(token_id={node.token_id}', " f"raw_likelihood={node.raw_likelihood}, raw_score={node.raw_score}, success rate={node.success_rate}, " f"children={list(node.children.keys())}, " f"parent={node.parent.token_id if node.parent else None}, " f"is_end_of_sequence={node.is_end_of_sequence})") print(node_details) # Recursively call print_all_nodes for all children for child_node in node.children.values(): self.print_all_nodes(child_node, depth + 1)