Spaces:
Runtime error
Runtime error
""" | |
Language-related data loading helper functions and class wrappers. | |
""" | |
import re | |
import torch | |
import codecs | |
UNK_TOKEN = '<unk>' | |
PAD_TOKEN = '<pad>' | |
END_TOKEN = '<eos>' | |
SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') | |
class Dictionary(object): | |
def __init__(self): | |
self.word2idx = {} | |
self.idx2word = [] | |
def add_word(self, word): | |
if word not in self.word2idx: | |
self.idx2word.append(word) | |
self.word2idx[word] = len(self.idx2word) - 1 | |
return self.word2idx[word] | |
def __len__(self): | |
return len(self.idx2word) | |
def __getitem__(self, a): | |
if isinstance(a, int): | |
return self.idx2word[a] | |
elif isinstance(a, list): | |
return [self.idx2word[x] for x in a] | |
elif isinstance(a, str): | |
return self.word2idx[a] | |
else: | |
raise TypeError("Query word/index argument must be int or str") | |
def __contains__(self, word): | |
return word in self.word2idx | |
class Corpus(object): | |
def __init__(self): | |
self.dictionary = Dictionary() | |
def set_max_len(self, value): | |
self.max_len = value | |
def load_file(self, filename): | |
with codecs.open(filename, 'r', 'utf-8') as f: | |
for line in f: | |
line = line.strip() | |
self.add_to_corpus(line) | |
self.dictionary.add_word(UNK_TOKEN) | |
self.dictionary.add_word(PAD_TOKEN) | |
def add_to_corpus(self, line): | |
"""Tokenizes a text line.""" | |
# Add words to the dictionary | |
words = line.split() | |
# tokens = len(words) | |
for word in words: | |
word = word.lower() | |
self.dictionary.add_word(word) | |
def tokenize(self, line, max_len=20): | |
# Tokenize line contents | |
words = SENTENCE_SPLIT_REGEX.split(line.strip()) | |
# words = [w.lower() for w in words if len(w) > 0] | |
words = [w.lower() for w in words if (len(w) > 0 and w != ' ')] ## do not include space as a token | |
if words[-1] == '.': | |
words = words[:-1] | |
if max_len > 0: | |
if len(words) > max_len: | |
words = words[:max_len] | |
elif len(words) < max_len: | |
# words = [PAD_TOKEN] * (max_len - len(words)) + words | |
words = words + [END_TOKEN] + [PAD_TOKEN] * (max_len - len(words) - 1) | |
tokens = len(words) ## for end token | |
ids = torch.LongTensor(tokens) | |
token = 0 | |
for word in words: | |
if word not in self.dictionary: | |
word = UNK_TOKEN | |
# print(word, type(word), word.encode('ascii','ignore').decode('ascii'), type(word.encode('ascii','ignore').decode('ascii'))) | |
if type(word) != type('a'): | |
print(word, type(word), word.encode('ascii', 'ignore').decode('ascii'), | |
type(word.encode('ascii', 'ignore').decode('ascii'))) | |
word = word.encode('ascii', 'ignore').decode('ascii') | |
ids[token] = self.dictionary[word] | |
token += 1 | |
# ids[token] = self.dictionary[END_TOKEN] | |
return ids | |
def __len__(self): | |
return len(self.dictionary) | |