|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""utils for ngram for ZEN model.""" |
|
|
|
import os |
|
import logging |
|
|
|
NGRAM_DICT_NAME = 'ngram.txt' |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ZenNgramDict(object): |
|
""" |
|
Dict class to store the ngram |
|
""" |
|
def __init__(self, ngram_freq_path, tokenizer, max_ngram_in_seq=128): |
|
"""Constructs ZenNgramDict |
|
|
|
:param ngram_freq_path: ngrams with frequency |
|
""" |
|
if os.path.isdir(ngram_freq_path): |
|
ngram_freq_path = os.path.join(ngram_freq_path, NGRAM_DICT_NAME) |
|
self.ngram_freq_path = ngram_freq_path |
|
self.max_ngram_in_seq = max_ngram_in_seq |
|
self.id_to_ngram_list = ["[pad]"] |
|
self.ngram_to_id_dict = {"[pad]": 0} |
|
self.ngram_to_freq_dict = {} |
|
|
|
logger.info("loading ngram frequency file {}".format(ngram_freq_path)) |
|
with open(ngram_freq_path, "r", encoding="utf-8") as fin: |
|
for i, line in enumerate(fin): |
|
ngram,freq = line.split(",") |
|
tokens = tuple(tokenizer.tokenize(ngram)) |
|
self.ngram_to_freq_dict[ngram] = freq |
|
self.id_to_ngram_list.append(tokens) |
|
self.ngram_to_id_dict[tokens] = i + 1 |
|
|
|
def save(self, ngram_freq_path): |
|
with open(ngram_freq_path, "w", encoding="utf-8") as fout: |
|
for ngram,freq in self.ngram_to_freq_dict.items(): |
|
fout.write("{},{}\n".format(ngram, freq)) |