Update functions.py
Browse filesAdded Knowledge Graph tab
- functions.py +323 -4
functions.py
CHANGED
@@ -6,7 +6,7 @@ import plotly_express as px
|
|
6 |
import nltk
|
7 |
import plotly.graph_objects as go
|
8 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
9 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
|
10 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
11 |
import streamlit as st
|
12 |
import en_core_web_lg
|
@@ -31,6 +31,8 @@ margin-bottom: 2.5rem">{}</div> """
|
|
31 |
def load_models():
|
32 |
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
|
33 |
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
|
|
|
|
|
34 |
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
|
35 |
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
|
36 |
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
|
@@ -38,7 +40,7 @@ def load_models():
|
|
38 |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
|
39 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
40 |
|
41 |
-
return sent_pipe, sum_pipe, ner_pipe, cross_encoder
|
42 |
|
43 |
@st.experimental_singleton(suppress_st_warning=True)
|
44 |
def load_asr_model(asr_model_name):
|
@@ -358,7 +360,324 @@ def make_spans(text,results):
|
|
358 |
def fin_ext(text):
|
359 |
results = remote_clx(sent_tokenizer(text))
|
360 |
return make_spans(text,results)
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
nlp = get_spacy()
|
363 |
-
sent_pipe, sum_pipe, ner_pipe, cross_encoder = load_models()
|
364 |
sbert = load_sbert('all-MiniLM-L12-v2')
|
|
|
6 |
import nltk
|
7 |
import plotly.graph_objects as go
|
8 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
9 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
|
10 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
11 |
import streamlit as st
|
12 |
import en_core_web_lg
|
|
|
31 |
def load_models():
|
32 |
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
|
33 |
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
|
34 |
+
kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
|
35 |
+
kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
|
36 |
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
|
37 |
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
|
38 |
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
|
|
|
40 |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
|
41 |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
42 |
|
43 |
+
return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer
|
44 |
|
45 |
@st.experimental_singleton(suppress_st_warning=True)
|
46 |
def load_asr_model(asr_model_name):
|
|
|
360 |
def fin_ext(text):
|
361 |
results = remote_clx(sent_tokenizer(text))
|
362 |
return make_spans(text,results)
|
363 |
+
|
364 |
+
## Knowledge Graphs code
|
365 |
+
|
366 |
+
def extract_relations_from_model_output(text):
|
367 |
+
relations = []
|
368 |
+
relation, subject, relation, object_ = '', '', '', ''
|
369 |
+
text = text.strip()
|
370 |
+
current = 'x'
|
371 |
+
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
|
372 |
+
for token in text_replaced.split():
|
373 |
+
if token == "<triplet>":
|
374 |
+
current = 't'
|
375 |
+
if relation != '':
|
376 |
+
relations.append({
|
377 |
+
'head': subject.strip(),
|
378 |
+
'type': relation.strip(),
|
379 |
+
'tail': object_.strip()
|
380 |
+
})
|
381 |
+
relation = ''
|
382 |
+
subject = ''
|
383 |
+
elif token == "<subj>":
|
384 |
+
current = 's'
|
385 |
+
if relation != '':
|
386 |
+
relations.append({
|
387 |
+
'head': subject.strip(),
|
388 |
+
'type': relation.strip(),
|
389 |
+
'tail': object_.strip()
|
390 |
+
})
|
391 |
+
object_ = ''
|
392 |
+
elif token == "<obj>":
|
393 |
+
current = 'o'
|
394 |
+
relation = ''
|
395 |
+
else:
|
396 |
+
if current == 't':
|
397 |
+
subject += ' ' + token
|
398 |
+
elif current == 's':
|
399 |
+
object_ += ' ' + token
|
400 |
+
elif current == 'o':
|
401 |
+
relation += ' ' + token
|
402 |
+
if subject != '' and relation != '' and object_ != '':
|
403 |
+
relations.append({
|
404 |
+
'head': subject.strip(),
|
405 |
+
'type': relation.strip(),
|
406 |
+
'tail': object_.strip()
|
407 |
+
})
|
408 |
+
return relations
|
409 |
+
|
410 |
+
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
|
411 |
+
article_publish_date=None, verbose=False):
|
412 |
+
# tokenize whole text
|
413 |
+
inputs = tokenizer([text], return_tensors="pt")
|
414 |
+
|
415 |
+
# compute span boundaries
|
416 |
+
num_tokens = len(inputs["input_ids"][0])
|
417 |
+
if verbose:
|
418 |
+
print(f"Input has {num_tokens} tokens")
|
419 |
+
num_spans = math.ceil(num_tokens / span_length)
|
420 |
+
if verbose:
|
421 |
+
print(f"Input has {num_spans} spans")
|
422 |
+
overlap = math.ceil((num_spans * span_length - num_tokens) /
|
423 |
+
max(num_spans - 1, 1))
|
424 |
+
spans_boundaries = []
|
425 |
+
start = 0
|
426 |
+
for i in range(num_spans):
|
427 |
+
spans_boundaries.append([start + span_length * i,
|
428 |
+
start + span_length * (i + 1)])
|
429 |
+
start -= overlap
|
430 |
+
if verbose:
|
431 |
+
print(f"Span boundaries are {spans_boundaries}")
|
432 |
+
|
433 |
+
# transform input with spans
|
434 |
+
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
|
435 |
+
for boundary in spans_boundaries]
|
436 |
+
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
|
437 |
+
for boundary in spans_boundaries]
|
438 |
+
inputs = {
|
439 |
+
"input_ids": torch.stack(tensor_ids),
|
440 |
+
"attention_mask": torch.stack(tensor_masks)
|
441 |
+
}
|
442 |
+
|
443 |
+
# generate relations
|
444 |
+
num_return_sequences = 3
|
445 |
+
gen_kwargs = {
|
446 |
+
"max_length": 256,
|
447 |
+
"length_penalty": 0,
|
448 |
+
"num_beams": 3,
|
449 |
+
"num_return_sequences": num_return_sequences
|
450 |
+
}
|
451 |
+
generated_tokens = model.generate(
|
452 |
+
**inputs,
|
453 |
+
**gen_kwargs,
|
454 |
+
)
|
455 |
+
|
456 |
+
# decode relations
|
457 |
+
decoded_preds = tokenizer.batch_decode(generated_tokens,
|
458 |
+
skip_special_tokens=False)
|
459 |
+
|
460 |
+
# create kb
|
461 |
+
kb = KB()
|
462 |
+
i = 0
|
463 |
+
for sentence_pred in decoded_preds:
|
464 |
+
current_span_index = i // num_return_sequences
|
465 |
+
relations = extract_relations_from_model_output(sentence_pred)
|
466 |
+
for relation in relations:
|
467 |
+
relation["meta"] = {
|
468 |
+
article_url: {
|
469 |
+
"spans": [spans_boundaries[current_span_index]]
|
470 |
+
}
|
471 |
+
}
|
472 |
+
kb.add_relation(relation, article_title, article_publish_date)
|
473 |
+
i += 1
|
474 |
+
|
475 |
+
return kb
|
476 |
+
|
477 |
+
def get_article(url):
|
478 |
+
article = Article(url)
|
479 |
+
article.download()
|
480 |
+
article.parse()
|
481 |
+
return article
|
482 |
+
|
483 |
+
def from_url_to_kb(url, model, tokenizer):
|
484 |
+
article = get_article(url)
|
485 |
+
config = {
|
486 |
+
"article_title": article.title,
|
487 |
+
"article_publish_date": article.publish_date
|
488 |
+
}
|
489 |
+
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
|
490 |
+
return kb
|
491 |
+
|
492 |
+
def get_news_links(query, lang="en", region="US", pages=1):
|
493 |
+
googlenews = GoogleNews(lang=lang, region=region)
|
494 |
+
googlenews.search(query)
|
495 |
+
all_urls = []
|
496 |
+
for page in range(pages):
|
497 |
+
googlenews.get_page(page)
|
498 |
+
all_urls += googlenews.get_links()
|
499 |
+
return list(set(all_urls))
|
500 |
+
|
501 |
+
def from_urls_to_kb(urls, model, tokenizer, verbose=False):
|
502 |
+
kb = KB()
|
503 |
+
if verbose:
|
504 |
+
print(f"{len(urls)} links to visit")
|
505 |
+
for url in urls:
|
506 |
+
if verbose:
|
507 |
+
print(f"Visiting {url}...")
|
508 |
+
try:
|
509 |
+
kb_url = from_url_to_kb(url, model, tokenizer)
|
510 |
+
kb.merge_with_kb(kb_url)
|
511 |
+
except ArticleException:
|
512 |
+
if verbose:
|
513 |
+
print(f" Couldn't download article at url {url}")
|
514 |
+
return kb
|
515 |
+
|
516 |
+
def save_network_html(kb, filename="network.html"):
|
517 |
+
# create network
|
518 |
+
net = Network(directed=True, width="700px", height="700px")
|
519 |
+
|
520 |
+
# nodes
|
521 |
+
color_entity = "#00FF00"
|
522 |
+
for e in kb.entities:
|
523 |
+
net.add_node(e, shape="circle", color=color_entity)
|
524 |
+
|
525 |
+
# edges
|
526 |
+
for r in kb.relations:
|
527 |
+
net.add_edge(r["head"], r["tail"],
|
528 |
+
title=r["type"], label=r["type"])
|
529 |
+
|
530 |
+
# save network
|
531 |
+
net.repulsion(
|
532 |
+
node_distance=200,
|
533 |
+
central_gravity=0.2,
|
534 |
+
spring_length=200,
|
535 |
+
spring_strength=0.05,
|
536 |
+
damping=0.09
|
537 |
+
)
|
538 |
+
net.set_edge_smooth('dynamic')
|
539 |
+
net.show(filename)
|
540 |
+
|
541 |
+
def save_kb(kb, filename):
|
542 |
+
with open(filename, "wb") as f:
|
543 |
+
pickle.dump(kb, f)
|
544 |
+
|
545 |
+
class CustomUnpickler(pickle.Unpickler):
|
546 |
+
def find_class(self, module, name):
|
547 |
+
if name == 'KB':
|
548 |
+
return KB
|
549 |
+
return super().find_class(module, name)
|
550 |
+
|
551 |
+
def load_kb(filename):
|
552 |
+
res = None
|
553 |
+
with open(filename, "rb") as f:
|
554 |
+
res = CustomUnpickler(f).load()
|
555 |
+
return res
|
556 |
+
|
557 |
+
class KB():
|
558 |
+
def __init__(self):
|
559 |
+
self.entities = {} # { entity_title: {...} }
|
560 |
+
self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
|
561 |
+
# meta: { article_url: { spans: [...] } } ]
|
562 |
+
self.sources = {} # { article_url: {...} }
|
563 |
+
|
564 |
+
def merge_with_kb(self, kb2):
|
565 |
+
for r in kb2.relations:
|
566 |
+
article_url = list(r["meta"].keys())[0]
|
567 |
+
source_data = kb2.sources[article_url]
|
568 |
+
self.add_relation(r, source_data["article_title"],
|
569 |
+
source_data["article_publish_date"])
|
570 |
+
|
571 |
+
def are_relations_equal(self, r1, r2):
|
572 |
+
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
|
573 |
+
|
574 |
+
def exists_relation(self, r1):
|
575 |
+
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
|
576 |
+
|
577 |
+
def merge_relations(self, r2):
|
578 |
+
r1 = [r for r in self.relations
|
579 |
+
if self.are_relations_equal(r2, r)][0]
|
580 |
+
|
581 |
+
# if different article
|
582 |
+
article_url = list(r2["meta"].keys())[0]
|
583 |
+
if article_url not in r1["meta"]:
|
584 |
+
r1["meta"][article_url] = r2["meta"][article_url]
|
585 |
+
|
586 |
+
# if existing article
|
587 |
+
else:
|
588 |
+
spans_to_add = [span for span in r2["meta"][article_url]["spans"]
|
589 |
+
if span not in r1["meta"][article_url]["spans"]]
|
590 |
+
r1["meta"][article_url]["spans"] += spans_to_add
|
591 |
+
|
592 |
+
def get_wikipedia_data(self, candidate_entity):
|
593 |
+
try:
|
594 |
+
page = wikipedia.page(candidate_entity, auto_suggest=False)
|
595 |
+
entity_data = {
|
596 |
+
"title": page.title,
|
597 |
+
"url": page.url,
|
598 |
+
"summary": page.summary
|
599 |
+
}
|
600 |
+
return entity_data
|
601 |
+
except:
|
602 |
+
return None
|
603 |
+
|
604 |
+
def add_entity(self, e):
|
605 |
+
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
|
606 |
+
|
607 |
+
def add_relation(self, r, article_title, article_publish_date):
|
608 |
+
# check on wikipedia
|
609 |
+
candidate_entities = [r["head"], r["tail"]]
|
610 |
+
entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
|
611 |
+
|
612 |
+
# if one entity does not exist, stop
|
613 |
+
if any(ent is None for ent in entities):
|
614 |
+
return
|
615 |
+
|
616 |
+
# manage new entities
|
617 |
+
for e in entities:
|
618 |
+
self.add_entity(e)
|
619 |
+
|
620 |
+
# rename relation entities with their wikipedia titles
|
621 |
+
r["head"] = entities[0]["title"]
|
622 |
+
r["tail"] = entities[1]["title"]
|
623 |
+
|
624 |
+
# add source if not in kb
|
625 |
+
article_url = list(r["meta"].keys())[0]
|
626 |
+
if article_url not in self.sources:
|
627 |
+
self.sources[article_url] = {
|
628 |
+
"article_title": article_title,
|
629 |
+
"article_publish_date": article_publish_date
|
630 |
+
}
|
631 |
+
|
632 |
+
# manage new relation
|
633 |
+
if not self.exists_relation(r):
|
634 |
+
self.relations.append(r)
|
635 |
+
else:
|
636 |
+
self.merge_relations(r)
|
637 |
+
|
638 |
+
def get_textual_representation(self):
|
639 |
+
res = ""
|
640 |
+
res += "### Entities\n"
|
641 |
+
for e in self.entities.items():
|
642 |
+
# shorten summary
|
643 |
+
e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
|
644 |
+
res += f"- {e_temp}\n"
|
645 |
+
res += "\n"
|
646 |
+
res += "### Relations\n"
|
647 |
+
for r in self.relations:
|
648 |
+
res += f"- {r}\n"
|
649 |
+
res += "\n"
|
650 |
+
res += "### Sources\n"
|
651 |
+
for s in self.sources.items():
|
652 |
+
res += f"- {s}\n"
|
653 |
+
return res
|
654 |
+
|
655 |
+
def save_network_html(kb, filename="network.html"):
|
656 |
+
# create network
|
657 |
+
net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")
|
658 |
+
|
659 |
+
# nodes
|
660 |
+
color_entity = "#00FF00"
|
661 |
+
for e in kb.entities:
|
662 |
+
net.add_node(e, shape="circle", color=color_entity)
|
663 |
+
|
664 |
+
# edges
|
665 |
+
for r in kb.relations:
|
666 |
+
net.add_edge(r["head"], r["tail"],
|
667 |
+
title=r["type"], label=r["type"])
|
668 |
+
|
669 |
+
# save network
|
670 |
+
net.repulsion(
|
671 |
+
node_distance=200,
|
672 |
+
central_gravity=0.2,
|
673 |
+
spring_length=200,
|
674 |
+
spring_strength=0.05,
|
675 |
+
damping=0.09
|
676 |
+
)
|
677 |
+
net.set_edge_smooth('dynamic')
|
678 |
+
net.show(filename)
|
679 |
+
|
680 |
+
|
681 |
nlp = get_spacy()
|
682 |
+
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer = load_models()
|
683 |
sbert = load_sbert('all-MiniLM-L12-v2')
|