Spaces:
Running
Running
from __future__ import print_function | |
import json | |
import math | |
import pickle | |
import sys | |
from io import open | |
import numpy as np | |
from os.path import abspath, dirname | |
sys.path.insert(0, dirname(dirname(abspath(__file__)))) | |
from torchmoji.word_generator import WordGenerator | |
from torchmoji.create_vocab import VocabBuilder | |
from torchmoji.sentence_tokenizer import SentenceTokenizer, extend_vocab, coverage | |
from torchmoji.tokenizer import tokenize | |
try: | |
unicode # Python 2 | |
except NameError: | |
unicode = str # Python 3 | |
IS_PYTHON2 = int(sys.version[0]) == 2 | |
DATASETS = [ | |
'Olympic', | |
'PsychExp', | |
'SCv1', | |
'SCv2-GEN', | |
'SE0714', | |
#'SE1604', # Excluded due to Twitter's ToS | |
'SS-Twitter', | |
'SS-Youtube', | |
] | |
DIR = '../data' | |
FILENAME_RAW = 'raw.pickle' | |
FILENAME_OWN = 'own_vocab.pickle' | |
FILENAME_OUR = 'twitter_vocab.pickle' | |
FILENAME_COMBINED = 'combined_vocab.pickle' | |
def roundup(x): | |
return int(math.ceil(x / 10.0)) * 10 | |
def format_pickle(dset, train_texts, val_texts, test_texts, train_labels, val_labels, test_labels): | |
return {'dataset': dset, | |
'train_texts': train_texts, | |
'val_texts': val_texts, | |
'test_texts': test_texts, | |
'train_labels': train_labels, | |
'val_labels': val_labels, | |
'test_labels': test_labels} | |
def convert_dataset(filepath, extend_with, vocab): | |
print('-- Generating {} '.format(filepath)) | |
sys.stdout.flush() | |
st = SentenceTokenizer(vocab, maxlen) | |
tokenized, dicts, _ = st.split_train_val_test(texts, | |
labels, | |
[data['train_ind'], | |
data['val_ind'], | |
data['test_ind']], | |
extend_with=extend_with) | |
pick = format_pickle(dset, tokenized[0], tokenized[1], tokenized[2], | |
dicts[0], dicts[1], dicts[2]) | |
with open(filepath, 'w') as f: | |
pickle.dump(pick, f) | |
cover = coverage(tokenized[2]) | |
print(' done. Coverage: {}'.format(cover)) | |
with open('../model/vocabulary.json', 'r') as f: | |
vocab = json.load(f) | |
for dset in DATASETS: | |
print('Converting {}'.format(dset)) | |
PATH_RAW = '{}/{}/{}'.format(DIR, dset, FILENAME_RAW) | |
PATH_OWN = '{}/{}/{}'.format(DIR, dset, FILENAME_OWN) | |
PATH_OUR = '{}/{}/{}'.format(DIR, dset, FILENAME_OUR) | |
PATH_COMBINED = '{}/{}/{}'.format(DIR, dset, FILENAME_COMBINED) | |
with open(PATH_RAW, 'rb') as dataset: | |
if IS_PYTHON2: | |
data = pickle.load(dataset) | |
else: | |
data = pickle.load(dataset, fix_imports=True) | |
# Decode data | |
try: | |
texts = [unicode(x) for x in data['texts']] | |
except UnicodeDecodeError: | |
texts = [x.decode('utf-8') for x in data['texts']] | |
wg = WordGenerator(texts) | |
vb = VocabBuilder(wg) | |
vb.count_all_words() | |
# Calculate max length of sequences considered | |
# Adjust batch_size accordingly to prevent GPU overflow | |
lengths = [len(tokenize(t)) for t in texts] | |
maxlen = roundup(np.percentile(lengths, 80.0)) | |
# Extract labels | |
labels = [x['label'] for x in data['info']] | |
convert_dataset(PATH_OWN, 50000, {}) | |
convert_dataset(PATH_OUR, 0, vocab) | |
convert_dataset(PATH_COMBINED, 10000, vocab) | |