File size: 5,881 Bytes
3090b4b
 
 
1f0c01c
3090b4b
 
 
c2f789d
70b186b
 
c2f789d
10fdf8a
3090b4b
75d092f
 
 
 
70b186b
3090b4b
 
c2f789d
 
3090b4b
c2f789d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3090b4b
c2f789d
75d092f
 
 
 
 
 
 
 
 
 
 
68e73b8
75d092f
 
 
 
 
e395983
 
8a6156f
75d092f
 
 
8a6156f
75d092f
 
8a6156f
 
 
 
75d092f
 
55caca8
 
3090b4b
 
 
 
c2f789d
4390c00
 
 
 
22a8fb2
 
 
 
 
 
 
3090b4b
55caca8
3090b4b
 
 
 
 
55caca8
3090b4b
 
 
 
 
 
 
 
 
 
 
c2f789d
3090b4b
 
 
 
 
 
 
8027117
9c74623
8027117
3090b4b
 
7c39280
3090b4b
 
 
 
 
 
 
 
 
 
 
 
55caca8
 
 
 
40670fd
4bc07c4
3090b4b
 
10fdf8a
 
c2f789d
 
3090b4b
 
 
752219e
55caca8
 
c2f789d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from predict import run_prediction
from io import StringIO
import json
import gradio as gr
import spacy
from spacy import displacy
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
import torch
import nltk
from nltk.tokenize import sent_tokenize
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
import en_core_web_sm
from fincat_utils import extract_context_words
from fincat_utils import bert_embedding_extract
import pickle
lr_clf = pickle.load(open("lr_clf_FiNCAT.pickle",'rb'))

nlp = en_core_web_sm.load()
nltk.download('punkt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#SUSTAINABILITY STARTS
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
model_sustain = BERTClass(2, "sustanability")
model_sustain.to(device)
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])

def get_sustainability(text):
  df = pd.DataFrame({'sentence':sent_tokenize(text)})
  actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
  highlight = []
  for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
    if prob>=4.384316:
      highlight.append((sent, 'non-sustainable'))
    elif prob<=1.423736:
      highlight.append((sent, 'sustainable'))
    else:
      highlight.append((sent, '-'))
  return highlight
#SUSTAINABILITY ENDS

#CLAIM STARTS
def score_fincat(txt):
  li = []
  highlight = []
  txt = " " + txt + " "
  k = ''
  for word in txt.split():
    if any(char.isdigit() for char in word):
      if word[-1] in ['.', ',', ';', ":", "-", "!", "?", ")", '"', "'"]:
        k = word[-1]
        word = word[:-1]
      st = txt.find(" " + word + k + " ")+1
      k = ''
      ed = st + len(word)
      x = {'paragraph' : txt, 'offset_start':st, 'offset_end':ed}
      context_text = extract_context_words(x)
      features = bert_embedding_extract(context_text, word)
      if(features[0]=='None'):
          highlight.append(('None', '    '))
          return highlight
      prediction = lr_clf.predict(features.reshape(1, 768))
      prediction_probability = '{:.4f}'.format(round(lr_clf.predict_proba(features.reshape(1, 768))[:,1][0], 4))
      highlight.append((word, '    In-claim' if prediction==1 else 'Out-of-Claim'))
     # li.append([word,'    In-claim' if prediction==1 else 'Out-of-Claim', prediction_probability])
    else:
      highlight.append((word, '    '))
  #headers = ['numeral', 'prediction', 'probability']
  #dff = pd.DataFrame(li)
 # dff.columns = headers
  return highlight


##Summarization
summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY") 
def summarize_text(text):
    resp = summarizer(text)
    stext = resp[0]['summary_text']
    return stext


def split_in_sentences(text):
    doc = nlp(text)
    return [str(sent).strip() for sent in doc.sents]
def make_spans(text,results):
    results_list = []
    for i in range(len(results)):
        results_list.append(results[i]['label'])
    facts_spans = []
    facts_spans = list(zip(split_in_sentences(text),results_list))
    return facts_spans    
##Forward Looking Statement
fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
def fls(text):
    results = fls_model(split_in_sentences(text))
    return make_spans(text,results) 
    
##Company Extraction
ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
def fin_ner(text):
    replaced_spans = ner(text)
    return replaced_spans  
    
     
#CUAD STARTS    
def load_questions():
    questions = []
    with open('questions.txt') as f:
        questions = f.readlines()
    return questions


def load_questions_short():
    questions_short = []
    with open('questionshort.txt') as f:
        questions_short = f.readlines()
    return questions_short

def quad(query,file):
    with open(file.name) as f:
        paragraph = f.read()
    questions = load_questions()
    questions_short = load_questions_short()
    if (not len(paragraph)==0) and not (len(query)==0):
        print('getting predictions')
    predictions = run_prediction([query], paragraph, 'marshmellow77/roberta-base-cuad',n_best_size=5)
    answer = ""
    if predictions['0'] == "":
        answer = 'No answer found in document'
    else:
        with open("nbest.json") as jf:
            data = json.load(jf)
            for i in range(1):
                raw_answer=data['0'][i]['text']
                answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
                answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
    #summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
    #resp = summarizer(answer)
    #stext = resp[0]['summary_text']
    
   # highlight,dff=score_fincat(answer)
    return answer,summarize_text(answer),score_fincat(answer),get_sustainability(answer),fls(answer)    
                
                   
# b6 = gr.Button("Get Sustainability")
              #b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText())
              
              
#iface = gr.Interface(fn=get_sustainability, inputs="textbox", title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never")
#iface.launch()

iface = gr.Interface(fn=quad, inputs=[gr.inputs.Textbox(label='SEARCH QUERY'),gr.inputs.File(label='TXT FILE')], title="CONBERT",description="SUSTAINABILITY TOOL",article='Article', outputs=[gr.outputs.Textbox(label='Answer'),gr.outputs.Textbox(label='Summary'),gr.HighlightedText(label='CLAIM'),gr.HighlightedText(label='SUSTAINABILITY'),gr.HighlightedText(label='FLS')], allow_flagging="never")


iface.launch()