import whisper import os import random import openai import yt_dlp from pytube import YouTube, extract import pandas as pd import plotly_express as px import nltk import plotly.graph_objects as go from optimum.onnxruntime import ORTModelForSequenceClassification from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM from sentence_transformers import SentenceTransformer, CrossEncoder, util import streamlit as st import en_core_web_lg import validators import re import itertools import numpy as np from bs4 import BeautifulSoup import base64, time from annotated_text import annotated_text import pickle, math import wikipedia from pyvis.network import Network import torch from pydub import AudioSegment from langchain.docstore.document import Document from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings from langchain.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chat_models import ChatOpenAI from langchain.chains import QAGenerationChain from langchain.callbacks import StreamlitCallbackHandler from langchain.agents import OpenAIFunctionsAgent, AgentExecutor from langchain.agents.agent_toolkits import create_retriever_tool from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( AgentTokenBufferMemory, ) from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import ( AIMessage, HumanMessage, SystemMessage ) from langchain.prompts import PromptTemplate nltk.download('punkt') from nltk import sent_tokenize OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY') time_str = time.strftime("%d%m%Y-%H%M%S") HTML_WRAPPER = """
{}
""" ###################### Functions ####################################################################################### #load all required models and cache @st.cache_resource def load_models(): '''Load and cache all the models to be used''' q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english") kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone") ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english") emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl') sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer) sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True) ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True) cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2 sbert = SentenceTransformer('all-MiniLM-L6-v2') return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert @st.cache_resource def get_spacy(): nlp = en_core_web_lg.load() return nlp nlp = get_spacy() sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models() @st.cache_data def get_yt_audio(url): '''Get YT video from given URL link''' yt = YouTube(url) title = yt.title # Get the first available audio stream and download it audio_stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download() return audio_stream, title @st.cache_data def get_yt_audio_dl(url): '''Back up for when pytube is down''' temp_audio_file = os.path.join('output', 'audio') ydl_opts = { 'format': 'bestaudio/best', 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192', }], 'outtmpl': temp_audio_file, 'quiet': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: info = ydl.extract_info(url, download=False) title = info.get('title', None) ydl.download([url]) #with open(temp_audio_file+'.mp3', 'rb') as file: audio_file = os.path.join('output', 'audio.mp3') return audio_file, title @st.cache_data def load_asr_model(model_name): '''Load the open source whisper model in cases where the API is not working''' model = whisper.load_model(model_name) return model @st.cache_data def load_whisper_api(audio): '''Transcribe YT audio to text using Open AI API''' file = open(audio, "rb") transcript = openai.Audio.translate("whisper-1", file) return transcript @st.cache_data def transcribe_yt_video(url, py_tube=True): '''Transcribe YouTube video''' if py_tube: audio_file, title = get_yt_audio(link) print(f'audio_file:{audio_file}') st.session_state['audio'] = audio_file print(f"audio_file_session_state:{st.session_state['audio'] }") #Get size of audio file audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1) #Check if file is > 24mb, if not then use Whisper API if audio_size <= 25: st.info("`Transcribing YT audio...`") #Use whisper API results = load_whisper_api(st.session_state['audio'])['text'] else: st.warning('File size larger than 24mb, applying chunking and transcription',icon="⚠️") song = AudioSegment.from_file(st.session_state['audio'], format='mp4') # PyDub handles time in milliseconds twenty_minutes = 20 * 60 * 1000 chunks = song[::twenty_minutes] transcriptions = [] video_id = extract.video_id(link) for i, chunk in enumerate(chunks): chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4') transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text']) results = ','.join(transcriptions) else: audio_file, title = get_yt_audio_dl(link) print(f'audio_file:{audio_file}') st.session_state['audio'] = audio_file print(f"audio_file_session_state:{st.session_state['audio'] }") #Get size of audio file audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1) #Check if file is > 24mb, if not then use Whisper API if audio_size <= 25: st.info("`Transcribing YT audio...`") #Use whisper API results = load_whisper_api(st.session_state['audio'])['text'] else: st.warning('File size larger than 24mb, applying chunking and transcription',icon="⚠️") song = AudioSegment.from_file(st.session_state['audio'], format='mp4') # PyDub handles time in milliseconds twenty_minutes = 20 * 60 * 1000 chunks = song[::twenty_minutes] transcriptions = [] video_id = extract.video_id(link) for i, chunk in enumerate(chunks): chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4') transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text']) results = ','.join(transcriptions) st.info("`YT Video transcription process complete...`") return results, title @st.cache_data def inference(link, upload, _asr_model): '''Convert Youtube video or Audio upload to text''' try: if validators.url(link): st.info("`Downloading YT audio...`") results, title = transcribe_yt_video(link) return results, title elif _upload: #Get size of audio file audio_size = round(os.path.getsize(_upload)/(1024*1024),1) #Check if file is > 24mb, if not then use Whisper API if audio_size <= 25: st.info("`Transcribing uploaded audio...`") #Use whisper API results = load_whisper_api(_upload)['text'] else: st.write('File size larger than 24mb, applying chunking and transcription') song = AudioSegment.from_file(_upload) # PyDub handles time in milliseconds twenty_minutes = 20 * 60 * 1000 chunks = song[::twenty_minutes] transcriptions = [] st.info("`Transcribing uploaded audio...`") for i, chunk in enumerate(chunks): chunk.export(f'output/chunk_{i}.mp4', format='mp4') transcriptions.append(load_whisper_api(f'output/chunk_{i}.mp4')['text']) results = ','.join(transcriptions) st.info("`Uploaded audio transcription process complete...`") return results, "Transcribed Earnings Audio" except Exception as e: st.error(f'''PyTube Error: {e}, Using yt_dlp module, might take longer than expected''',icon="🚨") results, title = transcribe_yt_video(link, py_tube=False) # results = _asr_model.transcribe(st.session_state['audio'], task='transcribe', language='en') return results, title @st.cache_data def clean_text(text): '''Clean all text after inference''' text = text.encode("ascii", "ignore").decode() # unicode text = re.sub(r"https*\S+", " ", text) # url text = re.sub(r"@\S+", " ", text) # mentions text = re.sub(r"#\S+", " ", text) # hastags text = re.sub(r"\s{2,}", " ", text) # over spaces return text @st.cache_data def chunk_long_text(text,threshold,window_size=3,stride=2): '''Preprocess text and chunk for sentiment analysis''' #Convert cleaned text into sentences sentences = sent_tokenize(text) out = [] #Limit the length of each sentence to a threshold for chunk in sentences: if len(chunk.split()) < threshold: out.append(chunk) else: words = chunk.split() num = int(len(words)/threshold) for i in range(0,num*threshold+1,threshold): out.append(' '.join(words[i:threshold+i])) passages = [] #Combine sentences into a window of size window_size for paragraph in [out]: for start_idx in range(0, len(paragraph), stride): end_idx = min(start_idx+window_size, len(paragraph)) passages.append(" ".join(paragraph[start_idx:end_idx])) return passages @st.cache_data def sentiment_pipe(earnings_text): '''Determine the sentiment of the text''' earnings_sentences = chunk_long_text(earnings_text,150,1,1) earnings_sentiment = sent_pipe(earnings_sentences) return earnings_sentiment, earnings_sentences @st.cache_data def chunk_and_preprocess_text(text, model_name= 'philschmid/flan-t5-base-samsum'): '''Chunk and preprocess text for summarization''' tokenizer = AutoTokenizer.from_pretrained(model_name) sentences = sent_tokenize(text) # initialize length = 0 chunk = "" chunks = [] count = -1 for sentence in sentences: count += 1 combined_length = len(tokenizer.tokenize(sentence)) + length # add the no. of sentence tokens to the length counter if combined_length <= tokenizer.max_len_single_sentence: # if it doesn't exceed chunk += sentence + " " # add the sentence to the chunk length = combined_length # update the length counter # if it is the last sentence if count == len(sentences) - 1: chunks.append(chunk) # save the chunk else: chunks.append(chunk) # save the chunk # reset length = 0 chunk = "" # take care of the overflow sentence chunk += sentence + " " length = len(tokenizer.tokenize(sentence)) return chunks @st.cache_data def summarize_text(text_to_summarize,max_len,min_len): '''Summarize text with HF model''' summarized_text = sum_pipe(text_to_summarize, max_length=max_len, min_length=min_len, do_sample=False, early_stopping=True, num_beams=4) summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text]) return summarized_text @st.cache_data def get_all_entities_per_sentence(text): doc = nlp(''.join(text)) sentences = list(doc.sents) entities_all_sentences = [] for sentence in sentences: entities_this_sentence = [] # SPACY ENTITIES for entity in sentence.ents: entities_this_sentence.append(str(entity)) # XLM ENTITIES entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))] for entity in entities_xlm: entities_this_sentence.append(str(entity)) entities_all_sentences.append(entities_this_sentence) return entities_all_sentences @st.cache_data def get_all_entities(text): all_entities_per_sentence = get_all_entities_per_sentence(text) return list(itertools.chain.from_iterable(all_entities_per_sentence)) @st.cache_data def get_and_compare_entities(article_content,summary_output): all_entities_per_sentence = get_all_entities_per_sentence(article_content) entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence)) all_entities_per_sentence = get_all_entities_per_sentence(summary_output) entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence)) matched_entities = [] unmatched_entities = [] for entity in entities_summary: if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): matched_entities.append(entity) elif any( np.inner(sbert.encode(entity, show_progress_bar=False), sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for art_entity in entities_article): matched_entities.append(entity) else: unmatched_entities.append(entity) matched_entities = list(dict.fromkeys(matched_entities)) unmatched_entities = list(dict.fromkeys(unmatched_entities)) matched_entities_to_remove = [] unmatched_entities_to_remove = [] for entity in matched_entities: for substring_entity in matched_entities: if entity != substring_entity and entity.lower() in substring_entity.lower(): matched_entities_to_remove.append(entity) for entity in unmatched_entities: for substring_entity in unmatched_entities: if entity != substring_entity and entity.lower() in substring_entity.lower(): unmatched_entities_to_remove.append(entity) matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove)) unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove)) for entity in matched_entities_to_remove: matched_entities.remove(entity) for entity in unmatched_entities_to_remove: unmatched_entities.remove(entity) return matched_entities, unmatched_entities @st.cache_data def highlight_entities(article_content,summary_output): markdown_start_red = "" markdown_start_green = "" markdown_end = "" matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output) for entity in matched_entities: summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output) for entity in unmatched_entities: summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output) print("") print("") soup = BeautifulSoup(summary_output, features="html.parser") return HTML_WRAPPER.format(soup) def summary_downloader(raw_text): '''Download the summary generated''' b64 = base64.b64encode(raw_text.encode()).decode() new_filename = "new_text_file_{}_.txt".format(time_str) st.markdown("#### Download Summary as a File ###") href = f'Click to Download!!' st.markdown(href,unsafe_allow_html=True) @st.cache_data def generate_eval(raw_text, N, chunk): # Generate N questions from context of chunk chars # IN: text, N questions, chunk size to draw question from in the doc # OUT: eval set as JSON list # raw_text = ','.join(raw_text) update = st.empty() ques_update = st.empty() update.info("`Generating sample questions ...`") n = len(raw_text) starting_indices = [random.randint(0, n-chunk) for _ in range(N)] sub_sequences = [raw_text[i:i+chunk] for i in starting_indices] chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0)) eval_set = [] for i, b in enumerate(sub_sequences): try: qa = chain.run(b) eval_set.append(qa) ques_update.info(f"Creating Question: {i+1}") except Exception as e: print(e) st.warning(f'Error in generating Question: {i+1}...', icon="⚠️") continue eval_set_full = list(itertools.chain.from_iterable(eval_set)) update.empty() ques_update.empty() return eval_set_full @st.cache_resource def create_prompt_and_llm(): '''Create prompt''' llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4") message = SystemMessage( content=( "You are a helpful chatbot who is tasked with answering questions acuurately about earnings call transcript provided. " "Unless otherwise explicitly stated, it is probably fair to assume that questions are about the earnings call transcript. " "If there is any ambiguity, you probably assume they are about that." "Do not use any information not provided in the earnings context and remember you are a to speak like a finance expert." "If you don't know the answer, just say 'There is no relevant answer in the given earnings call transcript'" "don't try to make up an answer" ) ) prompt = OpenAIFunctionsAgent.create_prompt( system_message=message, extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], ) return prompt, llm @st.cache_resource def gen_embeddings(embedding_model): '''Generate embeddings for given model''' if 'hkunlp' in embedding_model: embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model, query_instruction='Represent the Financial question for retrieving supporting paragraphs: ', embed_instruction='Represent the Financial paragraph for retrieval: ') elif 'mpnet' in embedding_model: embeddings = HuggingFaceEmbeddings(model_name=embedding_model) elif 'FlagEmbedding' in embedding_model: encode_kwargs = {'normalize_embeddings': True} embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model, encode_kwargs = encode_kwargs ) return embeddings @st.cache_data def create_vectorstore(corpus, title, embedding_model, chunk_size=1000, overlap=50): '''Process text for Semantic Search''' text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap) texts = text_splitter.split_text(corpus) embeddings = gen_embeddings(embedding_model) vectorstore = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))]) return vectorstore @st.cache_data def create_memory_and_agent(query,_docsearch): '''Embed text and generate semantic search scores''' #create vectorstore vectorstore = _docsearch.as_retriever(search_kwargs={"k": 4}) #create retriever tool tool = create_retriever_tool( vectorstore, "earnings_call_search", "Searches and returns documents using the earnings context provided as a source, relevant to the user input question.", ) tools = [tool] prompt,llm = create_prompt_and_llm() agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) agent_executor = AgentExecutor( agent=agent, tools=tools, verbose=True, return_intermediate_steps=True, ) memory = AgentTokenBufferMemory(llm=llm) return memory, agent_executor @st.cache_data def gen_sentiment(text): '''Generate sentiment of given text''' return sent_pipe(text)[0]['label'] @st.cache_data def gen_annotated_text(df): '''Generate annotated text''' tag_list=[] for row in df.itertuples(): label = row[2] text = row[1] if label == 'Positive': tag_list.append((text,label,'#8fce00')) elif label == 'Negative': tag_list.append((text,label,'#f44336')) else: tag_list.append((text,label,'#000000')) return tag_list def display_df_as_table(model,top_k,score='score'): '''Display the df with text and scores as a table''' df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text']) df['Score'] = round(df['Score'],2) return df def make_spans(text,results): results_list = [] for i in range(len(results)): results_list.append(results[i]['label']) facts_spans = [] facts_spans = list(zip(sent_tokenizer(text),results_list)) return facts_spans ##Fiscal Sentiment by Sentence def fin_ext(text): results = remote_clx(sent_tokenizer(text)) return make_spans(text,results) ## Knowledge Graphs code @st.cache_data def extract_relations_from_model_output(text): relations = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' text_replaced = text.replace("", "").replace("", "").replace("", "") for token in text_replaced.split(): if token == "": current = 't' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) relation = '' subject = '' elif token == "": current = 's' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) return relations def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None, article_publish_date=None, verbose=False): # tokenize whole text inputs = tokenizer([text], return_tensors="pt") # compute span boundaries num_tokens = len(inputs["input_ids"][0]) if verbose: print(f"Input has {num_tokens} tokens") num_spans = math.ceil(num_tokens / span_length) if verbose: print(f"Input has {num_spans} spans") overlap = math.ceil((num_spans * span_length - num_tokens) / max(num_spans - 1, 1)) spans_boundaries = [] start = 0 for i in range(num_spans): spans_boundaries.append([start + span_length * i, start + span_length * (i + 1)]) start -= overlap if verbose: print(f"Span boundaries are {spans_boundaries}") # transform input with spans tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] inputs = { "input_ids": torch.stack(tensor_ids), "attention_mask": torch.stack(tensor_masks) } # generate relations num_return_sequences = 3 gen_kwargs = { "max_length": 256, "length_penalty": 0, "num_beams": 3, "num_return_sequences": num_return_sequences } generated_tokens = model.generate( **inputs, **gen_kwargs, ) # decode relations decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # create kb kb = KB() i = 0 for sentence_pred in decoded_preds: current_span_index = i // num_return_sequences relations = extract_relations_from_model_output(sentence_pred) for relation in relations: relation["meta"] = { article_url: { "spans": [spans_boundaries[current_span_index]] } } kb.add_relation(relation, article_title, article_publish_date) i += 1 return kb def get_article(url): article = Article(url) article.download() article.parse() return article def from_url_to_kb(url, model, tokenizer): article = get_article(url) config = { "article_title": article.title, "article_publish_date": article.publish_date } kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config) return kb def get_news_links(query, lang="en", region="US", pages=1): googlenews = GoogleNews(lang=lang, region=region) googlenews.search(query) all_urls = [] for page in range(pages): googlenews.get_page(page) all_urls += googlenews.get_links() return list(set(all_urls)) def from_urls_to_kb(urls, model, tokenizer, verbose=False): kb = KB() if verbose: print(f"{len(urls)} links to visit") for url in urls: if verbose: print(f"Visiting {url}...") try: kb_url = from_url_to_kb(url, model, tokenizer) kb.merge_with_kb(kb_url) except ArticleException: if verbose: print(f" Couldn't download article at url {url}") return kb def save_network_html(kb, filename="network.html"): # create network net = Network(directed=True, width="700px", height="700px") # nodes color_entity = "#00FF00" for e in kb.entities: net.add_node(e, shape="circle", color=color_entity) # edges for r in kb.relations: net.add_edge(r["head"], r["tail"], title=r["type"], label=r["type"]) # save network net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show(filename) def save_kb(kb, filename): with open(filename, "wb") as f: pickle.dump(kb, f) class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): if name == 'KB': return KB return super().find_class(module, name) def load_kb(filename): res = None with open(filename, "rb") as f: res = CustomUnpickler(f).load() return res class KB(): def __init__(self): self.entities = {} # { entity_title: {...} } self.relations = [] # [ head: entity_title, type: ..., tail: entity_title, # meta: { article_url: { spans: [...] } } ] self.sources = {} # { article_url: {...} } def merge_with_kb(self, kb2): for r in kb2.relations: article_url = list(r["meta"].keys())[0] source_data = kb2.sources[article_url] self.add_relation(r, source_data["article_title"], source_data["article_publish_date"]) def are_relations_equal(self, r1, r2): return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) def exists_relation(self, r1): return any(self.are_relations_equal(r1, r2) for r2 in self.relations) def merge_relations(self, r2): r1 = [r for r in self.relations if self.are_relations_equal(r2, r)][0] # if different article article_url = list(r2["meta"].keys())[0] if article_url not in r1["meta"]: r1["meta"][article_url] = r2["meta"][article_url] # if existing article else: spans_to_add = [span for span in r2["meta"][article_url]["spans"] if span not in r1["meta"][article_url]["spans"]] r1["meta"][article_url]["spans"] += spans_to_add def get_wikipedia_data(self, candidate_entity): try: page = wikipedia.page(candidate_entity, auto_suggest=False) entity_data = { "title": page.title, "url": page.url, "summary": page.summary } return entity_data except: return None def add_entity(self, e): self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"} def add_relation(self, r, article_title, article_publish_date): # check on wikipedia candidate_entities = [r["head"], r["tail"]] entities = [self.get_wikipedia_data(ent) for ent in candidate_entities] # if one entity does not exist, stop if any(ent is None for ent in entities): return # manage new entities for e in entities: self.add_entity(e) # rename relation entities with their wikipedia titles r["head"] = entities[0]["title"] r["tail"] = entities[1]["title"] # add source if not in kb article_url = list(r["meta"].keys())[0] if article_url not in self.sources: self.sources[article_url] = { "article_title": article_title, "article_publish_date": article_publish_date } # manage new relation if not self.exists_relation(r): self.relations.append(r) else: self.merge_relations(r) def get_textual_representation(self): res = "" res += "### Entities\n" for e in self.entities.items(): # shorten summary e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()}) res += f"- {e_temp}\n" res += "\n" res += "### Relations\n" for r in self.relations: res += f"- {r}\n" res += "\n" res += "### Sources\n" for s in self.sources.items(): res += f"- {s}\n" return res def save_network_html(kb, filename="network.html"): # create network net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee") # nodes color_entity = "#00FF00" for e in kb.entities: net.add_node(e, shape="circle", color=color_entity) # edges for r in kb.relations: net.add_edge(r["head"], r["tail"], title=r["type"], label=r["type"]) # save network net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show(filename)