A Vo commited on
Commit
7d45f3e
β€’
1 Parent(s): 5bb7d76

Add application file

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ # Core Imports
3
+ import torch
4
+ # Model-related Imports
5
+ from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
6
+ from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
7
+ from transformers import pipeline # restore punct
8
+ import gradio as gr
9
+
10
+
11
+
12
+ # Instantiate model to restore punctuation
13
+ print("1/4 - Instantiating model to restore punctuation")
14
+
15
+ punct_model_path = "felflare/bert-restore-punctuation"
16
+ # Load punct tokenizer and model
17
+ punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
18
+ punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)
19
+ punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)
20
+
21
+
22
+
23
+ # Instantiate fine-tuned horror BART model
24
+ print("2/4 - Instantiating two-sentence horror generation model")
25
+
26
+ model_path = 'voacado/bart-two-sentence-horror'
27
+ # Load tokenizer and model
28
+ tokenizer = BartTokenizer.from_pretrained(model_path)
29
+ model = BartForConditionalGeneration.from_pretrained(model_path)
30
+
31
+
32
+
33
+ # Set up inference
34
+ print("3/4 - Setting parameters for inference")
35
+
36
+ # Set the model to evaluation mode
37
+ model.eval()
38
+ # If GPU, use it
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ # Restore punct
43
+ def restore_punctuation(text, restorer):
44
+ # Use the model to predict punctuation
45
+ punctuated_output = restorer(text)
46
+ punctuated_text = []
47
+
48
+ # Define punctuation marks (note: not including left-side because we want space still)
49
+ punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
50
+
51
+ for elem in punctuated_output:
52
+ cur_token = elem.get('word')
53
+
54
+ # If token is punctuation, append to previous token
55
+ if cur_token in punctuation_marks:
56
+ punctuated_text[-1] += cur_token
57
+
58
+ # If previous token is quotations, append to previous token
59
+ elif punctuated_text and punctuated_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
60
+ punctuated_text[-1] += cur_token
61
+
62
+ # If token is a contraction or a quote, append to previous token (no space)
63
+ elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
64
+ # Remove space for contractions
65
+ punctuated_text[-1] += cur_token
66
+
67
+ # if prediction is LABEL_0, token should be capitalized
68
+ elif elem.get('entity') == 'LABEL_0':
69
+ punctuated_text.append(cur_token.capitalize())
70
+
71
+ # else if prediction is LABEL_1, token should be lowercase
72
+ # elif elem.get('entity') == 'LABEL_1':
73
+ else:
74
+ punctuated_text.append(cur_token)
75
+
76
+ # If there's no period at the end of the story, add one
77
+ if punctuated_text[-1][-1] != '.':
78
+ punctuated_text[-1] = punctuated_text[-1] + '.'
79
+
80
+ return ' '.join(punctuated_text)
81
+
82
+ def generate_text(input_text):
83
+ # Encode the input text
84
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
85
+
86
+ # Generate text
87
+ with torch.no_grad():
88
+ output_ids = model.generate(input_ids, max_length=50)
89
+
90
+ # Decode the generated text
91
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
+
93
+ # Restore punctuation
94
+ generated_text_punct = restore_punctuation(generated_text, punct_restorer)
95
+
96
+ return generated_text_punct
97
+
98
+
99
+
100
+ # Create gradio demo
101
+ print("4/4 - Launching demo")
102
+
103
+ title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
104
+ description = """
105
+ <center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
106
+ """
107
+
108
+ article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)."
109
+
110
+
111
+ demo = gr.Interface(
112
+ fn=generate_text,
113
+ inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
114
+ outputs=gr.Textbox(lines=4, label="Second Sentence"),
115
+ title=title,
116
+ description=description,
117
+ article=article,
118
+ examples=[["My parents told me not to go upstairs."], ["There was a ghost."]],
119
+ )
120
+
121
+ demo.launch(share=True)