Spaces:
Runtime error
Runtime error
File size: 25,910 Bytes
fc5ecba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 |
import random
import math
import os
import pickle
from collections import defaultdict, namedtuple
import string
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
import numpy as np
from tqdm import tqdm
import torch
from util import suppress_stdout
from poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group
from constants import *
DatasetInfo = namedtuple('DatasetInfo',
['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings'])
RhymeInfo = namedtuple('RhymeInfo',
['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups'])
def collate(batch):
pad_id = batch[0][4]
inputs = [b[0] for b in batch]
lengths = torch.LongTensor([b[1] for b in batch])
max_length = lengths.max()
for i in range(len(inputs)):
if len(inputs[i]) < max_length:
inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out
inputs = torch.stack(inputs, dim=0)
future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch
labels = torch.zeros_like(future_words).long()
labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone()
log_probs = torch.Tensor([b[3] for b in batch])
classification_labels = [b[5] for b in batch] # batch
if type(classification_labels[0]) == list:
for i in range(len(classification_labels)):
assert len(classification_labels[i]) == lengths[i]
if len(classification_labels[i]) < max_length:
classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0)
else:
classification_labels[i] = torch.LongTensor(classification_labels[i])
classification_labels = torch.stack(classification_labels, dim=0) # batch x seq
else:
assert type(classification_labels[0]) == int
classification_labels = torch.LongTensor(classification_labels) # they're just int labels
syllables_to_go = torch.LongTensor([b[6] for b in batch])
future_word_num_syllables = torch.LongTensor([b[7] for b in batch])
rhyme_group_index = torch.LongTensor([b[8] for b in batch])
return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index)
def load_rhyme_info(index2word, vocab):
word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP)
rhyme_group_counts = defaultdict(lambda: 0)
rhyme_groups = set()
for word in index2word:
try:
rhyme_group = get_rhyme_group(word)
word2rhyme_group[word] = rhyme_group
rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1
rhyme_groups.add(rhyme_group)
except:
rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1)
index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups))
rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)}
total_rhyme_groups = sum(rhyme_group_counts.values())
return RhymeInfo(word2rhyme_group=dict(word2rhyme_group),
rhyme_group_counts=dict(rhyme_group_counts),
rhyme_groups=rhyme_groups,
index2rhyme_group=index2rhyme_group,
rhyme_group2index=rhyme_group2index,
total_rhyme_groups=total_rhyme_groups)
class Dataset:
def __init__(self, args):
print('loading data')
random.seed(args.seed)
self.batch_size = args.batch_size
self.data_dir = args.data_dir
self.topic = args.task == 'topic'
self.formality = args.task == 'formality'
self.iambic = args.task == 'iambic'
self.rhyme = args.task == 'rhyme'
self.newline = args.task == 'newline'
self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING)
self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size
sentences = []
self.vocab = defaultdict(lambda: 0)
if self.formality:
self.vocab['placeholder'] = 1 # anything so we don't crash
train, val, test = [], [], []
for category, label in [('formal', 1), ('informal', 0)]:
with open(os.path.join(args.data_dir, 'train', category), 'r') as rf:
for i, line in enumerate(rf):
if len(line) > FORMALITY_MAX_LEN:
line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset
if i < FORMALITY_VAL_SIZE // 2:
val.append((line.strip(), label))
else:
train.append((line.strip(), label))
with open(os.path.join(args.data_dir, 'test', category), 'r') as rf:
for line in rf:
if len(line) > FORMALITY_MAX_LEN:
line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len
test.append((line.strip(), label))
self.splits = {}
self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test
else: # topic / poetry
for root, _, filenames in os.walk(args.data_dir):
for fname in filenames:
with open(os.path.join(root, fname), 'r') as rf:
for line in rf:
sentences.append(line.strip())
for word in line.strip().split(' '):
self.vocab[word] += 1
random.shuffle(sentences)
self.splits = {}
if args.debug:
self.splits['val'] = sentences
self.splits['test'] = sentences
self.splits['train'] = sentences
else:
self.splits['val'] = sentences[:TOPIC_VAL_SIZE]
self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE]
self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:]
if args.dataset_info is not None:
print('loading dataset info from file')
with open(args.dataset_info, 'rb') as rf:
dataset_info = pickle.load(rf)
self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \
dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings
self.dataset_info = dataset_info
else:
print('generating dataset info from scratch')
words_values = list(self.vocab.items())
words_values = sorted(words_values, key=lambda x: x[1], reverse=True)
if args.glove_file is None:
print('no glove embeddings given')
for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens
del self.vocab[word]
glove_embeddings = None
else:
print('loading glove embeddings')
glove_embeddings = {}
with open(args.glove_file, 'r') as rf:
for i, line in enumerate(rf):
if i % GLOVE_PRINT_PROGRESS_FREQ == 0:
print(i)
line = line.strip().split()
if len(line) != GLOVE_DIM + 1:
continue # skip multi-word embeddings which are rare anyway
glove_embeddings[line[0]] = [float(x) for x in line[1:]]
for word, _ in words_values:
if word not in glove_embeddings:
del self.vocab[word]
self.total_words = sum(self.vocab.values())
self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys()))
self.word2index = {s: i for i, s in enumerate(self.index2word)}
self.vocab = dict(self.vocab) # so we can pickle later
if glove_embeddings is None:
self.glove_embeddings = None
else:
self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0)
self.dataset_info = DatasetInfo(index2word=self.index2word,
word2index=self.word2index,
total_words=self.total_words,
vocab=self.vocab,
glove_embeddings=self.glove_embeddings)
if self.rhyme:
if args.rhyme_info is not None:
print('loading rhyme info from file')
with open(args.rhyme_info, 'rb') as rf:
self.rhyme_info = pickle.load(rf)
else:
self.rhyme_info = load_rhyme_info(self.index2word, self.vocab)
self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \
defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups
print('done loading data')
print('split sizes:')
for key in ['train', 'val', 'test']:
print(key, len(self.splits[key]))
if not self.formality:
print('total words', self.total_words)
print('vocab size', len(self.index2word))
def shuffle(self, split, seed=None):
assert split in ['train', 'val', 'test']
if seed is not None:
random.seed(seed)
random.shuffle(self.splits[split])
def loader(self, split, num_workers=20, indices=None):
assert split in ['train', 'val', 'test']
data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices]
return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers)
class SplitLoader(torch.utils.data.IterableDataset):
def __init__(self, data, parent):
super(SplitLoader).__init__()
self.data = data
self.pos = 0
self.parent = parent
def __len__(self):
return len(self.data)
def __iter__(self):
return self
def __next__(self):
increment = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: # # in a worker process
increment = worker_info.num_workers
worker_id = worker_info.id
if self.pos == 0:
self.pos = worker_id
valid = False
while not valid:
if self.pos >= len(self):
raise StopIteration
if self.parent.topic:
failed = False
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
raw_sentence, classification_label = self.data[self.pos], -1
original_sentence = raw_sentence.split()
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
length = len(sentence)
min_sentence_length = MIN_SENTENCE_LENGTH
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
inp = sentence[:pos_to_split]
length = len(inp)
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
if not failed and num_words_in_input < len(original_sentence):
future_word_position_max = len(original_sentence) - 1
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
future_word = original_sentence[future_word_position]
unstripped_future_word = future_word
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
if not failed and future_word in self.parent.word2index.keys():
word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
future_word = self.parent.word2index[future_word]
pad_id = self.parent.gpt_pad_id
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
valid = not failed
elif self.parent.formality:
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
raw_sentence, classification_label = self.data[self.pos]
original_sentence = raw_sentence.split()
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
length = len(sentence)
min_sentence_length = MIN_SENTENCE_LENGTH
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
pos_to_split = length # no need to split; we're going to train on all possible prefixes simultaneously for efficiency
inp = sentence[:pos_to_split]
length = len(inp)
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
future_word_position_max = len(original_sentence) - 1
future_word_position = 0
future_word = 'placeholder'
unstripped_future_word = future_word
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
word_log_prob, future_word = 0, 0
pad_id = self.parent.gpt_pad_id
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
valid = True
elif self.parent.iambic:
failed = False
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
raw_sentence, classification_label = self.data[self.pos], -1
original_sentence = raw_sentence.split()
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
length = len(sentence)
min_sentence_length = MIN_SENTENCE_LENGTH
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
pos_to_split = random.randint(0, length - 1)
# try to get a subseq of exactly 10 syllables
inp = sentence[pos_to_split:]
num_syllables = 0
checked = False
for i in range(1, len(inp)):
decoded = self.parent.tokenizer.decode(inp[:i])
num_syllables = count_syllables(decoded)
if num_syllables > POETRY_LINE_SYLLABLES:
inp = inp[:i-1] # might get a few data points where the split is in the middle of a word, but it should be ok for learning.
last_line_length = i-1
decoded = self.parent.tokenizer.decode(inp)
num_syllables = count_syllables(decoded)
checked = True
break
if not checked or num_syllables != POETRY_LINE_SYLLABLES:
failed = True
length = len(inp)
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
classification_label = [is_iambic(self.parent.tokenizer.decode(inp)) for _ in range(length)] # predict for whole seq including future
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
future_word_position_max = len(original_sentence) - 1
future_word_position = 0
future_word = 'placeholder'
unstripped_future_word = future_word
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
if not failed:
word_log_prob, future_word = 0, 0
pad_id = self.parent.gpt_pad_id
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
valid = not failed
elif self.parent.rhyme:
failed = False
future_word_num_syllables, rhyme_group_index = -1, -1
raw_sentence, classification_label = self.data[self.pos], -1
original_sentence = raw_sentence.split()
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
length = len(sentence)
min_sentence_length = MIN_SENTENCE_LENGTH
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
inp = sentence[:pos_to_split]
length = len(inp)
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
if not failed and num_words_in_input < len(original_sentence):
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
future_word_position_max = min(len(original_sentence) - 1, num_words_in_input + MAX_COUNT_SYLLABLE_DIST)
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
future_word = original_sentence[future_word_position]
unstripped_future_word = future_word
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
syllables_to_go = count_syllables(' '.join(words_in_between))
if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
failed = True
future_word_num_syllables = count_syllables(future_word)
rhyme_group = self.parent.word2rhyme_group[future_word]
rhyme_group_index = self.parent.rhyme_group2index[rhyme_group]
# truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
inp = inp[-desired_length:]
length = len(inp)
if not failed and future_word in self.parent.word2index.keys():
word_log_prob = math.log(self.parent.rhyme_group_counts[rhyme_group] / self.parent.total_rhyme_groups)
future_word = rhyme_group_index # future conditioning is just the rhyme group in this case
pad_id = self.parent.gpt_pad_id
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
valid = not failed
elif self.parent.newline:
failed = False
future_word_num_syllables, rhyme_group_index = -1, -1
raw_sentence, classification_label = self.data[self.pos], -1
original_sentence = raw_sentence.split()
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
length = len(sentence)
min_sentence_length = MIN_SENTENCE_LENGTH
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
inp = sentence[:pos_to_split]
while pos_to_split < len(sentence):
if len(self.parent.tokenizer.decode(inp).split()) == len(self.parent.tokenizer.decode(sentence[:pos_to_split + 1]).split()):
pos_to_split += 1
inp = sentence[:pos_to_split]
else:
break
length = len(inp)
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
if not failed and num_words_in_input < len(original_sentence):
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
future_word_position_max = len(original_sentence) - 1
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
future_word = original_sentence[future_word_position]
unstripped_future_word = future_word
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
# future_word = original_sentence[-1] # useful for debugging
words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
syllables_to_go = count_syllables(' '.join(words_in_between))
if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
failed = True
# truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
# desired_length = 10 # useful for debugging
inp = inp[-desired_length:]
length = len(inp)
true_label = 1 if unstripped_future_word.strip()[-1] in PHRASE_ENDS else 0 # common ways to end a phrase
classification_label = [-1 for _ in range(length)]
classification_label[-1] = true_label # only learn at the last position
if not failed and future_word in self.parent.word2index.keys():
word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
future_word = self.parent.word2index[future_word]
pad_id = self.parent.gpt_pad_id
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
valid = not failed
else:
raise NotImplementedError
self.pos += increment
return example |