sohomju's picture
Update app.py
fac3fce verified
raw
history blame
3.66 kB
import pickle
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import BertTokenizer, BertForSequenceClassification, pipeline, AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForSeq2SeqLM, AutoModel, RobertaModel, RobertaTokenizer
from sentence_transformers import SentenceTransformer
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
#import lightgbm
#lr_clf_finbert = pickle.load(open("lr_clf_finread_new.pkl",'rb'))
tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_read = BERTClass(2, "readability")
model_read.to(device)
model_read.load_state_dict(torch.load('readability_model.bin', map_location=device, strict=False)['model_state_dict'])
def get_readability(text):
df = pd.DataFrame({'sentence':[text]})
actual_predictions_read = do_predict(model_read, tokenizer_read, df)
score = round(actual_predictions_read[1][0], 4)
return score
# Reference : https://huggingface.co/humarin/chatgpt_paraphraser_on_T5_base
tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
def paraphrase(
question,
num_beams=5,
num_beam_groups=5,
num_return_sequences=5,
repetition_penalty=10.0,
diversity_penalty=3.0,
no_repeat_ngram_size=2,
temperature=0.7,
max_length=128
):
input_ids = tokenizer(
f'paraphrase: {question}',
return_tensors="pt", padding="longest",
max_length=max_length,
truncation=True,
).input_ids
outputs = model.generate(
input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
num_beams=num_beams, num_beam_groups=num_beam_groups,
max_length=max_length, diversity_penalty=diversity_penalty
)
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return res
def get_most_readable_paraphrse(text):
li_paraphrases = paraphrase(text)
li_paraphrases.append(text)
best = li_paraphrases[0]
score_max = get_readability(best)
for i in range(1,len(li_paraphrases)):
curr = li_paraphrases[i]
score = get_readability(curr)
if score > score_max:
best = curr
score_max = score
if best!=text and score_max>.6:
ans = "The most redable version of text that I can think of is:\n" + best
else:
"Sorry! I am not confident. As per my best knowledge, you already have the most readable version of the text!"
return ans
def set_example_text(example_text):
return gr.Textbox.update(value=example_text[0])
with gr.Blocks() as demo:
gr.Markdown(
"""
# FinLanSer
Financial Language Simplifier
""")
text = gr.Textbox(label="Enter text you want to simply (make more readable)")
greet_btn = gr.Button("Simplify/Make Readable")
output = gr.Textbox(label="Output Box")
greet_btn.click(fn=get_most_readable_paraphrse, inputs=text, outputs=output, api_name="get_most_raedable_paraphrse")
example_text = gr.Dataset(components=[text], samples=[['Legally assured line of credit with a bank'], ['A mutual fund is a type of financial vehicle made up of a pool of money collected from many investors to invest in securities like stocks, bonds, money market instruments']])
example_text.click(fn=set_example_text, inputs=example_text,outputs=example_text.components)
demo.launch()