File size: 26,412 Bytes
8499c35
 
 
 
 
 
 
 
e512522
8499c35
 
 
35a0403
77b71a6
be67fcf
 
8d7b496
d401ad6
6155281
0566fd9
 
 
a0ade06
08e6e30
 
8e2eef3
08e6e30
 
 
 
 
 
 
6155281
8499c35
 
6155281
8499c35
 
0be8860
8d6cc8d
cedea8d
 
 
08e6e30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f91c1
74f896e
8499c35
 
63f91c1
 
8499c35
 
e512522
 
8499c35
 
08e6e30
8499c35
f03dbdb
b7ef881
7e12771
aed9c6e
8499c35
aed9c6e
377fd6b
 
 
a31e67a
1ad3fab
377fd6b
74f896e
 
8e2eef3
63f91c1
 
1ad3fab
383e61a
74f896e
 
8499c35
8e2eef3
 
d9164b6
63f91c1
 
08e6e30
b1b5065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa7a64
 
08e6e30
daa7a64
 
08e6e30
 
5dfeae8
08e6e30
 
 
 
 
daa7a64
 
 
 
 
0ee4a6b
daa7a64
 
 
8186919
14ec610
8e2eef3
08e6e30
de88316
08e6e30
1b788f1
 
08e6e30
 
 
 
 
 
 
 
 
3365ede
08e6e30
 
de88316
08e6e30
8186919
 
63f91c1
 
 
 
 
3365ede
63f91c1
c57a7fa
08e6e30
8499c35
 
 
 
 
 
f67f4fa
8499c35
 
 
 
 
 
 
a31e67a
8499c35
34b166f
8499c35
 
a31e67a
8499c35
34b166f
8499c35
 
 
 
 
f4b6788
8499c35
 
 
7a5728d
93ef2af
 
 
 
 
 
 
 
 
6a1d689
93ef2af
6a1d689
93ef2af
8499c35
7a5728d
 
 
8499c35
 
 
 
 
 
7a5728d
 
 
f4b6788
63f91c1
8499c35
7a5728d
 
 
 
 
 
 
 
 
 
 
 
 
8499c35
 
7a5728d
 
 
f4b6788
8499c35
 
7a5728d
e1f6c5c
7a5728d
8499c35
 
 
 
 
 
 
 
edd60a3
 
8499c35
 
 
 
 
 
 
 
 
 
 
 
 
 
040ddf0
8499c35
 
 
 
 
 
edd60a3
 
8499c35
 
 
edd60a3
 
8499c35
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ce126
 
8499c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd60a3
8499c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd60a3
 
8499c35
 
 
 
 
 
 
 
 
 
 
edd60a3
e512522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edd60a3
9009cbe
aed9c6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
import whisper
import os
from pytube import YouTube
import pandas as pd
import plotly_express as px
import nltk
import plotly.graph_objects as go
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import streamlit as st
import en_core_web_lg
import validators
import re
import itertools
import numpy as np
from bs4 import BeautifulSoup   
import base64, time
from annotated_text import annotated_text
import pickle, math
import wikipedia
from pyvis.network import Network
import torch
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain.llms import OpenAI
from langchain import VectorDBQA
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from langchain.prompts.base import RegexParser

nltk.download('punkt')


from nltk import sent_tokenize

OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')
time_str = time.strftime("%d%m%Y-%H%M%S")
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; 
margin-bottom: 2.5rem">{}</div> """

#Stuff Chain Type Prompt template
output_parser = RegexParser(
    regex=r"(.*?)\nScore: (.*)",
    output_keys=["answer", "score"],
)

template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). 
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.

In addition to giving an answer, also return a score of how fully it answered the user's question. This should be in the following format:

Question: [question here]
Helpful Answer: [answer here]
Score: [score between 0 and 100]

Begin!

