File size: 5,881 Bytes
3090b4b 1f0c01c 3090b4b c2f789d 70b186b c2f789d 10fdf8a 3090b4b 75d092f 70b186b 3090b4b c2f789d 3090b4b c2f789d 3090b4b c2f789d 75d092f 68e73b8 75d092f e395983 8a6156f 75d092f 8a6156f 75d092f 8a6156f 75d092f 55caca8 3090b4b c2f789d 4390c00 22a8fb2 3090b4b 55caca8 3090b4b 55caca8 3090b4b c2f789d 3090b4b 8027117 9c74623 8027117 3090b4b 7c39280 3090b4b 55caca8 40670fd 4bc07c4 3090b4b 10fdf8a c2f789d 3090b4b 752219e 55caca8 c2f789d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from predict import run_prediction
from io import StringIO
import json
import gradio as gr
import spacy
from spacy import displacy
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
import torch
import nltk
from nltk.tokenize import sent_tokenize
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
import en_core_web_sm
from fincat_utils import extract_context_words
from fincat_utils import bert_embedding_extract
import pickle
lr_clf = pickle.load(open("lr_clf_FiNCAT.pickle",'rb'))
nlp = en_core_web_sm.load()
nltk.download('punkt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#SUSTAINABILITY STARTS
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
model_sustain = BERTClass(2, "sustanability")
model_sustain.to(device)
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
def get_sustainability(text):
df = pd.DataFrame({'sentence':sent_tokenize(text)})
actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
highlight = []
for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
if prob>=4.384316:
highlight.append((sent, 'non-sustainable'))
elif prob<=1.423736:
highlight.append((sent, 'sustainable'))
else:
highlight.append((sent, '-'))
return highlight
#SUSTAINABILITY ENDS
#CLAIM STARTS
def score_fincat(txt):
li = []
highlight = []
txt = " " + txt + " "
k = ''
for word in txt.split():
if any(char.isdigit() for char in word):
if word[-1] in ['.', ',', ';', ":", "-", "!", "?", ")", '"', "'"]:
k = word[-1]
word = word[:-1]
st = txt.find(" " + word + k + " ")+1
k = ''
ed = st + len(word)
x = {'paragraph' : txt, 'offset_start':st, 'offset_end':ed}
context_text = extract_context_words(x)
features = bert_embedding_extract(context_text, word)
if(features[0]=='None'):
highlight.append(('None', ' '))
return highlight
prediction = lr_clf.predict(features.reshape(1, 768))
prediction_probability = '{:.4f}'.format(round(lr_clf.predict_proba(features.reshape(1, 768))[:,1][0], 4))
highlight.append((word, ' In-claim' if prediction==1 else 'Out-of-Claim'))
# li.append([word,' In-claim' if prediction==1 else 'Out-of-Claim', prediction_probability])
else:
highlight.append((word, ' '))
#headers = ['numeral', 'prediction', 'probability']
#dff = pd.DataFrame(li)
# dff.columns = headers
return highlight
##Summarization
summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
def summarize_text(text):
resp = summarizer(text)
stext = resp[0]['summary_text']
return stext
def split_in_sentences(text):
doc = nlp(text)
return [str(sent).strip() for sent in doc.sents]
def make_spans(text,results):
results_list = []
for i in range(len(results)):
results_list.append(results[i]['label'])
facts_spans = []
facts_spans = list(zip(split_in_sentences(text),results_list))
return facts_spans
##Forward Looking Statement
fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
def fls(text):
results = fls_model(split_in_sentences(text))
return make_spans(text,results)
##Company Extraction
ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
def fin_ner(text):
replaced_spans = ner(text)
return replaced_spans
#CUAD STARTS
def load_questions():
questions = []
with open('questions.txt') as f:
questions = f.readlines()
return questions
def load_questions_short():
questions_short = []
with open('questionshort.txt') as f:
questions_short = f.readlines()
return questions_short
def quad(query,file):
with open(file.name) as f:
paragraph = f.read()
questions = load_questions()
questions_short = load_questions_short()
if (not len(paragraph)==0) and not (len(query)==0):
print('getting predictions')
predictions = run_prediction([query], paragraph, 'marshmellow77/roberta-base-cuad',n_best_size=5)
answer = ""
if predictions['0'] == "":
answer = 'No answer found in document'
else:
with open("nbest.json") as jf:
data = json.load(jf)
for i in range(1):
raw_answer=data['0'][i]['text']
answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
#summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
#resp = summarizer(answer)
#stext = resp[0]['summary_text']
# highlight,dff=score_fincat(answer)
return answer,summarize_text(answer),score_fincat(answer),get_sustainability(answer),fls(answer)
# b6 = gr.Button("Get Sustainability")
#b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText())
#iface = gr.Interface(fn=get_sustainability, inputs="textbox", title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never")
#iface.launch()
iface = gr.Interface(fn=quad, inputs=[gr.inputs.Textbox(label='SEARCH QUERY'),gr.inputs.File(label='TXT FILE')], title="CONBERT",description="SUSTAINABILITY TOOL",article='Article', outputs=[gr.outputs.Textbox(label='Answer'),gr.outputs.Textbox(label='Summary'),gr.HighlightedText(label='CLAIM'),gr.HighlightedText(label='SUSTAINABILITY'),gr.HighlightedText(label='FLS')], allow_flagging="never")
iface.launch() |