Spaces:
Sleeping
Sleeping
# Imports | |
# Core Imports | |
import torch | |
# Model-related Imports | |
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model | |
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct | |
from transformers import pipeline # restore punct | |
import gradio as gr | |
# Instantiate model to restore punctuation | |
print("1/4 - Instantiating model to restore punctuation") | |
punct_model_path = "felflare/bert-restore-punctuation" | |
# Load punct tokenizer and model | |
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path) | |
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path) | |
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer) | |
# Instantiate fine-tuned horror BART model | |
print("2/4 - Instantiating two-sentence horror generation model") | |
model_path = 'voacado/bart-two-sentence-horror' | |
# Load tokenizer and model | |
tokenizer = BartTokenizer.from_pretrained(model_path) | |
model = BartForConditionalGeneration.from_pretrained(model_path) | |
# Set up inference | |
print("3/4 - Setting parameters for inference") | |
# Set the model to evaluation mode | |
model.eval() | |
# If GPU, use it | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Restore punct | |
def restore_punctuation(text, restorer): | |
# Use the model to predict punctuation | |
punctuated_output = restorer(text) | |
punctuated_text = [] | |
# Define punctuation marks (note: not including left-side because we want space still) | |
punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "β", ",", ")", "]", "}", "β¦", "β", "ββ", "''"] | |
for elem in punctuated_output: | |
cur_token = elem.get('word') | |
# If token is punctuation, append to previous token | |
if cur_token in punctuation_marks: | |
punctuated_text[-1] += cur_token | |
# If previous token is quotations, append to previous token | |
elif punctuated_text and punctuated_text[-1] in ["'", "β", "β", "β", "ββ", "ββ"]: | |
punctuated_text[-1] += cur_token | |
# If token is a contraction or a quote, append to previous token (no space) | |
elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]: | |
# Remove space for contractions | |
punctuated_text[-1] += cur_token | |
# if prediction is LABEL_0, token should be capitalized | |
elif elem.get('entity') == 'LABEL_0': | |
punctuated_text.append(cur_token.capitalize()) | |
# else if prediction is LABEL_1, token should be lowercase | |
# elif elem.get('entity') == 'LABEL_1': | |
else: | |
punctuated_text.append(cur_token) | |
# If there's no period at the end of the story, add one | |
if punctuated_text[-1][-1] != '.': | |
punctuated_text[-1] = punctuated_text[-1] + '.' | |
return ' '.join(punctuated_text) | |
def generate_text(input_text): | |
# Encode the input text | |
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) | |
# Generate text | |
with torch.no_grad(): | |
output_ids = model.generate(input_ids, max_length=50) | |
# Decode the generated text | |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Restore punctuation | |
generated_text_punct = restore_punctuation(generated_text, punct_restorer) | |
return generated_text_punct | |
# Create gradio demo | |
print("4/4 - Launching demo") | |
title = "π» π«£ Generate a Two-Sentence Horror Story π± π»" | |
description = """ | |
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center> | |
""" | |
article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)." | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"), | |
outputs=gr.Textbox(lines=4, label="Second Sentence"), | |
title=title, | |
description=description, | |
article=article, | |
examples=[["My parents told me not to go upstairs."], ["There was a ghost."]], | |
) | |
demo.launch(share=True) |