|
import gradio as gr
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import torch
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("noahkim/KoT5_news_summarization")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("noahkim/KoT5_news_summarization")
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
|
|
def summarize_text(input_text):
|
|
inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=2048)
|
|
inputs = {key: value.to(device) for key, value in inputs.items()}
|
|
|
|
summary_text_ids = model.generate(
|
|
input_ids=inputs['input_ids'],
|
|
attention_mask=inputs['attention_mask'],
|
|
max_length=512,
|
|
min_length=128,
|
|
num_beams=6,
|
|
repetition_penalty=1.5,
|
|
no_repeat_ngram_size=15,
|
|
)
|
|
|
|
summary_text = tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
|
|
return summary_text
|
|
|
|
|
|
iface = gr.Interface(
|
|
fn=summarize_text,
|
|
inputs=gr.Textbox(label="Input Text"),
|
|
outputs=gr.Textbox(label="Summary")
|
|
)
|
|
|
|
|
|
iface.launch()
|
|
|