nickmuchi's picture
Update functions.py
14eaae6
raw
history blame
33.2 kB
import whisper
import os
import random
import openai
import yt_dlp
from pytube import YouTube, extract
import pandas as pd
import plotly_express as px
import nltk
import plotly.graph_objects as go
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import streamlit as st
import en_core_web_lg
import validators
import re
import itertools
import numpy as np
from bs4 import BeautifulSoup
import base64, time
from annotated_text import annotated_text
import pickle, math
import wikipedia
from pyvis.network import Network
import torch
from pydub import AudioSegment
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings, HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.chains import QAGenerationChain
from langchain.callbacks import StreamlitCallbackHandler
from langchain.agents import OpenAIFunctionsAgent, AgentExecutor
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
from langchain.prompts import PromptTemplate
nltk.download('punkt')
from nltk import sent_tokenize
OPEN_AI_KEY = os.environ.get('OPEN_AI_KEY')
time_str = time.strftime("%d%m%Y-%H%M%S")
HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;
margin-bottom: 2.5rem">{}</div> """
###################### Functions #######################################################################################
#load all required models and cache
@st.cache_resource
def load_models():
'''Load and cache all the models to be used'''
q_model = ORTModelForSequenceClassification.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
ner_model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
kg_model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
kg_tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
q_tokenizer = AutoTokenizer.from_pretrained("nickmuchi/quantized-optimum-finbert-tone")
ner_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
emb_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xl')
sent_pipe = pipeline("text-classification",model=q_model, tokenizer=q_tokenizer)
sum_pipe = pipeline("summarization",model="philschmid/flan-t5-base-samsum",clean_up_tokenization_spaces=True)
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1') #cross-encoder/ms-marco-MiniLM-L-12-v2
sbert = SentenceTransformer('all-MiniLM-L6-v2')
return sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert
@st.cache_resource
def get_spacy():
nlp = en_core_web_lg.load()
return nlp
nlp = get_spacy()
sent_pipe, sum_pipe, ner_pipe, cross_encoder, kg_model, kg_tokenizer, emb_tokenizer, sbert = load_models()
@st.cache_data
def get_yt_audio(url):
'''Get YT video from given URL link'''
yt = YouTube(url)
title = yt.title
# Get the first available audio stream and download it
audio_stream = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
return audio_stream, title
@st.cache_data
def get_yt_audio_dl(url):
'''Back up for when pytube is down'''
temp_audio_file = os.path.join('output', 'audio')
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': temp_audio_file,
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=False)
title = info.get('title', None)
ydl.download([url])
#with open(temp_audio_file+'.mp3', 'rb') as file:
audio_file = os.path.join('output', 'audio.mp3')
return audio_file, title
@st.cache_data
def load_asr_model(model_name):
'''Load the open source whisper model in cases where the API is not working'''
model = whisper.load_model(model_name)
return model
@st.cache_data
def load_whisper_api(audio):
'''Transcribe YT audio to text using Open AI API'''
file = open(audio, "rb")
transcript = openai.Audio.translate("whisper-1", file)
return transcript
@st.cache_data
def transcribe_yt_video(url, py_tube=True):
'''Transcribe YouTube video'''
if py_tube:
audio_file, title = get_yt_audio(link)
print(f'audio_file:{audio_file}')
st.session_state['audio'] = audio_file
print(f"audio_file_session_state:{st.session_state['audio'] }")
#Get size of audio file
audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1)
#Check if file is > 24mb, if not then use Whisper API
if audio_size <= 25:
st.info("`Transcribing YT audio...`")
#Use whisper API
results = load_whisper_api(st.session_state['audio'])['text']
else:
st.warning('File size larger than 24mb, applying chunking and transcription',icon="โš ๏ธ")
song = AudioSegment.from_file(st.session_state['audio'], format='mp4')
# PyDub handles time in milliseconds
twenty_minutes = 20 * 60 * 1000
chunks = song[::twenty_minutes]
transcriptions = []
video_id = extract.video_id(link)
for i, chunk in enumerate(chunks):
chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4')
transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text'])
results = ','.join(transcriptions)
else:
audio_file, title = get_yt_audio_dl(link)
print(f'audio_file:{audio_file}')
st.session_state['audio'] = audio_file
print(f"audio_file_session_state:{st.session_state['audio'] }")
#Get size of audio file
audio_size = round(os.path.getsize(st.session_state['audio'])/(1024*1024),1)
#Check if file is > 24mb, if not then use Whisper API
if audio_size <= 25:
st.info("`Transcribing YT audio...`")
#Use whisper API
results = load_whisper_api(st.session_state['audio'])['text']
else:
st.warning('File size larger than 24mb, applying chunking and transcription',icon="โš ๏ธ")
song = AudioSegment.from_file(st.session_state['audio'], format='mp4')
# PyDub handles time in milliseconds
twenty_minutes = 20 * 60 * 1000
chunks = song[::twenty_minutes]
transcriptions = []
video_id = extract.video_id(link)
for i, chunk in enumerate(chunks):
chunk.export(f'output/chunk_{i}_{video_id}.mp4', format='mp4')
transcriptions.append(load_whisper_api(f'output/chunk_{i}_{video_id}.mp4')['text'])
results = ','.join(transcriptions)
st.info("`YT Video transcription process complete...`")
return results, title
@st.cache_data
def inference(link, upload, _asr_model):
'''Convert Youtube video or Audio upload to text'''
try:
if validators.url(link):
st.info("`Downloading YT audio...`")
results, title = transcribe_yt_video(link)
return results, title
elif _upload:
#Get size of audio file
audio_size = round(os.path.getsize(_upload)/(1024*1024),1)
#Check if file is > 24mb, if not then use Whisper API
if audio_size <= 25:
st.info("`Transcribing uploaded audio...`")
#Use whisper API
results = load_whisper_api(_upload)['text']
else:
st.write('File size larger than 24mb, applying chunking and transcription')
song = AudioSegment.from_file(_upload)
# PyDub handles time in milliseconds
twenty_minutes = 20 * 60 * 1000
chunks = song[::twenty_minutes]
transcriptions = []
st.info("`Transcribing uploaded audio...`")
for i, chunk in enumerate(chunks):
chunk.export(f'output/chunk_{i}.mp4', format='mp4')
transcriptions.append(load_whisper_api(f'output/chunk_{i}.mp4')['text'])
results = ','.join(transcriptions)
st.info("`Uploaded audio transcription process complete...`")
return results, "Transcribed Earnings Audio"
except Exception as e:
st.error(f'''PyTube Error: {e},
Using yt_dlp module, might take longer than expected''',icon="๐Ÿšจ")
results, title = transcribe_yt_video(link, py_tube=False)
# results = _asr_model.transcribe(st.session_state['audio'], task='transcribe', language='en')
return results, title
@st.cache_data
def clean_text(text):
'''Clean all text after inference'''
text = text.encode("ascii", "ignore").decode() # unicode
text = re.sub(r"https*\S+", " ", text) # url
text = re.sub(r"@\S+", " ", text) # mentions
text = re.sub(r"#\S+", " ", text) # hastags
text = re.sub(r"\s{2,}", " ", text) # over spaces
return text
@st.cache_data
def chunk_long_text(text,threshold,window_size=3,stride=2):
'''Preprocess text and chunk for sentiment analysis'''
#Convert cleaned text into sentences
sentences = sent_tokenize(text)
out = []
#Limit the length of each sentence to a threshold
for chunk in sentences:
if len(chunk.split()) < threshold:
out.append(chunk)
else:
words = chunk.split()
num = int(len(words)/threshold)
for i in range(0,num*threshold+1,threshold):
out.append(' '.join(words[i:threshold+i]))
passages = []
#Combine sentences into a window of size window_size
for paragraph in [out]:
for start_idx in range(0, len(paragraph), stride):
end_idx = min(start_idx+window_size, len(paragraph))
passages.append(" ".join(paragraph[start_idx:end_idx]))
return passages
@st.cache_data
def sentiment_pipe(earnings_text):
'''Determine the sentiment of the text'''
earnings_sentences = chunk_long_text(earnings_text,150,1,1)
earnings_sentiment = sent_pipe(earnings_sentences)
return earnings_sentiment, earnings_sentences
@st.cache_data
def chunk_and_preprocess_text(text, model_name= 'philschmid/flan-t5-base-samsum'):
'''Chunk and preprocess text for summarization'''
tokenizer = AutoTokenizer.from_pretrained(model_name)
sentences = sent_tokenize(text)
# initialize
length = 0
chunk = ""
chunks = []
count = -1
for sentence in sentences:
count += 1
combined_length = len(tokenizer.tokenize(sentence)) + length # add the no. of sentence tokens to the length counter
if combined_length <= tokenizer.max_len_single_sentence: # if it doesn't exceed
chunk += sentence + " " # add the sentence to the chunk
length = combined_length # update the length counter
# if it is the last sentence
if count == len(sentences) - 1:
chunks.append(chunk) # save the chunk
else:
chunks.append(chunk) # save the chunk
# reset
length = 0
chunk = ""
# take care of the overflow sentence
chunk += sentence + " "
length = len(tokenizer.tokenize(sentence))
return chunks
@st.cache_data
def summarize_text(text_to_summarize,max_len,min_len):
'''Summarize text with HF model'''
summarized_text = sum_pipe(text_to_summarize,
max_length=max_len,
min_length=min_len,
do_sample=False,
early_stopping=True,
num_beams=4)
summarized_text = ' '.join([summ['summary_text'] for summ in summarized_text])
return summarized_text
@st.cache_data
def get_all_entities_per_sentence(text):
doc = nlp(''.join(text))
sentences = list(doc.sents)
entities_all_sentences = []
for sentence in sentences:
entities_this_sentence = []
# SPACY ENTITIES
for entity in sentence.ents:
entities_this_sentence.append(str(entity))
# XLM ENTITIES
entities_xlm = [entity["word"] for entity in ner_pipe(str(sentence))]
for entity in entities_xlm:
entities_this_sentence.append(str(entity))
entities_all_sentences.append(entities_this_sentence)
return entities_all_sentences
@st.cache_data
def get_all_entities(text):
all_entities_per_sentence = get_all_entities_per_sentence(text)
return list(itertools.chain.from_iterable(all_entities_per_sentence))
@st.cache_data
def get_and_compare_entities(article_content,summary_output):
all_entities_per_sentence = get_all_entities_per_sentence(article_content)
entities_article = list(itertools.chain.from_iterable(all_entities_per_sentence))
all_entities_per_sentence = get_all_entities_per_sentence(summary_output)
entities_summary = list(itertools.chain.from_iterable(all_entities_per_sentence))
matched_entities = []
unmatched_entities = []
for entity in entities_summary:
if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article):
matched_entities.append(entity)
elif any(
np.inner(sbert.encode(entity, show_progress_bar=False),
sbert.encode(art_entity, show_progress_bar=False)) > 0.9 for
art_entity in entities_article):
matched_entities.append(entity)
else:
unmatched_entities.append(entity)
matched_entities = list(dict.fromkeys(matched_entities))
unmatched_entities = list(dict.fromkeys(unmatched_entities))
matched_entities_to_remove = []
unmatched_entities_to_remove = []
for entity in matched_entities:
for substring_entity in matched_entities:
if entity != substring_entity and entity.lower() in substring_entity.lower():
matched_entities_to_remove.append(entity)
for entity in unmatched_entities:
for substring_entity in unmatched_entities:
if entity != substring_entity and entity.lower() in substring_entity.lower():
unmatched_entities_to_remove.append(entity)
matched_entities_to_remove = list(dict.fromkeys(matched_entities_to_remove))
unmatched_entities_to_remove = list(dict.fromkeys(unmatched_entities_to_remove))
for entity in matched_entities_to_remove:
matched_entities.remove(entity)
for entity in unmatched_entities_to_remove:
unmatched_entities.remove(entity)
return matched_entities, unmatched_entities
@st.cache_data
def highlight_entities(article_content,summary_output):
markdown_start_red = "<mark class=\"entity\" style=\"background: rgb(238, 135, 135);\">"
markdown_start_green = "<mark class=\"entity\" style=\"background: rgb(121, 236, 121);\">"
markdown_end = "</mark>"
matched_entities, unmatched_entities = get_and_compare_entities(article_content,summary_output)
for entity in matched_entities:
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_green + entity + markdown_end,summary_output)
for entity in unmatched_entities:
summary_output = re.sub(f'({entity})(?![^rgb\(]*\))',markdown_start_red + entity + markdown_end,summary_output)
print("")
print("")
soup = BeautifulSoup(summary_output, features="html.parser")
return HTML_WRAPPER.format(soup)
def summary_downloader(raw_text):
'''Download the summary generated'''
b64 = base64.b64encode(raw_text.encode()).decode()
new_filename = "new_text_file_{}_.txt".format(time_str)
st.markdown("#### Download Summary as a File ###")
href = f'<a href="data:file/txt;base64,{b64}" download="{new_filename}">Click to Download!!</a>'
st.markdown(href,unsafe_allow_html=True)
@st.cache_data
def generate_eval(raw_text, N, chunk):
# Generate N questions from context of chunk chars
# IN: text, N questions, chunk size to draw question from in the doc
# OUT: eval set as JSON list
# raw_text = ','.join(raw_text)
update = st.empty()
ques_update = st.empty()
update.info("`Generating sample questions ...`")
n = len(raw_text)
starting_indices = [random.randint(0, n-chunk) for _ in range(N)]
sub_sequences = [raw_text[i:i+chunk] for i in starting_indices]
chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
eval_set = []
for i, b in enumerate(sub_sequences):
try:
qa = chain.run(b)
eval_set.append(qa)
ques_update.info(f"Creating Question: {i+1}")
except Exception as e:
print(e)
st.warning(f'Error in generating Question: {i+1}...', icon="โš ๏ธ")
continue
eval_set_full = list(itertools.chain.from_iterable(eval_set))
update.empty()
ques_update.empty()
return eval_set_full
@st.cache_resource
def create_prompt_and_llm():
'''Create prompt'''
llm = ChatOpenAI(temperature=0, streaming=True, model="gpt-4")
message = SystemMessage(
content=(
"You are a helpful chatbot who is tasked with answering questions acuurately about earnings call transcript provided. "
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the earnings call transcript. "
"If there is any ambiguity, you probably assume they are about that."
"Do not use any information not provided in the earnings context and remember you are a to speak like a finance expert."
"If you don't know the answer, just say 'There is no relevant answer in the given earnings call transcript'"
"don't try to make up an answer"
)
)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=message,
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
return prompt, llm
@st.cache_resource
def gen_embeddings(embedding_model):
'''Generate embeddings for given model'''
if 'hkunlp' in embedding_model:
embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
embed_instruction='Represent the Financial paragraph for retrieval: ')
elif 'mpnet' in embedding_model:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
elif 'FlagEmbedding' in embedding_model:
encode_kwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model,
encode_kwargs = encode_kwargs
)
return embeddings
@st.cache_data
def create_vectorstore(corpus, title, embedding_model, chunk_size=1000, overlap=50):
'''Process text for Semantic Search'''
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
texts = text_splitter.split_text(corpus)
embeddings = gen_embeddings(embedding_model)
vectorstore = FAISS.from_texts(texts, embeddings, metadatas=[{"source": i} for i in range(len(texts))])
return vectorstore
@st.cache_data
def create_memory_and_agent(query,_docsearch):
'''Embed text and generate semantic search scores'''
#create vectorstore
vectorstore = _docsearch.as_retriever(search_kwargs={"k": 4})
#create retriever tool
tool = create_retriever_tool(
vectorstore,
"earnings_call_search",
"Searches and returns documents using the earnings context provided as a source, relevant to the user input question.",
)
tools = [tool]
prompt,llm = create_prompt_and_llm()
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
return_intermediate_steps=True,
)
memory = AgentTokenBufferMemory(llm=llm)
return memory, agent_executor
@st.cache_data
def gen_sentiment(text):
'''Generate sentiment of given text'''
return sent_pipe(text)[0]['label']
@st.cache_data
def gen_annotated_text(df):
'''Generate annotated text'''
tag_list=[]
for row in df.itertuples():
label = row[2]
text = row[1]
if label == 'Positive':
tag_list.append((text,label,'#8fce00'))
elif label == 'Negative':
tag_list.append((text,label,'#f44336'))
else:
tag_list.append((text,label,'#000000'))
return tag_list
def display_df_as_table(model,top_k,score='score'):
'''Display the df with text and scores as a table'''
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in model[0:top_k]],columns=['Score','Text'])
df['Score'] = round(df['Score'],2)
return df
def make_spans(text,results):
results_list = []
for i in range(len(results)):
results_list.append(results[i]['label'])
facts_spans = []
facts_spans = list(zip(sent_tokenizer(text),results_list))
return facts_spans
##Fiscal Sentiment by Sentence
def fin_ext(text):
results = remote_clx(sent_tokenizer(text))
return make_spans(text,results)
## Knowledge Graphs code
@st.cache_data
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
for token in text_replaced.split():
if token == "<triplet>":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
article_publish_date=None, verbose=False):
# tokenize whole text
inputs = tokenizer([text], return_tensors="pt")
# compute span boundaries
num_tokens = len(inputs["input_ids"][0])
if verbose:
print(f"Input has {num_tokens} tokens")
num_spans = math.ceil(num_tokens / span_length)
if verbose:
print(f"Input has {num_spans} spans")
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
spans_boundaries = []
start = 0
for i in range(num_spans):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
if verbose:
print(f"Span boundaries are {spans_boundaries}")
# transform input with spans
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids),
"attention_mask": torch.stack(tensor_masks)
}
# generate relations
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
generated_tokens = model.generate(
**inputs,
**gen_kwargs,
)
# decode relations
decoded_preds = tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# create kb
kb = KB()
i = 0
for sentence_pred in decoded_preds:
current_span_index = i // num_return_sequences
relations = extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {
article_url: {
"spans": [spans_boundaries[current_span_index]]
}
}
kb.add_relation(relation, article_title, article_publish_date)
i += 1
return kb
def get_article(url):
article = Article(url)
article.download()
article.parse()
return article
def from_url_to_kb(url, model, tokenizer):
article = get_article(url)
config = {
"article_title": article.title,
"article_publish_date": article.publish_date
}
kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
return kb
def get_news_links(query, lang="en", region="US", pages=1):
googlenews = GoogleNews(lang=lang, region=region)
googlenews.search(query)
all_urls = []
for page in range(pages):
googlenews.get_page(page)
all_urls += googlenews.get_links()
return list(set(all_urls))
def from_urls_to_kb(urls, model, tokenizer, verbose=False):
kb = KB()
if verbose:
print(f"{len(urls)} links to visit")
for url in urls:
if verbose:
print(f"Visiting {url}...")
try:
kb_url = from_url_to_kb(url, model, tokenizer)
kb.merge_with_kb(kb_url)
except ArticleException:
if verbose:
print(f" Couldn't download article at url {url}")
return kb
def save_network_html(kb, filename="network.html"):
# create network
net = Network(directed=True, width="700px", height="700px")
# nodes
color_entity = "#00FF00"
for e in kb.entities:
net.add_node(e, shape="circle", color=color_entity)
# edges
for r in kb.relations:
net.add_edge(r["head"], r["tail"],
title=r["type"], label=r["type"])
# save network
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
def save_kb(kb, filename):
with open(filename, "wb") as f:
pickle.dump(kb, f)
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == 'KB':
return KB
return super().find_class(module, name)
def load_kb(filename):
res = None
with open(filename, "rb") as f:
res = CustomUnpickler(f).load()
return res
class KB():
def __init__(self):
self.entities = {} # { entity_title: {...} }
self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
# meta: { article_url: { spans: [...] } } ]
self.sources = {} # { article_url: {...} }
def merge_with_kb(self, kb2):
for r in kb2.relations:
article_url = list(r["meta"].keys())[0]
source_data = kb2.sources[article_url]
self.add_relation(r, source_data["article_title"],
source_data["article_publish_date"])
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def merge_relations(self, r2):
r1 = [r for r in self.relations
if self.are_relations_equal(r2, r)][0]
# if different article
article_url = list(r2["meta"].keys())[0]
if article_url not in r1["meta"]:
r1["meta"][article_url] = r2["meta"][article_url]
# if existing article
else:
spans_to_add = [span for span in r2["meta"][article_url]["spans"]
if span not in r1["meta"][article_url]["spans"]]
r1["meta"][article_url]["spans"] += spans_to_add
def get_wikipedia_data(self, candidate_entity):
try:
page = wikipedia.page(candidate_entity, auto_suggest=False)
entity_data = {
"title": page.title,
"url": page.url,
"summary": page.summary
}
return entity_data
except:
return None
def add_entity(self, e):
self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
def add_relation(self, r, article_title, article_publish_date):
# check on wikipedia
candidate_entities = [r["head"], r["tail"]]
entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
# if one entity does not exist, stop
if any(ent is None for ent in entities):
return
# manage new entities
for e in entities:
self.add_entity(e)
# rename relation entities with their wikipedia titles
r["head"] = entities[0]["title"]
r["tail"] = entities[1]["title"]
# add source if not in kb
article_url = list(r["meta"].keys())[0]
if article_url not in self.sources:
self.sources[article_url] = {
"article_title": article_title,
"article_publish_date": article_publish_date
}
# manage new relation
if not self.exists_relation(r):
self.relations.append(r)
else:
self.merge_relations(r)
def get_textual_representation(self):
res = ""
res += "### Entities\n"
for e in self.entities.items():
# shorten summary
e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
res += f"- {e_temp}\n"
res += "\n"
res += "### Relations\n"
for r in self.relations:
res += f"- {r}\n"
res += "\n"
res += "### Sources\n"
for s in self.sources.items():
res += f"- {s}\n"
return res
def save_network_html(kb, filename="network.html"):
# create network
net = Network(directed=True, width="700px", height="700px", bgcolor="#eeeeee")
# nodes
color_entity = "#00FF00"
for e in kb.entities:
net.add_node(e, shape="circle", color=color_entity)
# edges
for r in kb.relations:
net.add_edge(r["head"], r["tail"],
title=r["type"], label=r["type"])
# save network
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)