Spaces:
Runtime error
Runtime error
import os | |
import random | |
import time | |
import pickle | |
import math | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from data import Dataset | |
from model import Model | |
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask | |
from constants import * | |
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index): | |
model.train() | |
if data_start_index == 0: | |
dataset.shuffle('train', seed=epoch + args.seed) | |
if args.epoch_max_len is not None: | |
data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train'])) | |
loader = dataset.loader('train', num_workers=args.num_workers, indices=list(range(data_start_index, data_end_index))) | |
data_start_index = data_end_index if data_end_index < len(dataset.splits['train']) else 0 | |
else: | |
loader = dataset.loader('train', num_workers=args.num_workers) | |
loss_meter = AverageMeter('loss', ':6.4f') | |
total_length = len(loader) | |
progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ') | |
for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): | |
batch = [tensor.to(args.device) for tensor in batch] | |
inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch | |
if args.task not in ['formality', 'iambic']: | |
if not args.debug and len(inputs) != args.batch_size: # it'll screw up the bias...? | |
continue | |
scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) | |
if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq | |
expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq | |
length_mask = pad_mask(lengths).permute(1, 0) # batch x seq | |
loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1]) | |
elif args.task in ['iambic', 'newline']: | |
use_indices = classification_targets.flatten() != -1 | |
loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) | |
else: # topic, rhyme | |
loss = criterion(scores.flatten(), labels.flatten().float()) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
loss_meter.update(loss.detach(), len(labels)) | |
if batch_num % args.train_print_freq == 0: | |
progress.display(batch_num) | |
progress.display(total_length) | |
return data_start_index | |
def validate(model, dataset, criterion, epoch, args): | |
model.eval() | |
random.seed(0) | |
loader = dataset.loader('val', num_workers=args.num_workers) | |
loss_meter = AverageMeter('loss', ':6.4f') | |
total_length = len(loader) | |
progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ') | |
with torch.no_grad(): | |
for batch_num, batch in enumerate(tqdm(loader, total=len(loader))): | |
batch = [tensor.to(args.device) for tensor in batch] | |
inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch | |
if args.task not in ['formality', 'iambic']: # topic predictor | |
if not args.debug and len(inputs) != args.batch_size: | |
continue | |
scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True) | |
if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq | |
expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq | |
length_mask = pad_mask(lengths).permute(1, 0) # batch x seq | |
loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1]) | |
elif args.task in ['iambic', 'newline']: | |
use_indices = classification_targets.flatten() != -1 | |
loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices]) | |
else: # topic, rhyme | |
loss = criterion(scores.flatten(), labels.flatten().float()) | |
loss_meter.update(loss.detach(), len(labels)) | |
if batch_num % args.train_print_freq == 0: | |
progress.display(batch_num) | |
progress.display(total_length) | |
return loss_meter.avg | |
def main(args): | |
dataset = Dataset(args) | |
os.makedirs(args.save_dir, exist_ok=True) | |
with open(os.path.join(args.save_dir, 'dataset_info'), 'wb') as wf: | |
pickle.dump(dataset.dataset_info, wf) | |
if args.task == 'rhyme': | |
with open(os.path.join(args.save_dir, 'rhyme_info'), 'wb') as wf: | |
pickle.dump(dataset.rhyme_info, wf) | |
if args.ckpt: | |
checkpoint = torch.load(args.ckpt, map_location=args.device) | |
start_epoch = checkpoint['epoch'] + 1 | |
best_val_metric = checkpoint['best_metric'] | |
model_args = checkpoint['args'] | |
model = Model(model_args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
model.load_state_dict(checkpoint['state_dict']) | |
model = model.to(args.device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=model_args.lr) | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
data_start_index = checkpoint['data_start_index'] | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.ckpt, checkpoint['epoch'])) | |
# NOTE: just import pdb after loading the model here if you want to play with it, it's easy | |
# model.eval() | |
# import pdb; pdb.set_trace() | |
else: | |
model = Model(args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None, glove_embeddings=dataset.glove_embeddings) | |
model = model.to(args.device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
best_val_metric = 1e8 # lower is better for BCE | |
data_start_index = 0 | |
print('num params', num_params(model)) | |
criterion = nn.BCEWithLogitsLoss().to(args.device) | |
if args.evaluate: | |
epoch = 0 | |
validate(model, dataset, criterion, epoch, args) | |
return | |
for epoch in range(args.epochs): | |
print("TRAINING: Epoch {} at {}".format(epoch, time.ctime())) | |
data_start_index = train(model, dataset, optimizer, criterion, epoch, args, data_start_index) | |
if epoch % args.validation_freq == 0: | |
print("VALIDATION: Epoch {} at {}".format(epoch, time.ctime())) | |
metric = validate(model, dataset, criterion, epoch, args) | |
if not args.debug: | |
if metric < best_val_metric: | |
print('new best val metric', metric) | |
best_val_metric = metric | |
save_checkpoint({ | |
'epoch': epoch, | |
'state_dict': model.state_dict(), | |
'best_metric': best_val_metric, | |
'optimizer': optimizer.state_dict(), | |
'data_start_index': data_start_index, | |
'args': args | |
}, os.path.join(args.save_dir, 'model_best.pth.tar')) | |
save_checkpoint({ | |
'epoch': epoch, | |
'state_dict': model.state_dict(), | |
'best_metric': metric, | |
'optimizer': optimizer.state_dict(), | |
'data_start_index': data_start_index, | |
'args': args | |
}, os.path.join(args.save_dir, 'model_epoch' + str(epoch) + '.pth.tar')) | |
if __name__=='__main__': | |
parser = ArgumentParser() | |
# DATA | |
parser.add_argument('--task', type=str, required=True, choices=['iambic', 'rhyme', 'newline', 'topic', 'formality', 'clickbait']) | |
parser.add_argument('--data_dir', type=str, required=True) | |
parser.add_argument('--glove_file', type=str, help='glove embedding init, for topic task') | |
# SAVE/LOAD | |
parser.add_argument('--save_dir', type=str, required=True, help='where to save ckpts') | |
parser.add_argument('--ckpt', type=str, default=None, help='load ckpt from file if given') | |
parser.add_argument('--dataset_info', type=str, help='saved dataset info') | |
parser.add_argument('--rhyme_info', type=str, help='saved dataset rhyme info, for a ckpt with task==rhyme') | |
# TRAINING | |
parser.add_argument('--batch_size', type=int, default=128) | |
parser.add_argument('--epochs', type=int, default=100) | |
parser.add_argument('--epoch_max_len', type=int, default=None, help='max batches per epoch if set, for more frequent validation') | |
parser.add_argument('--validation_freq', type=int, default=1, help='validate every X epochs') | |
parser.add_argument('--lr', type=float, default=1e-3, help='Adam learning rate') | |
parser.add_argument('--seed', type=int, default=1, help='random seed') | |
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
parser.add_argument('--num_workers', type=int, default=20, help='num workers for data loader') | |
parser.add_argument('--evaluate', action='store_true', default=False) | |
parser.add_argument('--debug', action='store_true', default=False) | |
# PRINTING | |
parser.add_argument('--train_print_freq', type=int, default=100, help='how often to print metrics (every X batches)') | |
args = parser.parse_args() | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if args.evaluate: | |
assert args.ckpt is not None | |
main(args) |