File size: 10,520 Bytes
7d45f3e
 
 
40884a0
7d45f3e
 
 
 
 
 
40884a0
 
 
 
 
 
7d45f3e
 
 
40884a0
7d45f3e
 
 
 
 
 
 
 
 
 
40884a0
7d45f3e
 
 
 
 
 
 
 
40884a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d45f3e
40884a0
7d45f3e
 
 
 
 
 
 
 
40884a0
 
 
 
 
 
 
 
 
 
 
7d45f3e
 
40884a0
7d45f3e
 
 
 
 
 
 
 
 
40884a0
7d45f3e
 
40884a0
 
7d45f3e
 
 
 
40884a0
7d45f3e
 
 
40884a0
7d45f3e
 
 
 
40884a0
7d45f3e
 
40884a0
 
7d45f3e
40884a0
7d45f3e
40884a0
 
 
7d45f3e
40884a0
 
 
7d45f3e
40884a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d45f3e
40884a0
 
 
 
 
7d45f3e
 
 
 
40884a0
7d45f3e
 
 
 
 
 
40884a0
 
 
 
 
 
 
 
 
7d45f3e
 
 
 
40884a0
 
 
 
 
 
 
 
 
 
 
7d45f3e
 
 
a48cee6
7d45f3e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Imports
# Core Imports
import torch

# Model-related Imports
from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
from transformers import pipeline # restore punct
import gradio as gr

# Evaluation Imports
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import string



# Instantiate model to restore punctuation
print("1/7 - Instantiating model to restore punctuation")

punct_model_path = "felflare/bert-restore-punctuation"
# Load punct tokenizer and model
punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path)
punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path)
punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer)



# Instantiate fine-tuned horror BART model
print("2/7 - Instantiating two-sentence horror generation model")

model_path = 'voacado/bart-two-sentence-horror'
# Load tokenizer and model
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)



# Load data for evaluation metrics
print("3/7 - Reading in data")
data = pd.read_csv("./reddit_cleansed_data.csv")
data['weighted_score'] = data['score'] + (10 * data['num_comments']) + (100 * data['gilded_count']) 
dataset_stories = (data['title'] + ' ' + data['selftext']).to_list()



# Instantiate evaluation metrics - Cosine Similarity with TF-IDF
print("4/7 - Instantiating evaluation metrics - Cosine Similarity with TF-IDF")
# Pre-vectorize dataset
vectorizer = TfidfVectorizer()
dataset_matrix = vectorizer.fit_transform(dataset_stories)

def eval_cosine_similarity(input_sentence: str) -> [str, str]:
    """
    Evaluate cosine similarity between input sentence and each story in the dataset.

    Args:
        input_sentence (str): user story (first sentence)

    Returns:
        [str, str]: most similar story, weighted score
    """
    # Vectorize input sentence using the existing vocab
    input_vec = vectorizer.transform([input_sentence])
    # Get cosine similarity
    similarities = cosine_similarity(input_vec, dataset_matrix)
    # Find most similar story
    most_similar_story_idx = similarities.argmax()
    most_similar_story = dataset_stories[most_similar_story_idx]
    # Get weighted score of most similar story
    weighted_score = data['weighted_score'][most_similar_story_idx]

    return most_similar_story, weighted_score



# Instantiate evaluation metrics - Jaccard Similarity
print("5/7 - Instantiating evaluation metrics - Jaccard Similarity")
def tokenize(text: str):
    """
    Convert text to lowercase and remove punctuation, then tokenize.

    Args:
        text (str): user story

    Returns:
        set: set of tokens
    """
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    tokens = text.split()
    return set(tokens)

def jaccard_similarity(set1: set, set2: set):
    """
    Calculate Jaccard similarity between two sets.

    Args:
        set1 (set): user_tokens
        set2 (set): story_tokens

    Returns:
        float: Jaccard similarity
    """
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    return len(intersection) / len(union)

def eval_jaccard_similarity(input_sentence: str) -> [str, str]:
    """
    Evaluate Jaccard similarity between input sentence and each story in the dataset.

    Args:
        input_sentence (str): user story (first sentence)

    Returns:
        [str, str]: most similar story, weighted score
    """
    # Tokenize the user story
    user_tokens = tokenize(input_sentence)

    # Initialize variables to find the most similar story
    max_similarity = 0
    most_similar_story = ''

    # Compare with each story in the dataset
    for story in dataset_stories:
        story_tokens = tokenize(story)
        similarity = jaccard_similarity(user_tokens, story_tokens)
        if similarity > max_similarity:
            max_similarity = similarity
            most_similar_story = story
            max_score = data['weighted_score'][dataset_stories.index(story)]
    
    return most_similar_story, max_score



