# Imports # Core Imports import torch # Model-related Imports from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct from transformers import pipeline # restore punct import gradio as gr # Instantiate model to restore punctuation print("1/4 - Instantiating model to restore punctuation") punct_model_path = "felflare/bert-restore-punctuation" # Load punct tokenizer and model punct_tokenizer = AutoTokenizer.from_pretrained(punct_model_path) punct_model = AutoModelForTokenClassification.from_pretrained(punct_model_path) punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer) # Instantiate fine-tuned horror BART model print("2/4 - Instantiating two-sentence horror generation model") model_path = 'voacado/bart-two-sentence-horror' # Load tokenizer and model tokenizer = BartTokenizer.from_pretrained(model_path) model = BartForConditionalGeneration.from_pretrained(model_path) # Set up inference print("3/4 - Setting parameters for inference") # Set the model to evaluation mode model.eval() # If GPU, use it device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Restore punct def restore_punctuation(text, restorer): # Use the model to predict punctuation punctuated_output = restorer(text) punctuated_text = [] # Define punctuation marks (note: not including left-side because we want space still) punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"] for elem in punctuated_output: cur_token = elem.get('word') # If token is punctuation, append to previous token if cur_token in punctuation_marks: punctuated_text[-1] += cur_token # If previous token is quotations, append to previous token elif punctuated_text and punctuated_text[-1] in ["'", "’", "“", "‘", "‘‘", "““"]: punctuated_text[-1] += cur_token # If token is a contraction or a quote, append to previous token (no space) elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]: # Remove space for contractions punctuated_text[-1] += cur_token # if prediction is LABEL_0, token should be capitalized elif elem.get('entity') == 'LABEL_0': punctuated_text.append(cur_token.capitalize()) # else if prediction is LABEL_1, token should be lowercase # elif elem.get('entity') == 'LABEL_1': else: punctuated_text.append(cur_token) # If there's no period at the end of the story, add one if punctuated_text[-1][-1] != '.': punctuated_text[-1] = punctuated_text[-1] + '.' return ' '.join(punctuated_text) def generate_text(input_text): # Encode the input text input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) # Generate text with torch.no_grad(): output_ids = model.generate(input_ids, max_length=50) # Decode the generated text generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Restore punctuation generated_text_punct = restore_punctuation(generated_text, punct_restorer) return generated_text_punct # Create gradio demo print("4/4 - Launching demo") title = "👻 🫣 Generate a Two-Sentence Horror Story 😱 👻" description = """