Spaces:
Runtime error
Runtime error
import os | |
import random | |
import time | |
import pickle | |
import math | |
from argparse import ArgumentParser | |
import string | |
from collections import defaultdict | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model | |
from data import Dataset, load_rhyme_info | |
from model import Model | |
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params | |
from constants import * | |
from poetry_util import get_rhymes, count_syllables | |
from predict_poetry import predict_couplet | |
def main(args): | |
with open(args.dataset_info, 'rb') as rf: | |
dataset_info = pickle.load(rf) | |
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string) | |
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) | |
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0] | |
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device) | |
gpt_model.eval() | |
checkpoint = torch.load(args.iambic_ckpt, map_location=args.device) | |
model_args = checkpoint['args'] | |
iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
iambic_model.load_state_dict(checkpoint['state_dict']) | |
iambic_model = iambic_model.to(args.device) | |
iambic_model.eval() | |
if args.verbose: | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.iambic_ckpt, checkpoint['epoch'])) | |
print('iambic model num params', num_params(iambic_model)) | |
with open(args.rhyme_info, 'rb') as rf: | |
rhyme_info = pickle.load(rf) | |
checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device) | |
model_args = checkpoint['args'] | |
rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
rhyme_model.load_state_dict(checkpoint['state_dict']) | |
rhyme_model = rhyme_model.to(args.device) | |
rhyme_model.eval() | |
if args.verbose: | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.rhyme_ckpt, checkpoint['epoch'])) | |
print('rhyme model num params', num_params(rhyme_model)) | |
checkpoint = torch.load(args.newline_ckpt, map_location=args.device) | |
model_args = checkpoint['args'] | |
newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
newline_model.load_state_dict(checkpoint['state_dict']) | |
newline_model = newline_model.to(args.device) | |
newline_model.eval() | |
if args.verbose: | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.newline_ckpt, checkpoint['epoch'])) | |
print('iambic model num params', num_params(newline_model)) | |
with open(args.prefix_file, 'r') as rf: | |
lines = rf.readlines() | |
for line in tqdm(lines, total=len(lines)): | |
couplet = predict_couplet(gpt_model, | |
gpt_tokenizer, | |
iambic_model, | |
rhyme_model, | |
newline_model, | |
[line], | |
dataset_info, | |
rhyme_info, | |
args.precondition_topk, | |
args.topk, | |
condition_lambda=args.condition_lambda, | |
device=args.device) | |
assert len(couplet) == 2 | |
print(couplet[1].strip().replace('\n', '')) | |
if __name__=='__main__': | |
parser = ArgumentParser() | |
# DATA | |
parser.add_argument('--iambic_ckpt', type=str, required=True) | |
parser.add_argument('--rhyme_ckpt', type=str, required=True) | |
parser.add_argument('--newline_ckpt', type=str, required=True) | |
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') | |
parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info') | |
parser.add_argument('--model_string', type=str, default='gpt2-medium') | |
parser.add_argument('--prefix_file', type=str, default=None, required=True, help='file of prefix lines for couplets') | |
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning') | |
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step') | |
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') | |
parser.add_argument('--seed', type=int, default=1, help='random seed') | |
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
parser.add_argument('--debug', action='store_true', default=False) | |
parser.add_argument('--verbose', action='store_true', default=False) | |
args = parser.parse_args() | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
main(args) |