import math import pickle import re import bs4 import torch from GoogleNews import GoogleNews from tqdm import tqdm from kb import KB from newspaper import Article, ArticleException from pyvis.network import Network def extract_relations_from_model_output(text): relations = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' text_replaced = text.replace("", "").replace("", "").replace("", "") for token in text_replaced.split(): if token == "": current = 'o' relation = '' elif token == "": current = 's' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) object_ = '' elif token == "": current = 't' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) relation = '' subject = '' elif current == 'o': relation += f' {token}' elif current == 's': object_ += f' {token}' elif current == 't': subject += f' {token}' if subject != '' and relation != '' and object_ != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) return relations def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None, article_publish_date=None, verbose=False): # tokenize whole text print('Start tokenizing') inputs = tokenizer([text], return_tensors="pt") print('End tokenizing') # compute span boundaries num_tokens = len(inputs["input_ids"][0]) if verbose: print(f"Input has {num_tokens} tokens") num_spans = math.ceil(num_tokens / span_length) if verbose: print(f"Input has {num_spans} spans") overlap = math.ceil((num_spans * span_length - num_tokens) / max(num_spans - 1, 1)) spans_boundaries = [] start = 0 for i in tqdm(range(num_spans)): spans_boundaries.append([start + span_length * i, start + span_length * (i + 1)]) start -= overlap if verbose: print(f"Span boundaries are {spans_boundaries}") # transform input with spans tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries] inputs = { "input_ids": torch.stack(tensor_ids), "attention_mask": torch.stack(tensor_masks) } # generate relations num_return_sequences = 3 gen_kwargs = { "max_length": 256, "length_penalty": 0, "num_beams": 3, "num_return_sequences": num_return_sequences } generated_tokens = model.generate( **inputs, **gen_kwargs, ) # decode relations decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # create kb kb = KB() for i, sentence_pred in enumerate(decoded_preds): current_span_index = i // num_return_sequences relations = extract_relations_from_model_output(sentence_pred) for relation in relations: relation["meta"] = { article_url: { "spans": [spans_boundaries[current_span_index]] } } kb.add_relation(relation, article_title, article_publish_date) return kb def get_article(url): article = Article(url) article.download() article.parse() return article def from_url_to_kb(url, model, tokenizer): article = get_article(url) config = { "article_title": article.title, "article_publish_date": article.publish_date } return from_text_to_kb(article.text, model, tokenizer, article.url, **config) def get_news_links(query, lang="en", region="US", pages=1): googlenews = GoogleNews(lang=lang, region=region) googlenews.search(query) all_urls = [] for page in range(pages): googlenews.get_page(page) all_urls += googlenews.get_links() return list(set(all_urls)) def from_urls_to_kb(urls, model, tokenizer, verbose=False): kb = KB() if verbose: print(f"{len(urls)} links to visit") for url in urls: if verbose: print(f"Visiting {url}...") try: kb_url = from_url_to_kb(url, model, tokenizer) kb.merge_with_kb(kb_url) except ArticleException: if verbose: print(f" Couldn't download article at url {url}") return kb def save_network_html(kb, filename="network.html"): # create network net = Network(directed=True, width="700px", height="700px") # nodes color_entity = "#00FF00" for e in kb.entities: net.add_node(e, shape="circle", color=color_entity) # edges for r in kb.relations: net.add_edge(r["head"], r["tail"], title=r["type"], label=r["type"]) # save network net.repulsion( node_distance=200, central_gravity=0.2, spring_length=200, spring_strength=0.05, damping=0.09 ) net.set_edge_smooth('dynamic') net.show_buttons(filter_=['physics']) net.show(filename) def save_kb(kb, filename): with open(filename, "wb") as f: pickle.dump(kb, f) class CustomUnpickler(pickle.Unpickler): def find_class(self, module, name): return KB if name == 'KB' else super().find_class(module, name) def load_kb(filename): res = None with open(filename, "rb") as f: res = CustomUnpickler(f).load() return res def process_transcript(src, dist): with open(src, 'r') as src: html = bs4.BeautifulSoup(src.read()) transcript = html.findChildren('div', {'class': 'transcript-line'}) with open(dist, 'w') as dist: transcript_texts = map(lambda x: x.find('span', {'class': 'transcript-text'}).text, transcript) transcript_texts = map(lambda text: re.sub(r'\s(?=\s)','',re.sub(r'\s',' ', text)) , transcript_texts) text = ' '.join(transcript_texts) dist.write(text)