Context:
---------
{summaries}
---------
Question: {question}
Helpful Answer:"""

#Refine Chain Type Prompt Template
refine_prompt_template = (
    "The original question is as follows: {question}\n"
    "We have provided an existing answer: {existing_answer}\n"
    "We have the opportunity to refine the existing answer"
    "(only if needed) with some more context below.\n"
    "------------\n"
    "{context_str}\n"
    "------------\n"
    "Given the new context, refine the original answer to better "
    "answer the question. "
    "If the context isn't useful, return the original answer."
)
refine_prompt = PromptTemplate(
    input_variables=["question", "existing_answer", "context_str"],
    template=refine_prompt_template,
)


initial_qa_template = (
    "Context information is below. \n"
    "---------------------\n"
    "{context_str}"
    "\n---------------------\n"
    "Given the context information and not prior knowledge, "
    "answer the question: {question}\n.\n"
)

###################### Functions #######################################################################################

@st.experimental_singleton(suppress_st_warning=True)
def load_models():

    '''Load and cache all the models to be used'''
    q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
    ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
    kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
    kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
    q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
    ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
    emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
    sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
    sum_pipe = pipeline("summarization",model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn",clean_up_tokenization_spaces=True)
    ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
    cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2
    sbert = SentenceTransformer(all-MiniLM-L6-v2)
    
    return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert

@st.experimental_singleton(suppress_st_warning=True)
def load_asr_model(asr_model_name):
    asr_model = whisper.load_model(asr_model_name)
    
    return asr_model

@st.experimental_singleton(suppress_st_warning=True)
def process_corpus(corpus, _tokenizer, title, embedding_model, chunk_size=200, overlap=50):

    '''Process text for Semantic Search'''
    
    text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(_tokenizer,chunk_size=chunk_size,chunk_overlap=overlap,separator='.')

    texts = text_splitter.split_text(corpus)

    embeddings = gen_embeddings(embedding_model)

    docsearch = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))])

    return docsearch

@st.experimental_singleton(suppress_st_warning=True)
def chunk_and_preprocess_text(text,thresh=500):
    
    """Chunk text longer than n tokens for summarization"""
    
    sentences = sent_tokenize(clean_text(text))
    #sentences = [i.text for i in list(article.sents)]
    
    current_chunk = 0
    chunks = []
    
    for sentence in sentences:
        if len(chunks) == current_chunk + 1:
            if len(chunks[current_chunk]) + len(sentence.split(" ")) <= thresh:
                chunks[current_chunk].extend(sentence.split(" "))
            else:
                current_chunk += 1
                chunks.append(sentence.split(" "))
        else:
            chunks.append(sentence.split(" "))

    for chunk_id in range(len(chunks)):
        chunks[chunk_id] = " ".join(chunks[chunk_id])
    
    return chunks
    
@st.experimental_singleton(suppress_st_warning=True)
def gen_embeddings(embedding_model):

    '''Generate embeddings for given model'''
    
    if 'hkunlp' in embedding_model:
        
        embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
                                           query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
                                           embed_instruction='Represent the Financial paragraph for retrieval: ')

    else:
        
        embeddings = HuggingFaceEmbeddings(model_name=embedding_model)

    return embeddings
    
@st.experimental_memo(suppress_st_warning=True)
def embed_text(query,title,embedding_model,_emb_tok,docsearch,chain_type):
    
    '''Embed text and generate semantic search scores'''

    title = title.split()[0].lower()
                
    docs = docsearch.similarity_search_with_score(query, k=3)

    if chain_type == 'Normal':

        docs = [d[0] for d in docs]

        PROMPT = PromptTemplate(template=template, 
                                input_variables=["summaries", "question"],
                                output_parser=output_parser)
        
        chain = load_qa_with_sources_chain(OpenAI(temperature=0), 
                                           chain_type="stuff", 
                                           prompt=PROMPT, 
                                           )

        answer = chain({"input_documents": docs, "question": query}, return_only_outputs=False)


    elif chain_type == 'Refined':

        docs = [d[0] for d in docs]
        
        initial_qa_prompt = PromptTemplate(
    input_variables=["context_str", "question"], template=initial_qa_template
)
        chain = load_qa_chain(OpenAI(temperature=0), chain_type="refine", return_refine_steps=False,
                     question_prompt=initial_qa_prompt, refine_prompt=refine_prompt)
        answer = chain({"input_documents": docs, "question": query}, return_only_outputs=False)

    return answer
    
@st.experimental_singleton(suppress_st_warning=True)
def get_spacy():
    nlp = en_core_web_lg.load()
    return nlp
    
@st.experimental_memo(suppress_st_warning=True)
def inference(link, upload, _asr_model):
    '''Convert Youtube video or Audio upload to text'''
    
    if validators.url(link):
    
      yt = YouTube(link)
      title = yt.title
      path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
      results = _asr_model.transcribe(path, task='transcribe', language='en')
      
      return results['text'], yt.title
      
    elif upload:
      results = _asr_model.trasncribe(upload, task='transcribe', language='en')
      
      return results['text'], "Transcribed Earnings Audio"
      
@st.experimental_memo(suppress_st_warning=True)
def sentiment_pipe(earnings_text):
    '''Determine the sentiment of the text'''
    
    earnings_sentences = chunk_long_text(earnings_text,150,1,1)
    earnings_sentiment = sent_pipe(earnings_sentences)
    
    return earnings_sentiment, earnings_sentences    

@st.experimental_memo(suppress_st_warning=True)
def summarize_text(text_to_summarize,max_len,min_len):
    '''Summarize text with HF model'''
    
    summarized_text = sum_pipe(text_to_summarize,max_length=max_len,min_length=min_len,clean_up_tokenization_spaces=True,no_repeat_ngram_size=4,
           encoder_no_repeat_ngram_size=3,
           repetition_penalty=3.5,
           num_beams=4,
           early_stopping=True)
    summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
     
    return summarized_text
     
@st.experimental_memo(suppress_st_warning=True)
def clean_text(text):
    '''Clean all text'''

    text = text.encode("ascii", "ignore").decode()  # unicode
    text = re.sub(r"https*\S+", " ", text)  # url
    text = re.sub(r"@\S+", " ", text)  # mentions
    text = re.sub(r"#\S+", " ", text)  # hastags
    text = re.sub(r"\s{2,}", " ", text)  # over spaces
    
    return text
       
@st.experimental_memo(suppress_st_warning=True)
def chunk_long_text(text,threshold,window_size=3,stride=2):
    '''Preprocess text and chunk for sentiment analysis'''
    
    #Convert cleaned text into sentences
    sentences = sent_tokenize(text)
    out = []

    #Limit the length of each sentence to a threshold
    for chunk in sentences:
        if len(chunk.split()) < threshold:
            out.append(chunk)
        else:
            words = chunk.split()
            num = int(len(words)/threshold)
            for i in range(0,num*threshold+1,threshold):
                out.append(' '.join(words[i:threshold+i]))
    
    passages = []
    
    #Combine sentences into a window of size window_size
    for paragraph in [out]:
        for start_idx in range(0, len(paragraph), stride):
            end_idx = min(start_idx+window_size, len(paragraph))
            passages.append(" ".join(paragraph[start_idx:end_idx]))
            
    return passages   

    
def summary_downloader(raw_text):
    
	b64 = base64.b64encode(raw_text.encode()).decode()
	new_filename = "new_text_file_{}_.txt".format(time_str)
	st.markdown("#### Download Summary as a File ###")
	href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
	st.markdown(href,unsafe_allow_html=True)

@st.experimental_memo(suppress_st_warning=True) 	
def get_all_entities_per_sentence(text):
    doc = nlp(''.join(text))

    sentences = list(doc.sents)

    entities_all_sentences = []
    for sentence in sentences:
        entities_this_sentence = []

        # SPACY ENTITIES
        for entity in sentence.ents:
            entities_this_sentence.append(str(entity))

        # XLM ENTITIES
        entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))]
        for entity in entities_xlm:
            entities_this_sentence.append(str(entity))

        entities_all_sentences.append(entities_this_sentence)

    return entities_all_sentences
 
@st.experimental_memo(suppress_st_warning=True)    
def get_all_entities(text):
    all_entities_per_sentence = get_all_entities_per_sentence(text)
    return list(itertools.chain.from_iterable(all_entities_per_sentence))

@st.experimental_memo(suppress_st_warning=True)    
def get_and_compare_entities(article_content,summary_output):
    
    all_entities_per_sentence = get_all_entities_per_sentence(article_content)
    entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
   
    all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
    entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
   
    matched_entities = []
    unmatched_entities = []
    for entity in entities_summary:
        if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
            matched_entities.append(entity)
        elif any(
                np.inner(sbert.encode(entity, show_progress_bar=False),
                         sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for
                art_entity in entities_article):
            matched_entities.append(entity)
        else:
            unmatched_entities.append(entity)

    matched_entities = list(dict.fromkeys(matched_entities))
    unmatched_entities = list(dict.fromkeys(unmatched_entities))

    matched_entities_to_remove = []
    unmatched_entities_to_remove = []

    for entity in matched_entities:
        for substring_entity in matched_entities:
            if entity != substring_entity and entity.lower() in substring_entity.lower():
                matched_entities_to_remove.append(entity)

    for entity in unmatched_entities:
        for substring_entity in unmatched_entities:
            if entity != substring_entity and entity.lower() in substring_entity.lower():
                unmatched_entities_to_remove.append(entity)

    matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
    unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))

    for entity in matched_entities_to_remove:
        matched_entities.remove(entity)
    for entity in unmatched_entities_to_remove:
        unmatched_entities.remove(entity)

    return matched_entities, unmatched_entities

@st.experimental_memo(suppress_st_warning=True) 
def highlight_entities(article_content,summary_output):
   
    markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
    markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
    markdown_end = "</mark>"

    matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
    
    print(summary_output)

    for entity in matched_entities:
        summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output)

    for entity in unmatched_entities:
        summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output)
    
    print("")
    print(summary_output)
    
    print("")
    print(summary_output)
    
    soup = BeautifulSoup(summary_output, features="html.parser")

    return HTML_WRAPPER.format(soup)
    
    
def display_df_as_table(model,top_k,score='score'):
    '''Display the df with text and scores as a table'''
    
    df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
    df['Score'] = round(df['Score'],2)
    
    return df   

      
def make_spans(text,results):
    results_list = []
    for i in range(len(results)):
        results_list.append(results[i]['label'])
    facts_spans = []
    facts_spans = list(zip(sent_tokenizer(text),results_list))
    return facts_spans

##Fiscal Sentiment by Sentence
def fin_ext(text):
    results = remote_clx(sent_tokenizer(text))
    return make_spans(text,results)

## Knowledge Graphs code

def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
                    article_publish_date=None, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                article_url: {
                    "spans": [spans_boundaries[current_span_index]]
                }
            }
            kb.add_relation(relation, article_title, article_publish_date)
        i += 1

    return kb

def get_article(url):
    article = Article(url)
    article.download()
    article.parse()
    return article

def from_url_to_kb(url, model, tokenizer):
    article = get_article(url)
    config = {
        "article_title": article.title,
        "article_publish_date": article.publish_date
    }
    kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
    return kb

def get_news_links(query, lang="en", region="US", pages=1):
    googlenews = GoogleNews(lang=lang, region=region)
    googlenews.search(query)
    all_urls = []
    for page in range(pages):
        googlenews.get_page(page)
        all_urls += googlenews.get_links()
    return list(set(all_urls))

def from_urls_to_kb(urls, model, tokenizer, verbose=False):
    kb = KB()
    if verbose:
        print(f"{len(urls)} links to visit")
    for url in urls:
        if verbose:
            print(f"Visiting {url}...")
        try:
            kb_url = from_url_to_kb(url, model, tokenizer)
            kb.merge_with_kb(kb_url)
        except ArticleException:
            if verbose:
                print(f"  Couldn't download article at url {url}")
    return kb

def save_network_html(kb, filename="network.html"):
    # create network
    net = Network(directed=True, width="700px", height="700px")

    # nodes
    color_entity = "#00FF00"
    for e in kb.entities:
        net.add_node(e, shape="circle", color=color_entity)

    # edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])

    # save network
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)

def save_kb(kb, filename):
    with open(filename, "wb") as f:
        pickle.dump(kb, f)

class CustomUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if name == 'KB':
            return KB
        return super().find_class(module, name)

def load_kb(filename):
    res = None
    with open(filename, "rb") as f:
        res = CustomUnpickler(f).load()
    return res

class KB():
    def __init__(self):
        self.entities = {} # { entity_title: {...} }
        self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
          # meta: { article_url: { spans: [...] } } ]
        self.sources = {} # { article_url: {...} }

    def merge_with_kb(self, kb2):
        for r in kb2.relations:
            article_url = list(r["meta"].keys())[0]
            source_data = kb2.sources[article_url]
            self.add_relation(r, source_data["article_title"],
                              source_data["article_publish_date"])

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r2):
        r1 = [r for r in self.relations
              if self.are_relations_equal(r2, r)][0]

        # if different article
        article_url = list(r2["meta"].keys())[0]
        if article_url not in r1["meta"]:
            r1["meta"][article_url] = r2["meta"][article_url]

        # if existing article
        else:
            spans_to_add = [span for span in r2["meta"][article_url]["spans"]
                            if span not in r1["meta"][article_url]["spans"]]
            r1["meta"][article_url]["spans"] += spans_to_add

    def get_wikipedia_data(self, candidate_entity):
        try:
            page = wikipedia.page(candidate_entity, auto_suggest=False)
            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary
            }
            return entity_data
        except:
            return None

    def add_entity(self, e):
        self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}

    def add_relation(self, r, article_title, article_publish_date):
        # check on wikipedia
        candidate_entities = [r["head"], r["tail"]]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        # if one entity does not exist, stop
        if any(ent is None for ent in entities):
            return

        # manage new entities
        for e in entities:
            self.add_entity(e)

        # rename relation entities with their wikipedia titles
        r["head"] = entities[0]["title"]
        r["tail"] = entities[1]["title"]

        # add source if not in kb
        article_url = list(r["meta"].keys())[0]
        if article_url not in self.sources:
            self.sources[article_url] = {
                "article_title": article_title,
                "article_publish_date": article_publish_date
            }

        # manage new relation
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def get_textual_representation(self):
        res = ""
        res += "### Entities\n"
        for e in self.entities.items():
            # shorten summary
            e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
            res += f"- {e_temp}\n"
        res += "\n"
        res += "### Relations\n"
        for r in self.relations:
            res += f"- {r}\n"
        res += "\n"
        res += "### Sources\n"
        for s in self.sources.items():
            res += f"- {s}\n"
        return res
            
def save_network_html(kb, filename="network.html"):
    # create network
    net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")

    # nodes
    color_entity = "#00FF00"
    for e in kb.entities:
        net.add_node(e, shape="circle", color=color_entity)

    # edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])
        
    # save network
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)

nlp = get_spacy()    

sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert  = load_models()