Spaces:
Runtime error
Runtime error
import os | |
import random | |
import time | |
import pickle | |
import math | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model | |
from data import Dataset | |
from model import Model | |
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params | |
from constants import * | |
def main(args): | |
with open(args.dataset_info, 'rb') as rf: | |
dataset_info = pickle.load(rf) | |
for cw in args.condition_words.split(): | |
assert cw in dataset_info.word2index | |
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string) | |
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN}) | |
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0] | |
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device) | |
gpt_model.eval() | |
checkpoint = torch.load(args.ckpt, map_location=args.device) | |
model_args = checkpoint['args'] | |
conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
conditioning_model.load_state_dict(checkpoint['state_dict']) | |
conditioning_model = conditioning_model.to(args.device) | |
conditioning_model.eval() | |
print("=> loaded checkpoint '{}' (epoch {})" | |
.format(args.ckpt, checkpoint['epoch'])) | |
print('num params', num_params(conditioning_model)) | |
while True: | |
results = predict(gpt_model, | |
gpt_tokenizer, | |
conditioning_model, | |
[args.input_text], | |
args.condition_words, | |
dataset_info, | |
args.precondition_topk, | |
args.topk, | |
args.length_cutoff, | |
condition_lambda=args.condition_lambda, | |
device=args.device) | |
print(results) | |
import pdb; pdb.set_trace() | |
def predict(gpt_model, gpt_tokenizer, conditioning_model, input_text, condition_words, dataset_info, precondition_topk, postcondition_topk, length_cutoff, condition_lambda=1.0, device='cuda'): | |
with torch.no_grad(): | |
batch_size = len(input_text) | |
condition_words = condition_words.split() | |
future_words = torch.LongTensor([dataset_info.word2index[cw] for cw in condition_words]).to(device) # N | |
log_probs = torch.Tensor([math.log(dataset_info.vocab[cw] / dataset_info.total_words) for cw in condition_words]).to(device) # N | |
# assumes initially all same length. | |
encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq | |
encoded_input = torch.cat(encoded_input, dim=0) | |
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device) | |
gpt_encoded_future_words = [gpt_tokenizer.encode(' ' + cw, return_tensors='pt')[0].to(device) for cw in condition_words] | |
while lengths.max() < length_cutoff: | |
tokens_left = torch.LongTensor([length_cutoff - lengths.max() for _ in range(batch_size)]).to(device) | |
gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab | |
top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1) # batch x topk | |
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1 | |
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk | |
expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N | |
expanded_tokens_left = tokens_left.unsqueeze(1).expand(-1, precondition_topk) # batch x topk | |
if condition_lambda == 0: | |
condition_logits = torch.zeros_like(expanded_future_words).float() | |
else: | |
condition_logits = conditioning_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1 | |
expanded_lengths.flatten(0, 1), # batch*topk | |
expanded_future_words.flatten(0, 1), # batch*topk x N | |
log_probs, # N | |
expanded_tokens_left.flatten(0, 1)) # batch*topk | |
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N | |
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs | |
condition_logits = torch.mean(condition_logits, dim=2) | |
full_logits = top_logits + condition_logits * condition_lambda # batch x topk | |
post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1) | |
post_probs = F.softmax(post_logits, dim=1) | |
index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch | |
next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch | |
encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1 | |
lengths = lengths + 1 # batch | |
return [gpt_tokenizer.decode(s) for s in encoded_input] | |
if __name__=='__main__': | |
parser = ArgumentParser() | |
# DATA | |
parser.add_argument('--ckpt', type=str, required=True) | |
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info') | |
parser.add_argument('--model_string', type=str, default='gpt2-medium') | |
parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text') | |
parser.add_argument('--condition_words', type=str, default=None, required=True, help='word(s) to optimize for') | |
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning') | |
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step') | |
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model') | |
parser.add_argument('--length_cutoff', type=int, default=80, help='max length') | |
parser.add_argument('--seed', type=int, default=1, help='random seed') | |
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
parser.add_argument('--debug', action='store_true', default=False) | |
args = parser.parse_args() | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
main(args) |