|
import gradio |
|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
def chunk_text(text, chunk_size): |
|
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
return chunks |
|
|
|
def shorten_text(text, min_length, max_length): |
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
chunks = chunk_text(text, 1024) |
|
summary_chunks = [] |
|
for chunk in chunks: |
|
summary = summarizer(chunk, max_length, min_length, do_sample=False) |
|
summary_chunks.append(summary[0]["summary_text"]) |
|
summary = ' '.join(summary_chunks) |
|
return summary |
|
|
|
def paraphrase_text(text, min_length, max_length): |
|
tokenizer = AutoTokenizer.from_pretrained("randomshit11/fin-bert-1st-shit") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("randomshit11/fin-bert-1st-shit") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
text_instruction = "Summary: " + text + " </s>" |
|
chunks = chunk_text(text_instruction, 1024) |
|
output_chunks = [] |
|
for chunk in chunks: |
|
encoding = tokenizer.encode_plus(chunk, padding="longest", return_tensors="pt") |
|
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device) |
|
outputs = model.generate( |
|
input_ids=input_ids, attention_mask=attention_masks, |
|
max_length=max_length, |
|
do_sample=True, |
|
top_k=120, |
|
top_p=0.95, |
|
early_stopping=True, |
|
num_return_sequences=5 |
|
) |
|
line = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
output_chunks.append(line) |
|
output = ' '.join(output_chunks) |
|
return output |
|
|
|
def modify_text(mode, text, min_length, max_length): |
|
if mode == "shorten": |
|
return shorten_text(text, min_length, max_length) |
|
else: |
|
return paraphrase_text(text, min_length, max_length) |
|
|
|
gradio_interface = gradio.Interface( |
|
fn=modify_text, |
|
inputs=[ |
|
gradio.Radio(["shorten", "Summary"], label="Mode"), |
|
"text", |
|
gradio.Slider(5, 200, value=30, label="Min length"), |
|
gradio.Slider(5, 500, value=130, label="Max length") |
|
], |
|
outputs="text", |
|
examples=[ |
|
["shorten", |
|
"""Your long input text goes here...""", |
|
30, 130] |
|
], |
|
title="Text shortener/paraphraser", |
|
description="Shortening texts using `facebook/bart-large-cnn`, paraphrasing texts using `fin-bert-1st-shit`.", |
|
) |
|
gradio_interface.launch() |
|
|