Spaces:
Runtime error
Runtime error
# import os | |
# os.chdir('naacl-2021-fudge-controlled-generation/') | |
import gradio as gr | |
from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer | |
from datasets import load_dataset,DatasetDict,Dataset | |
# from datasets import | |
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
from sklearn.utils.class_weight import compute_class_weight | |
import torch | |
import pandas as pd | |
from fudge.model import Model | |
import os | |
from argparse import ArgumentParser | |
from collections import namedtuple | |
import mock | |
from tqdm import tqdm | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from data import Dataset | |
from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params | |
from fudge.constants import * | |
# imp.reload(model) | |
pretrained_model = "../checkpoint-150/" | |
generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device) | |
device = 'cuda' | |
pad_id = 0 | |
generation_model.eval() | |
model_args = mock.Mock() | |
model_args.task = 'clickbait' | |
model_args.device = device | |
model_args.checkpoint = '../checkpoint-1464/' | |
# conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway | |
conditioning_model = conditioning_model.to(device) | |
conditioning_model.eval() | |
condition_lambda = 5.0 | |
length_cutoff = 50 | |
precondition_topk = 200 | |
conditioning_model.classifier | |
model_args.checkpoint | |
classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True) | |
def rate_title(input_text, model, tokenizer, device='cuda'): | |
# input_text = { | |
# "postText": input_text['postText'], | |
# "truthClass" : input_text['truthClass'] | |
# } | |
tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer) | |
# print(tokenized_input.items()) | |
dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'} | |
predicted_class = float(model(**dict_tokenized_input).logits) | |
actual_class = input_text['truthClass'] | |
# print(predicted_class, actual_class) | |
return {'predicted_class' : predicted_class} | |
def preprocess_function_title_only_classification(examples,tokenizer=None): | |
model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25) | |
model_inputs['labels'] = examples['truthClass'] | |
return model_inputs | |
def clickbait_generator(article_content, condition_lambda=5.0): | |
# result = "Hi {}! π. The Mulitple of {} is {}".format(name, number, round(number**2, 2)) | |
results = generate_clickbait(model=generation_model, | |
tokenizer=tokenizer, | |
conditioning_model=conditioning_model, | |
input_text=[None], | |
dataset_info=dataset_info, | |
precondition_topk=precondition_topk, | |
length_cutoff=length_cutoff, | |
condition_lambda=condition_lambda, | |
article_content=article_content, | |
device=device) | |
return results[0].replace('</s>', '').replace('<pad>', '') | |
title = "Clickbait generator" | |
description = """ | |
"Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!" | |
""" | |
article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of." | |
app = gr.Interface( | |
title = title, | |
description = description, | |
label = 'Article content or paragraph', | |
fn = clickbait_generator, | |
inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text") | |
app.launch() |