Spaces:
Sleeping
Sleeping
A Vo
commited on
Commit
β’
7d45f3e
1
Parent(s):
5bb7d76
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Imports
|
2 |
+
# Core Imports
|
3 |
+
import torch
|
4 |
+
# Model-related Imports
|
5 |
+
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
|
6 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
|
7 |
+
from transformers import pipeline # restore punct
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
# Instantiate model to restore punctuation
|
13 |
+
print("1/4 - Instantiating model to restore punctuation")
|
14 |
+
|
15 |
+
punct_model_path = "felflare/bert-restore-punctuation"
|
16 |
+
# Load punct tokenizer and model
|
17 |
+
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
|
18 |
+
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)
|
19 |
+
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# Instantiate fine-tuned horror BART model
|
24 |
+
print("2/4 - Instantiating two-sentence horror generation model")
|
25 |
+
|
26 |
+
model_path = 'voacado/bart-two-sentence-horror'
|
27 |
+
# Load tokenizer and model
|
28 |
+
tokenizer = BartTokenizer.from_pretrained(model_path)
|
29 |
+
model = BartForConditionalGeneration.from_pretrained(model_path)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
# Set up inference
|
34 |
+
print("3/4 - Setting parameters for inference")
|
35 |
+
|
36 |
+
# Set the model to evaluation mode
|
37 |
+
model.eval()
|
38 |
+
# If GPU, use it
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
# Restore punct
|
43 |
+
def restore_punctuation(text, restorer):
|
44 |
+
# Use the model to predict punctuation
|
45 |
+
punctuated_output = restorer(text)
|
46 |
+
punctuated_text = []
|
47 |
+
|
48 |
+
# Define punctuation marks (note: not including left-side because we want space still)
|
49 |
+
punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "β", ",", ")", "]", "}", "β¦", "β", "ββ", "''"]
|
50 |
+
|
51 |
+
for elem in punctuated_output:
|
52 |
+
cur_token = elem.get('word')
|
53 |
+
|
54 |
+
# If token is punctuation, append to previous token
|
55 |
+
if cur_token in punctuation_marks:
|
56 |
+
punctuated_text[-1] += cur_token
|
57 |
+
|
58 |
+
# If previous token is quotations, append to previous token
|
59 |
+
elif punctuated_text and punctuated_text[-1] in ["'", "β", "β", "β", "ββ", "ββ"]:
|
60 |
+
punctuated_text[-1] += cur_token
|
61 |
+
|
62 |
+
# If token is a contraction or a quote, append to previous token (no space)
|
63 |
+
elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
|
64 |
+
# Remove space for contractions
|
65 |
+
punctuated_text[-1] += cur_token
|
66 |
+
|
67 |
+
# if prediction is LABEL_0, token should be capitalized
|
68 |
+
elif elem.get('entity') == 'LABEL_0':
|
69 |
+
punctuated_text.append(cur_token.capitalize())
|
70 |
+
|
71 |
+
# else if prediction is LABEL_1, token should be lowercase
|
72 |
+
# elif elem.get('entity') == 'LABEL_1':
|
73 |
+
else:
|
74 |
+
punctuated_text.append(cur_token)
|
75 |
+
|
76 |
+
# If there's no period at the end of the story, add one
|
77 |
+
if punctuated_text[-1][-1] != '.':
|
78 |
+
punctuated_text[-1] = punctuated_text[-1] + '.'
|
79 |
+
|
80 |
+
return ' '.join(punctuated_text)
|
81 |
+
|
82 |
+
def generate_text(input_text):
|
83 |
+
# Encode the input text
|
84 |
+
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
|
85 |
+
|
86 |
+
# Generate text
|
87 |
+
with torch.no_grad():
|
88 |
+
output_ids = model.generate(input_ids, max_length=50)
|
89 |
+
|
90 |
+
# Decode the generated text
|
91 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
92 |
+
|
93 |
+
# Restore punctuation
|
94 |
+
generated_text_punct = restore_punctuation(generated_text, punct_restorer)
|
95 |
+
|
96 |
+
return generated_text_punct
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
# Create gradio demo
|
101 |
+
print("4/4 - Launching demo")
|
102 |
+
|
103 |
+
title = "π» π«£ Generate a Two-Sentence Horror Story π± π»"
|
104 |
+
description = """
|
105 |
+
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
|
106 |
+
"""
|
107 |
+
|
108 |
+
article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)."
|
109 |
+
|
110 |
+
|
111 |
+
demo = gr.Interface(
|
112 |
+
fn=generate_text,
|
113 |
+
inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
|
114 |
+
outputs=gr.Textbox(lines=4, label="Second Sentence"),
|
115 |
+
title=title,
|
116 |
+
description=description,
|
117 |
+
article=article,
|
118 |
+
examples=[["My parents told me not to go upstairs."], ["There was a ghost."]],
|
119 |
+
)
|
120 |
+
|
121 |
+
demo.launch(share=True)
|