Spaces:
Runtime error
Runtime error
import gradio as gr | |
import re | |
from gradio.mix import Parallel | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
) | |
def clean_text(text): | |
text = text.encode("ascii", errors="ignore").decode( | |
"ascii" | |
) # remove non-ascii, Chinese characters | |
text = re.sub(r"\n", " ", text) | |
text = re.sub(r"\n\n", " ", text) | |
text = re.sub(r"\t", " ", text) | |
text = text.strip(" ") | |
text = re.sub( | |
" +", " ", text | |
).strip() # get rid of multiple spaces and replace with a single | |
return text | |
modchoice_1 = "chinhon/headline_writer" | |
def headline_writer1(text): | |
input_text = clean_text(text) | |
tokenizer_1 = AutoTokenizer.from_pretrained(modchoice_1) | |
model_1 = AutoModelForSeq2SeqLM.from_pretrained(modchoice_1) | |
with tokenizer_1.as_target_tokenizer(): | |
batch = tokenizer_1( | |
input_text, truncation=True, padding="longest", return_tensors="pt" | |
) | |
translated = model_1.generate(**batch) | |
summary_1 = tokenizer_1.batch_decode(translated, skip_special_tokens=True) | |
return summary_1[0] | |
headline1 = gr.Interface( | |
fn=headline_writer1, | |
inputs=gr.inputs.Textbox(), | |
outputs=gr.outputs.Textbox(label=""), | |
) | |
modchoice_2 = "chinhon/pegasus-multi_news-headline" | |
def headline_writer2(text): | |
input_text = clean_text(text) | |
tokenizer_2 = AutoTokenizer.from_pretrained(modchoice_2) | |
model_2 = AutoModelForSeq2SeqLM.from_pretrained(modchoice_2) | |
with tokenizer_2.as_target_tokenizer(): | |
batch = tokenizer_2( | |
input_text, truncation=True, padding="longest", return_tensors="pt" | |
) | |
translated = model_2.generate(**batch) | |
summary_2 = tokenizer_2.batch_decode(translated, skip_special_tokens=True) | |
return summary_2[0] | |
headline2 = gr.Interface( | |
fn=headline_writer2, | |
inputs=gr.inputs.Textbox(), | |
outputs=gr.outputs.Textbox(label=""), | |
) | |
modchoice_3 = "chinhon/pegasus-newsroom-headline_writer" | |
def headline_writer3(text): | |
input_text = clean_text(text) | |
tokenizer_3 = AutoTokenizer.from_pretrained(modchoice_3) | |
model_3 = AutoModelForSeq2SeqLM.from_pretrained(modchoice_3) | |
with tokenizer_3.as_target_tokenizer(): | |
batch = tokenizer_3( | |
input_text, truncation=True, padding="longest", return_tensors="pt" | |
) | |
translated = model_3.generate(**batch) | |
summary_3 = tokenizer_3.batch_decode( | |
translated, skip_special_tokens=True, max_length=100 | |
) | |
return summary_3[0] | |
headline3 = gr.Interface( | |
fn=headline_writer3, | |
inputs=gr.inputs.Textbox(), | |
outputs=gr.outputs.Textbox(label=""), | |
) | |
Parallel( | |
headline1, | |
headline2, | |
headline3, | |
title="AI Headlines Generator", | |
inputs=gr.inputs.Textbox( | |
lines=20, | |
label="Paste the first few paragraphs of your story here, and choose from 3 suggested headlines", | |
), | |
theme="darkhuggingface", | |
).launch(enable_queue=True) | |