|
|
|
|
|
import torch |
|
|
|
|
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
|
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import pandas as pd |
|
import string |
|
|
|
|
|
|
|
|
|
print("1/7 - Instantiating model to restore punctuation") |
|
|
|
punct_model_path = "felflare/bert-restore-punctuation" |
|
|
|
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path) |
|
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path) |
|
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer) |
|
|
|
|
|
|
|
|
|
print("2/7 - Instantiating two-sentence horror generation model") |
|
|
|
model_path = 'voacado/bart-two-sentence-horror' |
|
|
|
tokenizer = BartTokenizer.from_pretrained(model_path) |
|
model = BartForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
|
print("3/7 - Reading in data") |
|
data = pd.read_csv("./reddit_cleansed_data.csv") |
|
data['weighted_score'] = data['score'] + (10 * data['num_comments']) + (100 * data['gilded_count']) |
|
dataset_stories = (data['title'] + ' ' + data['selftext']).to_list() |
|
|
|
|
|
|
|
|
|
print("4/7 - Instantiating evaluation metrics - Cosine Similarity with TF-IDF") |
|
|
|
vectorizer = TfidfVectorizer() |
|
dataset_matrix = vectorizer.fit_transform(dataset_stories) |
|
|
|
def eval_cosine_similarity(input_sentence: str) -> [str, str]: |
|
""" |
|
Evaluate cosine similarity between input sentence and each story in the dataset. |
|
|
|
Args: |
|
input_sentence (str): user story (first sentence) |
|
|
|
Returns: |
|
[str, str]: most similar story, weighted score |
|
""" |
|
|
|
input_vec = vectorizer.transform([input_sentence]) |
|
|
|
similarities = cosine_similarity(input_vec, dataset_matrix) |
|
|
|
most_similar_story_idx = similarities.argmax() |
|
most_similar_story = dataset_stories[most_similar_story_idx] |
|
|
|
weighted_score = data['weighted_score'][most_similar_story_idx] |
|
|
|
return most_similar_story, weighted_score |
|
|
|
|
|
|
|
|
|
print("5/7 - Instantiating evaluation metrics - Jaccard Similarity") |
|
def tokenize(text: str): |
|
""" |
|
Convert text to lowercase and remove punctuation, then tokenize. |
|
|
|
Args: |
|
text (str): user story |
|
|
|
Returns: |
|
set: set of tokens |
|
""" |
|
text = text.lower() |
|
text = text.translate(str.maketrans('', '', string.punctuation)) |
|
tokens = text.split() |
|
return set(tokens) |
|
|
|
def jaccard_similarity(set1: set, set2: set): |
|
""" |
|
Calculate Jaccard similarity between two sets. |
|
|
|
Args: |
|
set1 (set): user_tokens |
|
set2 (set): story_tokens |
|
|
|
Returns: |
|
float: Jaccard similarity |
|
""" |
|
intersection = set1.intersection(set2) |
|
union = set1.union(set2) |
|
return len(intersection) / len(union) |
|
|
|
def eval_jaccard_similarity(input_sentence: str) -> [str, str]: |
|
""" |
|
Evaluate Jaccard similarity between input sentence and each story in the dataset. |
|
|
|
Args: |
|
input_sentence (str): user story (first sentence) |
|
|
|
Returns: |
|
[str, str]: most similar story, weighted score |
|
""" |
|
|
|
user_tokens = tokenize(input_sentence) |
|
|
|
|
|
max_similarity = 0 |
|
most_similar_story = '' |
|
|
|
|
|
for story in dataset_stories: |
|
story_tokens = tokenize(story) |
|
similarity = jaccard_similarity(user_tokens, story_tokens) |
|
if similarity > max_similarity: |
|
max_similarity = similarity |
|
most_similar_story = story |
|
max_score = data['weighted_score'][dataset_stories.index(story)] |
|
|
|
return most_similar_story, max_score |
|
|
|
|
|
|
|
|
|
print("6/7 - Setting parameters for inference") |
|
|
|
|
|
model.eval() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
def restore_punctuation(text: str, restorer: pipeline) -> str: |
|
""" |
|
Restore punctuation to text. |
|
|
|
Args: |
|
text (str): full story (first and second sentences) |
|
restorer (pipeline): model that restores punctuation |
|
|
|
Returns: |
|
str: punctuated text (based on input) |
|
""" |
|
|
|
punctuated_output = restorer(text) |
|
punct_text = [] |
|
|
|
|
|
punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "β", ",", ")", "]", "}", "β¦", "β", "ββ", "''"] |
|
|
|
for elem in punctuated_output: |
|
cur_token = elem.get('word') |
|
|
|
|
|
if cur_token in punctuation_marks: |
|
punct_text[-1] += cur_token |
|
|
|
|
|
elif punct_text and punct_text[-1] in ["'", "β", "β", "β", "ββ", "ββ"]: |
|
punct_text[-1] += cur_token |
|
|
|
|
|
elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]: |
|
|
|
punct_text[-1] += cur_token |
|
|
|
|
|
elif elem.get('entity') == 'LABEL_0': |
|
punct_text.append(cur_token.capitalize()) |
|
|
|
|
|
|
|
else: |
|
punct_text.append(cur_token) |
|
|
|
|
|
if punct_text[-1][-1] != '.': |
|
punct_text[-1] = punct_text[-1] + '.' |
|
|
|
return ' '.join(punct_text) |
|
|
|
def generate_text(input_text: str, full_sentence: str) -> [str, str, float, str, float]: |
|
""" |
|
Generate the second sentence of the horror story given the first (input_text). |
|
|
|
Args: |
|
input_text (str): first sentence of the horror story |
|
full_sentence (str): full story (first and second sentences) |
|
|
|
Returns: |
|
gen_text_punct (str): second sentence of the horror story |
|
similar_story_cosine (str): most similar story (cosine similarity) |
|
cosine_score (float): score of most similar story (cosine similarity) |
|
similar_story_jaccard (str): most similar story (Jaccard similarity) |
|
jaccard_score (float): score of most similar story (Jaccard similarity) |
|
""" |
|
|
|
if not full_sentence: |
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(input_ids, max_length=50) |
|
|
|
|
|
gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
gen_text_punct = restore_punctuation(gen_text, punct_restorer) |
|
full_sentence = input_text + ' ' + gen_text_punct |
|
else: |
|
gen_text_punct = "N/A" |
|
|
|
|
|
similar_story_cosine, cosine_score = eval_cosine_similarity(full_sentence) |
|
similar_story_jaccard, jaccard_score = eval_jaccard_similarity(full_sentence) |
|
|
|
return gen_text_punct, similar_story_cosine, cosine_score, similar_story_jaccard, jaccard_score |
|
|
|
|
|
|
|
|
|
print("7/7 - Launching demo") |
|
|
|
title = "π» π«£ Generate a Two-Sentence Horror Story π± π»" |
|
description = """ |
|
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center> |
|
""" |
|
|
|
article = """ |
|
Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023). |
|
|
|
The language model is fine-tuned from ['facebook/bart-base'](https://huggingface.co/facebook/bart-base). We import, then update the weights for the model to generate two-sentence horror stories. The model is fine-tuned over 3 epochs to avoid catastrophic forgetting. We also use a separate model (['felflare/bert-restore-punctuation'](https://huggingface.co/felflare/bert-restore-punctuation?text=My+name+is+wolfgang+and+I+live+in+berlin)) to restore punctuation. |
|
|
|
For evaluation, the generated story is compared to the most similar Reddit post (using either cosine or Jaccard similarity). The score of the most similar post is also returned. The score is calculated as the sum of the post score, 10 * number of comments, and 100 * number of gilds. The score is used as a proxy for the popularity of the post. |
|
|
|
Users may also enter an entire story in the second input prompt rather than generating the remainder of the story. This will be used for evaluation metrics and no story will be generated. |
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"), |
|
gr.Textbox(lines=4, placeholder="Or, enter full story for evaluation here...", label="Eval - Full Story") |
|
], |
|
outputs=[ |
|
gr.Textbox(lines=4, label="Generated Second Sentence"), |
|
gr.Textbox(lines=3, label="Cosine Similarity - Sentence"), |
|
gr.Textbox(lines=1, label="Cosine Similarity - Post Score"), |
|
gr.Textbox(lines=3, label="Jaccard Similarity - Sentence"), |
|
gr.Textbox(lines=1, label="Jaccard Similarity - Post Score") |
|
], |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[["My parents told me not to go upstairs."], ["There was a ghost."], ["Sometimes I catch myself staring at those missing person flyers at the store."]], |
|
) |
|
|
|
demo.launch(share=True) |