|
from pysolr import Solr
|
|
import os
|
|
import csv
|
|
from sentence_transformers import SentenceTransformer, util
|
|
import torch
|
|
from datetime import datetime
|
|
from get_keywords import get_keywords
|
|
import os
|
|
import re
|
|
"""
|
|
This function creates top 15 articles from Solr and saves them in a csv file
|
|
Input:
|
|
query: str
|
|
num_articles: int
|
|
keyword_type: str (openai, rake, or na)
|
|
Output: path to csv file
|
|
"""
|
|
|
|
def sanitize_query(text):
|
|
"""Sanitize the query text for Solr."""
|
|
|
|
sanitized = re.sub(r'[[\]{}()*+?\\^|;:!]', ' ', text)
|
|
|
|
sanitized = ' '.join(sanitized.split())
|
|
return sanitized
|
|
|
|
def save_solr_articles_full(query: str, num_articles: int, keyword_type: str = "openai") -> str:
|
|
try:
|
|
keywords = get_keywords(query, keyword_type)
|
|
if keyword_type == "na":
|
|
keywords = query
|
|
|
|
keywords = sanitize_query(keywords)
|
|
|
|
return save_solr_articles(keywords, num_articles)
|
|
except Exception as e:
|
|
raise
|
|
|
|
|
|
"""
|
|
Removes spaces and newlines from text
|
|
Input: text: str
|
|
Output: text: str
|
|
"""
|
|
def remove_spaces_newlines(text: str) -> str:
|
|
text = text.replace('\n', ' ')
|
|
text = text.replace(' ', ' ')
|
|
return text
|
|
|
|
|
|
|
|
def truncate_article(text: str) -> str:
|
|
split = text.split()
|
|
if len(split) > 1500:
|
|
split = split[:1500]
|
|
text = ' '.join(split)
|
|
return text
|
|
|
|
|
|
"""
|
|
Searches Solr for articles based on keywords and saves them in a csv file
|
|
Input:
|
|
keywords: str
|
|
num_articles: int
|
|
Output: path to csv file
|
|
Minor details:
|
|
Removes duplicate articles to start with.
|
|
Articles with dead urls are removed since those articles are often wierd.
|
|
Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
|
|
If one of title, uuid, cleaned_content, url are missing the article is skipped.
|
|
"""
|
|
def save_solr_articles(keywords: str, num_articles=15) -> str:
|
|
"""Save top articles from Solr search to CSV."""
|
|
solr_key = os.getenv("SOLR_KEY")
|
|
SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
|
|
solr = Solr(SOLR_ARTICLES_URL, verify=False)
|
|
|
|
|
|
fq = ['-dups:0']
|
|
|
|
query = f'text:({keywords})' + " AND " + "dead_url:(false)"
|
|
|
|
|
|
outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)
|
|
|
|
article_count = 0
|
|
|
|
save_path = os.path.join("data", "articles.csv")
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
|
|
with open(save_path, 'w', newline='') as csvfile:
|
|
fieldnames = ['title', 'uuid', 'content', 'url', 'domain', 'published_date']
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
|
|
writer.writeheader()
|
|
|
|
title_five_words = set()
|
|
|
|
for d in outputs.docs:
|
|
if article_count == num_articles:
|
|
break
|
|
|
|
|
|
if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
|
|
continue
|
|
|
|
title_cleaned = remove_spaces_newlines(d['title'])
|
|
|
|
split = title_cleaned.split()
|
|
|
|
if not len(split) < 5:
|
|
five_words = title_cleaned.split()[:5]
|
|
five_words = ' '.join(five_words)
|
|
if five_words in title_five_words:
|
|
continue
|
|
title_five_words.add(five_words)
|
|
|
|
article_count += 1
|
|
|
|
cleaned_content = remove_spaces_newlines(d['cleaned_content'])
|
|
cleaned_content = truncate_article(cleaned_content)
|
|
|
|
domain = ""
|
|
if 'domain' not in d:
|
|
domain = "Not Specified"
|
|
else:
|
|
domain = d['domain']
|
|
|
|
raw_date = d.get('year_month_day', "Unknown Date")
|
|
|
|
|
|
if raw_date != "Unknown Date":
|
|
try:
|
|
publication_date = datetime.strptime(raw_date, "%Y-%m-%d").strftime("%m/%d/%Y")
|
|
except ValueError:
|
|
publication_date = "Invalid Date"
|
|
else:
|
|
publication_date = raw_date
|
|
|
|
writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
|
|
'domain': domain, 'published_date': publication_date})
|
|
|
|
return save_path
|
|
|
|
|
|
def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
|
|
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
|
|
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
|
hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
|
|
hits = hits[0]
|
|
corpus_ids = [item['corpus_id'] for item in hits]
|
|
r_contents = [contents[idx] for idx in corpus_ids]
|
|
r_titles = [titles[idx] for idx in corpus_ids]
|
|
r_uuids = [uuids[idx] for idx in corpus_ids]
|
|
r_urls = [urls[idx] for idx in corpus_ids]
|
|
|
|
save_path = os.path.join("data", "articles.csv")
|
|
if not os.path.exists(os.path.dirname(save_path)):
|
|
os.makedirs(os.path.dirname(save_path))
|
|
|
|
with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
|
|
fieldNames = ['title', 'uuid', 'content', 'url']
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
|
|
writer.writeheader()
|
|
for i in range(num_articles):
|
|
writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
|
|
return save_path |