# Imports # Core Imports import torch # Model-related Imports from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct from transformers import pipeline # restore punct import gradio as gr # Evaluation Imports from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import pandas as pd import string # Instantiate model to restore punctuation print("1/7 - Instantiating model to restore punctuation") punct_model_path = "felflare/bert-restore-punctuation" # Load punct tokenizer and model punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path) punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path) punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer) # Instantiate fine-tuned horror BART model print("2/7 - Instantiating two-sentence horror generation model") model_path = 'voacado/bart-two-sentence-horror' # Load tokenizer and model tokenizer = BartTokenizer.from_pretrained(model_path) model = BartForConditionalGeneration.from_pretrained(model_path) # Load data for evaluation metrics print("3/7 - Reading in data") data = pd.read_csv("./reddit_cleansed_data.csv") data['weighted_score'] = data['score'] + (10 * data['num_comments']) + (100 * data['gilded_count']) dataset_stories = (data['title'] + ' ' + data['selftext']).to_list() # Instantiate evaluation metrics - Cosine Similarity with TF-IDF print("4/7 - Instantiating evaluation metrics - Cosine Similarity with TF-IDF") # Pre-vectorize dataset vectorizer = TfidfVectorizer() dataset_matrix = vectorizer.fit_transform(dataset_stories) def eval_cosine_similarity(input_sentence: str) -> [str, str]: """ Evaluate cosine similarity between input sentence and each story in the dataset. Args: input_sentence (str): user story (first sentence) Returns: [str, str]: most similar story, weighted score """ # Vectorize input sentence using the existing vocab input_vec = vectorizer.transform([input_sentence]) # Get cosine similarity similarities = cosine_similarity(input_vec, dataset_matrix) # Find most similar story most_similar_story_idx = similarities.argmax() most_similar_story = dataset_stories[most_similar_story_idx] # Get weighted score of most similar story weighted_score = data['weighted_score'][most_similar_story_idx] return most_similar_story, weighted_score # Instantiate evaluation metrics - Jaccard Similarity print("5/7 - Instantiating evaluation metrics - Jaccard Similarity") def tokenize(text: str): """ Convert text to lowercase and remove punctuation, then tokenize. Args: text (str): user story Returns: set: set of tokens """ text = text.lower() text = text.translate(str.maketrans('', '', string.punctuation)) tokens = text.split() return set(tokens) def jaccard_similarity(set1: set, set2: set): """ Calculate Jaccard similarity between two sets. Args: set1 (set): user_tokens set2 (set): story_tokens Returns: float: Jaccard similarity """ intersection = set1.intersection(set2) union = set1.union(set2) return len(intersection) / len(union) def eval_jaccard_similarity(input_sentence: str) -> [str, str]: """ Evaluate Jaccard similarity between input sentence and each story in the dataset. Args: input_sentence (str): user story (first sentence) Returns: [str, str]: most similar story, weighted score """ # Tokenize the user story user_tokens = tokenize(input_sentence) # Initialize variables to find the most similar story max_similarity = 0 most_similar_story = '' # Compare with each story in the dataset for story in dataset_stories: story_tokens = tokenize(story) similarity = jaccard_similarity(user_tokens, story_tokens) if similarity > max_similarity: max_similarity = similarity most_similar_story = story max_score = data['weighted_score'][dataset_stories.index(story)] return most_similar_story, max_score # Set up inference print("6/7 - Setting parameters for inference") # Set the model to evaluation mode model.eval() # If GPU, use it device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Restore punct def restore_punctuation(text: str, restorer: pipeline) -> str: """ Restore punctuation to text. Args: text (str): full story (first and second sentences) restorer (pipeline): model that restores punctuation Returns: str: punctuated text (based on input) """ # Use the model to predict punctuation punctuated_output = restorer(text) punct_text = [] # Define punctuation marks (note: not including left-side because we want space still) punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"] for elem in punctuated_output: cur_token = elem.get('word') # If token is punctuation, append to previous token if cur_token in punctuation_marks: punct_text[-1] += cur_token # If previous token is quotations, append to previous token elif punct_text and punct_text[-1] in ["'", "’", "“", "‘", "‘‘", "““"]: punct_text[-1] += cur_token # If token is a contraction or a quote, append to previous token (no space) elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]: # Remove space for contractions punct_text[-1] += cur_token # if prediction is LABEL_0, token should be capitalized elif elem.get('entity') == 'LABEL_0': punct_text.append(cur_token.capitalize()) # else if prediction is LABEL_1, token should be lowercase # elif elem.get('entity') == 'LABEL_1': else: punct_text.append(cur_token) # If there's no period at the end of the story, add one if punct_text[-1][-1] != '.': punct_text[-1] = punct_text[-1] + '.' return ' '.join(punct_text) def generate_text(input_text: str, full_sentence: str) -> [str, str, float, str, float]: """ Generate the second sentence of the horror story given the first (input_text). Args: input_text (str): first sentence of the horror story full_sentence (str): full story (first and second sentences) Returns: gen_text_punct (str): second sentence of the horror story similar_story_cosine (str): most similar story (cosine similarity) cosine_score (float): score of most similar story (cosine similarity) similar_story_jaccard (str): most similar story (Jaccard similarity) jaccard_score (float): score of most similar story (Jaccard similarity) """ # If user only enters first sentence, generate second sentence if not full_sentence: # Encode the input text input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) # Generate text with torch.no_grad(): output_ids = model.generate(input_ids, max_length=50) # Decode the generated text gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Restore punctuation gen_text_punct = restore_punctuation(gen_text, punct_restorer) full_sentence = input_text + ' ' + gen_text_punct else: gen_text_punct = "N/A" # Calculate Cosine and Jaccard similarity similar_story_cosine, cosine_score = eval_cosine_similarity(full_sentence) similar_story_jaccard, jaccard_score = eval_jaccard_similarity(full_sentence) return gen_text_punct, similar_story_cosine, cosine_score, similar_story_jaccard, jaccard_score # Create gradio demo print("7/7 - Launching demo") title = "👻 🫣 Generate a Two-Sentence Horror Story 😱 👻" description = """