File size: 4,437 Bytes
7d45f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a48cee6
7d45f3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Imports
# Core Imports
import torch
# Model-related Imports
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
from transformers import pipeline # restore punct
import gradio as gr



# Instantiate model to restore punctuation
print("1/4 - Instantiating model to restore punctuation")

punct_model_path = "felflare/bert-restore-punctuation"
# Load punct tokenizer and model
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)



# Instantiate fine-tuned horror BART model
print("2/4 - Instantiating two-sentence horror generation model")

model_path = 'voacado/bart-two-sentence-horror'
# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)



# Set up inference
print("3/4 - Setting parameters for inference")

# Set the model to evaluation mode
model.eval()
# If GPU, use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Restore punct
def restore_punctuation(text, restorer):
    # Use the model to predict punctuation
    punctuated_output = restorer(text)
    punctuated_text = []
    
    # Define punctuation marks (note: not including left-side because we want space still)
    punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
    
    for elem in punctuated_output:
        cur_token = elem.get('word')
        
        # If token is punctuation, append to previous token
        if cur_token in punctuation_marks:
            punctuated_text[-1] += cur_token
            
        # If previous token is quotations, append to previous token
        elif punctuated_text and punctuated_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
            punctuated_text[-1] += cur_token
            
        # If token is a contraction or a quote, append to previous token (no space)
        elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
            # Remove space for contractions
            punctuated_text[-1] += cur_token
            
        # if prediction is LABEL_0, token should be capitalized
        elif elem.get('entity') == 'LABEL_0':
            punctuated_text.append(cur_token.capitalize())

        # else if prediction is LABEL_1, token should be lowercase
        # elif elem.get('entity') == 'LABEL_1':
        else:
            punctuated_text.append(cur_token)
            
    # If there's no period at the end of the story, add one
    if punctuated_text[-1][-1] != '.':
        punctuated_text[-1] = punctuated_text[-1] + '.'

    return ' '.join(punctuated_text)

def generate_text(input_text):
    # Encode the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

    # Generate text
    with torch.no_grad():
        output_ids = model.generate(input_ids, max_length=50)

    # Decode the generated text
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # Restore punctuation
    generated_text_punct = restore_punctuation(generated_text, punct_restorer)
    
    return generated_text_punct



# Create gradio demo
print("4/4 - Launching demo")

title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
description = """
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
"""

article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)."


demo = gr.Interface(
    fn=generate_text, 
    inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
    outputs=gr.Textbox(lines=4, label="Second Sentence"),
    title=title,
    description=description,
    article=article,
    examples=[["My parents told me not to go upstairs."], ["There was a ghost."], ["Sometimes I catch myself staring at those missing person flyers at the store."]],
    )

demo.launch(share=True)