File size: 3,968 Bytes
3090b4b
 
 
1f0c01c
3090b4b
 
 
c2f789d
70b186b
 
c2f789d
10fdf8a
3090b4b
c2f789d
70b186b
3090b4b
 
c2f789d
 
3090b4b
c2f789d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3090b4b
c2f789d
3090b4b
 
 
 
 
 
c2f789d
3090b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f789d
3090b4b
 
 
 
 
 
 
8027117
9c74623
8027117
3090b4b
 
7c39280
3090b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10fdf8a
 
c2f789d
 
3090b4b
 
 
4c052dd
c2f789d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from predict import run_prediction
from io import StringIO
import json
import gradio as gr
import spacy
from spacy import displacy
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
import torch
import nltk
from nltk.tokenize import sent_tokenize
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
import en_core_web_sm


nlp = en_core_web_sm.load()
nltk.download('punkt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#SUSTAINABILITY STARTS
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
model_sustain = BERTClass(2, "sustanability")
model_sustain.to(device)
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])

def get_sustainability(text):
  df = pd.DataFrame({'sentence':sent_tokenize(text)})
  actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
  highlight = []
  for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
    if prob>=4.384316:
      highlight.append((sent, 'non-sustainable'))
    elif prob<=1.423736:
      highlight.append((sent, 'sustainable'))
    else:
      highlight.append((sent, '-'))
  return highlight
#SUSTAINABILITY ENDS

##Summarization 
def summarize_text(text):
    summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
    resp = summarizer(text)
    stext = resp[0]['summary_text']
    return stext

##Forward Looking Statement
def fls(text):
    fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
    results = fls_model(split_in_sentences(text))
    return make_spans(text,results) 
    
##Company Extraction
def fin_ner(text):
    ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
    replaced_spans = ner(text)
    return replaced_spans  
    
     
#CUAD STARTS    
def load_questions():
    questions = []
    with open('questions.txt') as f:
        questions = f.readlines()
    return questions


def load_questions_short():
    questions_short = []
    with open('questionshort.txt') as f:
        questions_short = f.readlines()
    return questions_short

def quad(query,file):
    with open(file.name) as f:
        paragraph = f.read()
    questions = load_questions()
    questions_short = load_questions_short()
    if (not len(paragraph)==0) and not (len(query)==0):
        print('getting predictions')
    predictions = run_prediction([query], paragraph, 'marshmellow77/roberta-base-cuad',n_best_size=5)
    answer = ""
    if predictions['0'] == "":
        answer = 'No answer found in document'
    else:
        with open("nbest.json") as jf:
            data = json.load(jf)
            for i in range(1):
                raw_answer=data['0'][i]['text']
                answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
                answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
    summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
    resp = summarizer(answer)
    stext = resp[0]['summary_text']
    return stext,answer            
                
                   
# b6 = gr.Button("Get Sustainability")
              #b6.click(get_sustainability, inputs = text, outputs = gr.HighlightedText())
              
              
#iface = gr.Interface(fn=get_sustainability, inputs="textbox", title="CONBERT",description="SUSTAINABILITY TOOL", outputs=gr.HighlightedText(), allow_flagging="never")
#iface.launch()

iface = gr.Interface(fn=quad, inputs=[gr.inputs.Textbox(label='SEARCH QUERY'),gr.inputs.File(label='TXT FILE')], title="CONBERT",description="SUSTAINABILITY TOOL",article='Article', outputs=[gr.outputs.Textbox(label='Answer'),gr.outputs.Textbox(label='Summary')], allow_flagging="never")
iface.launch()