# Set up inference
print("6/7 - Setting parameters for inference")

# Set the model to evaluation mode
model.eval()
# If GPU, use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Restore punct
def restore_punctuation(text: str, restorer: pipeline) -> str:
    """
    Restore punctuation to text.

    Args:
        text (str): full story (first and second sentences)
        restorer (pipeline): model that restores punctuation

    Returns:
        str: punctuated text (based on input)
    """
    # Use the model to predict punctuation
    punctuated_output = restorer(text)
    punct_text = []
    
    # Define punctuation marks (note: not including left-side because we want space still)
    punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
    
    for elem in punctuated_output:
        cur_token = elem.get('word')
        
        # If token is punctuation, append to previous token
        if cur_token in punctuation_marks:
            punct_text[-1] += cur_token
            
        # If previous token is quotations, append to previous token
        elif punct_text and punct_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
            punct_text[-1] += cur_token
            
        # If token is a contraction or a quote, append to previous token (no space)
        elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
            # Remove space for contractions
            punct_text[-1] += cur_token
            
        # if prediction is LABEL_0, token should be capitalized
        elif elem.get('entity') == 'LABEL_0':
            punct_text.append(cur_token.capitalize())

        # else if prediction is LABEL_1, token should be lowercase
        # elif elem.get('entity') == 'LABEL_1':
        else:
            punct_text.append(cur_token)
            
    # If there's no period at the end of the story, add one
    if punct_text[-1][-1] != '.':
        punct_text[-1] = punct_text[-1] + '.'

    return ' '.join(punct_text)

def generate_text(input_text: str, full_sentence: str) -> [str, str, float, str, float]:
    """
    Generate the second sentence of the horror story given the first (input_text).

    Args:
        input_text (str): first sentence of the horror story
        full_sentence (str): full story (first and second sentences)

    Returns:
        gen_text_punct (str): second sentence of the horror story
        similar_story_cosine (str): most similar story (cosine similarity)
        cosine_score (float): score of most similar story (cosine similarity)
        similar_story_jaccard (str): most similar story (Jaccard similarity)
        jaccard_score (float): score of most similar story (Jaccard similarity)
    """
    # If user only enters first sentence, generate second sentence
    if not full_sentence:
        # Encode the input text
        input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

        # Generate text
        with torch.no_grad():
            output_ids = model.generate(input_ids, max_length=50)

        # Decode the generated text
        gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Restore punctuation
        gen_text_punct = restore_punctuation(gen_text, punct_restorer)
        full_sentence = input_text + ' ' + gen_text_punct
    else:
        gen_text_punct = "N/A"
    
    # Calculate Cosine and Jaccard similarity
    similar_story_cosine, cosine_score = eval_cosine_similarity(full_sentence)
    similar_story_jaccard, jaccard_score = eval_jaccard_similarity(full_sentence)

    return gen_text_punct, similar_story_cosine, cosine_score, similar_story_jaccard, jaccard_score



# Create gradio demo
print("7/7 - Launching demo")

title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
description = """
<center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
"""

article = """
Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023).

The language model is fine-tuned from ['facebook/bart-base'](https://huggingface.co/facebook/bart-base). We import, then update the weights for the model to generate two-sentence horror stories. The model is fine-tuned over 3 epochs to avoid catastrophic forgetting. We also use a separate model (['felflare/bert-restore-punctuation'](https://huggingface.co/felflare/bert-restore-punctuation?text=My+name+is+wolfgang+and+I+live+in+berlin)) to restore punctuation.

For evaluation, the generated story is compared to the most similar Reddit post (using either cosine or Jaccard similarity). The score of the most similar post is also returned. The score is calculated as the sum of the post score, 10 * number of comments, and 100 * number of gilds. The score is used as a proxy for the popularity of the post.

Users may also enter an entire story in the second input prompt rather than generating the remainder of the story. This will be used for evaluation metrics and no story will be generated.
"""


demo = gr.Interface(
    fn=generate_text, 
    inputs=[
        gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
        gr.Textbox(lines=4, placeholder="Or, enter full story for evaluation here...", label="Eval - Full Story")
    ],
    outputs=[
        gr.Textbox(lines=4, label="Generated Second Sentence"),
        gr.Textbox(lines=3, label="Cosine Similarity - Sentence"),
        gr.Textbox(lines=1, label="Cosine Similarity - Post Score"),
        gr.Textbox(lines=3, label="Jaccard Similarity - Sentence"),
        gr.Textbox(lines=1, label="Jaccard Similarity - Post Score")
    ],
    title=title,
    description=description,
    article=article,
    examples=[["My parents told me not to go upstairs."], ["There was a ghost."], ["Sometimes I catch myself staring at those missing person flyers at the store."]],
    )

demo.launch(share=True)