A Vo
Add application file
7d45f3e
raw
history blame
4.35 kB
# Imports
# Core Imports
import torch
# Model-related Imports
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
from transformers import pipeline # restore punct
import gradio as gr
# Instantiate model to restore punctuation
print("1/4 - Instantiating model to restore punctuation")
punct_model_path = "felflare/bert-restore-punctuation"
# Load punct tokenizer and model
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)
# Instantiate fine-tuned horror BART model
print("2/4 - Instantiating two-sentence horror generation model")
model_path = 'voacado/bart-two-sentence-horror'
# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
# Set up inference
print("3/4 - Setting parameters for inference")
# Set the model to evaluation mode
model.eval()
# If GPU, use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Restore punct
def restore_punctuation(text, restorer):
# Use the model to predict punctuation
punctuated_output = restorer(text)
punctuated_text = []
# Define punctuation marks (note: not including left-side because we want space still)
punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
for elem in punctuated_output:
cur_token = elem.get('word')
# If token is punctuation, append to previous token
if cur_token in punctuation_marks:
punctuated_text[-1] += cur_token
# If previous token is quotations, append to previous token
elif punctuated_text and punctuated_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
punctuated_text[-1] += cur_token
# If token is a contraction or a quote, append to previous token (no space)
elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
# Remove space for contractions
punctuated_text[-1] += cur_token
# if prediction is LABEL_0, token should be capitalized
elif elem.get('entity') == 'LABEL_0':
punctuated_text.append(cur_token.capitalize())
# else if prediction is LABEL_1, token should be lowercase
# elif elem.get('entity') == 'LABEL_1':
else:
punctuated_text.append(cur_token)
# If there's no period at the end of the story, add one
if punctuated_text[-1][-1] != '.':
punctuated_text[-1] = punctuated_text[-1] + '.'
return ' '.join(punctuated_text)
def generate_text(input_text):
# Encode the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
# Generate text
with torch.no_grad():
output_ids = model.generate(input_ids, max_length=50)
# Decode the generated text
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Restore punctuation
generated_text_punct = restore_punctuation(generated_text, punct_restorer)
return generated_text_punct
# Create gradio demo
print("4/4 - Launching demo")
title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
description = """
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
"""
article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)."
demo = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
outputs=gr.Textbox(lines=4, label="Second Sentence"),
title=title,
description=description,
article=article,
examples=[["My parents told me not to go upstairs."], ["There was a ghost."]],
)
demo.launch(share=True)