Hanna Abi Akl
Update app.py
38e02fd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import gradio as gr
from torch.nn import functional as F
import seaborn
import matplotlib
import platform
from transformers.file_utils import ModelOutput
if platform.system() == "Darwin":
print("MacOS")
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
from PIL import Image
import matplotlib.font_manager as fm
# global var
MODEL_NAME = 'yseop/distilbert-base-financial-relation-extraction'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
config = AutoConfig.from_pretrained(MODEL_NAME)
MODEL_BUF = {
"name": MODEL_NAME,
"tokenizer": tokenizer,
"model": model,
"config": config
}
font_dir = ['./']
for font in fm.findSystemFonts(font_dir):
print(font)
fm.fontManager.addfont(font)
plt.rcParams["font.family"] = 'NanumGothicCoding'
def change_model_name(name):
MODEL_BUF["name"] = name
MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name)
MODEL_BUF["config"] = AutoConfig.from_pretrained(name)
def predict(model_name, text):
if model_name != MODEL_NAME:
change_model_name(model_name)
tokenizer = MODEL_BUF["tokenizer"]
model = MODEL_BUF["model"]
config = MODEL_BUF["config"]
tokenized_text = tokenizer([text], return_tensors='pt')
model.eval()
output, attention = model(**tokenized_text, output_attentions=True, return_dict=False)
output = F.softmax(output, dim=-1)
result = {}
for idx, label in enumerate(output[0].detach().numpy()):
result[config.id2label[idx]] = float(label)
return result
if __name__ == '__main__':
text1 = 'An A-B trust is a joint trust created by a married couple for the purpose of minimizing estate taxes.'
text2 = 'For example, if the supply of reserves in the fed funds market is greater than the demand, then the fed funds rate falls, and if the supply of reserves is less than the demand, the rate rises.'
text3 = 'Coupon dates are the dates on which the bond issuer will make interest payments.'
text4 = "Two features of a bond—credit quality and time to maturity—are the principal determinants of a bond's coupon rate."
text5 = "When an investment sale is less than a standard lot, it's referred to as a job lot."
text6 = 'Most bonds can be sold by the initial bondholder to other investors after they have been issued.'
text7 = 'A bond could be thought of as an I.O.U. between the lender and borrower.'
model_name_list = [
'yseop/distilbert-base-financial-relation-extraction'
]
#Create a gradio app with a button that calls predict()
app = gr.Interface(
fn=predict,
inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label'],
examples = [[MODEL_BUF["name"], text1], [MODEL_BUF["name"], text2], [MODEL_BUF["name"], text3], [MODEL_BUF["name"], text4], [MODEL_BUF["name"], text5], [MODEL_BUF["name"], text6], [MODEL_BUF["name"], text7]],
title="FReE (Financial Relation Extraction)",
description="A model capable of detecting the presence of a relationship between financial terms and qualifying the relationship in case of its presence."
)
app.launch(inline=False)