# ------------------------------------------------------------------------ # Modified from OFA (https://github.com/OFA-Sys/OFA) # Copyright 2022 The OFA-Sys Team. # All rights reserved. # This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. # ------------------------------------------------------------------------ # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict class TreeNode(): def __init__(self): self.child = defaultdict(TreeNode) class Trie: def __init__(self, eos): self.root = TreeNode() self.eos = eos def insert(self, word): cur = self.root for c in word: cur = cur.child[c] def get_next_layer(self, word): cur = self.root for c in word: cur = cur.child.get(c) if cur is None: return [self.eos] return list(cur.child.keys())