Spaces:
Runtime error
Runtime error
File size: 6,926 Bytes
fc5ecba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
from data import Dataset
from model import Model
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from constants import *
def main(args):
with open(args.dataset_info, 'rb') as rf:
dataset_info = pickle.load(rf)
for cw in args.condition_words.split():
assert cw in dataset_info.word2index
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
gpt_model.eval()
checkpoint = torch.load(args.ckpt, map_location=args.device)
model_args = checkpoint['args']
conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model.load_state_dict(checkpoint['state_dict'])
conditioning_model = conditioning_model.to(args.device)
conditioning_model.eval()
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.ckpt, checkpoint['epoch']))
print('num params', num_params(conditioning_model))
while True:
results = predict(gpt_model,
gpt_tokenizer,
conditioning_model,
[args.input_text],
args.condition_words,
dataset_info,
args.precondition_topk,
args.topk,
args.length_cutoff,
condition_lambda=args.condition_lambda,
device=args.device)
print(results)
import pdb; pdb.set_trace()
def predict(gpt_model, gpt_tokenizer, conditioning_model, input_text, condition_words, dataset_info, precondition_topk, postcondition_topk, length_cutoff, condition_lambda=1.0, device='cuda'):
with torch.no_grad():
batch_size = len(input_text)
condition_words = condition_words.split()
future_words = torch.LongTensor([dataset_info.word2index[cw] for cw in condition_words]).to(device) # N
log_probs = torch.Tensor([math.log(dataset_info.vocab[cw] / dataset_info.total_words) for cw in condition_words]).to(device) # N
# assumes initially all same length.
encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
encoded_input = torch.cat(encoded_input, dim=0)
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
gpt_encoded_future_words = [gpt_tokenizer.encode(' ' + cw, return_tensors='pt')[0].to(device) for cw in condition_words]
while lengths.max() < length_cutoff:
tokens_left = torch.LongTensor([length_cutoff - lengths.max() for _ in range(batch_size)]).to(device)
gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1) # batch x topk
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
expanded_tokens_left = tokens_left.unsqueeze(1).expand(-1, precondition_topk) # batch x topk
if condition_lambda == 0:
condition_logits = torch.zeros_like(expanded_future_words).float()
else:
condition_logits = conditioning_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
expanded_lengths.flatten(0, 1), # batch*topk
expanded_future_words.flatten(0, 1), # batch*topk x N
log_probs, # N
expanded_tokens_left.flatten(0, 1)) # batch*topk
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
condition_logits = torch.mean(condition_logits, dim=2)
full_logits = top_logits + condition_logits * condition_lambda # batch x topk
post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
post_probs = F.softmax(post_logits, dim=1)
index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
lengths = lengths + 1 # batch
return [gpt_tokenizer.decode(s) for s in encoded_input]
if __name__=='__main__':
parser = ArgumentParser()
# DATA
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
parser.add_argument('--model_string', type=str, default='gpt2-medium')
parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
parser.add_argument('--condition_words', type=str, default=None, required=True, help='word(s) to optimize for')
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--debug', action='store_true', default=False)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
main(args) |