from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig import gradio as gr from torch.nn import functional as F import seaborn import matplotlib import platform from transformers.file_utils import ModelOutput if platform.system() == "Darwin": print("MacOS") matplotlib.use('Agg') import matplotlib.pyplot as plt import io from PIL import Image import matplotlib.font_manager as fm # global var MODEL_NAME = 'yseop/distilbert-base-financial-relation-extraction' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) config = AutoConfig.from_pretrained(MODEL_NAME) MODEL_BUF = { "name": MODEL_NAME, "tokenizer": tokenizer, "model": model, "config": config } font_dir = ['./'] for font in fm.findSystemFonts(font_dir): print(font) fm.fontManager.addfont(font) plt.rcParams["font.family"] = 'NanumGothicCoding' def change_model_name(name): MODEL_BUF["name"] = name MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name) MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name) MODEL_BUF["config"] = AutoConfig.from_pretrained(name) def predict(model_name, text): if model_name != MODEL_NAME: change_model_name(model_name) tokenizer = MODEL_BUF["tokenizer"] model = MODEL_BUF["model"] config = MODEL_BUF["config"] tokenized_text = tokenizer([text], return_tensors='pt') model.eval() output, attention = model(**tokenized_text, output_attentions=True, return_dict=False) output = F.softmax(output, dim=-1) result = {} for idx, label in enumerate(output[0].detach().numpy()): result[config.id2label[idx]] = float(label) return result if __name__ == '__main__': text1 = 'An A-B trust is a joint trust created by a married couple for the purpose of minimizing estate taxes.' text2 = 'For example, if the supply of reserves in the fed funds market is greater than the demand, then the fed funds rate falls, and if the supply of reserves is less than the demand, the rate rises.' text3 = 'Coupon dates are the dates on which the bond issuer will make interest payments.' text4 = "Two features of a bond—credit quality and time to maturity—are the principal determinants of a bond's coupon rate." text5 = "When an investment sale is less than a standard lot, it's referred to as a job lot." text6 = 'Most bonds can be sold by the initial bondholder to other investors after they have been issued.' text7 = 'A bond could be thought of as an I.O.U. between the lender and borrower.' model_name_list = [ 'yseop/distilbert-base-financial-relation-extraction' ] #Create a gradio app with a button that calls predict() app = gr.Interface( fn=predict, inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=['label'], examples = [[MODEL_BUF["name"], text1], [MODEL_BUF["name"], text2], [MODEL_BUF["name"], text3], [MODEL_BUF["name"], text4], [MODEL_BUF["name"], text5], [MODEL_BUF["name"], text6], [MODEL_BUF["name"], text7]], title="FReE (Financial Relation Extraction)", description="A model capable of detecting the presence of a relationship between financial terms and qualifying the relationship in case of its presence." ) app.launch(inline=False)