|
import chromadb |
|
from chromadb import Documents, EmbeddingFunction, Embeddings |
|
from transformers import AutoModel |
|
import json |
|
from numpy.linalg import norm |
|
import sqlite3 |
|
import urllib |
|
|
|
class JinaAIEmbeddingFunction(EmbeddingFunction): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
def __call__(self, input: Documents) -> Embeddings: |
|
embeddings = self.model.encode(input) |
|
return embeddings.tolist() |
|
|
|
class ArxivSQL: |
|
def __init__(self, table="arxivsql", name="arxiv_records_sql"): |
|
self.con = sqlite3.connect(name) |
|
self.cur = self.con.cursor() |
|
self.table = table |
|
|
|
def query(self, title="", author=[]): |
|
if len(title)>0: |
|
query_title = 'title like "%{}%"'.format(title) |
|
else: |
|
query_title = "True" |
|
if len(author)>0: |
|
query_author = 'author like ' |
|
for auth in author: |
|
query_author += "'%{}%' or ".format(auth) |
|
query_author = query_author[:-4] |
|
else: |
|
query_author = "True" |
|
query = "select * from {} where {} and {}".format(self.table,query_title,query_author) |
|
result = self.cur.execute(query) |
|
return result.fetchall() |
|
|
|
def query_id(self, ids=[]): |
|
query = "select * from {} where id in (".format(self.table) |
|
for id in ids: |
|
query+="'"+id+"'," |
|
query = query[:-1] + ")" |
|
result = self.cur.execute(query) |
|
return result.fetchall() |
|
|
|
def add(self, crawl_records): |
|
""" |
|
Add crawl_records (list) obtained from arxiv_crawlers |
|
A record is a list of 8 columns: |
|
[topic, id, updated, published, title, author, link, summary] |
|
Return the final length of the database table |
|
""" |
|
results = "" |
|
for record in crawl_records: |
|
try: |
|
query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format( |
|
record[1][21:], |
|
record[0], |
|
record[4].replace('"',"'"), |
|
process_authors_str(record[5]), |
|
record[2][:10], |
|
record[3][:10], |
|
record[6] |
|
) |
|
self.cur.execute(query) |
|
self.con.commit() |
|
except Exception as e: |
|
result+=str(e) |
|
result+="\n" + query + "\n" |
|
finally: |
|
return results |
|
class ArxivChroma: |
|
""" |
|
Create an interface to arxivdb, which only support query and addition. |
|
This interface do not support edition and deletion procedures. |
|
""" |
|
def __init__(self, table="arxiv_records", name="arxivdb/"): |
|
self.client = chromadb.PersistentClient(name) |
|
self.model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', |
|
trust_remote_code=True, |
|
cache_dir='models') |
|
self.collection = self.client.get_or_create_collection(table, |
|
embedding_function=JinaAIEmbeddingFunction( |
|
model = self.model |
|
)) |
|
|
|
def query_relevant(self, keywords, query_texts, n_results=3): |
|
""" |
|
Perform a query using a list of keywords (str), |
|
or using a relavant string |
|
""" |
|
contains = [] |
|
for keyword in keywords: |
|
contains.append({"$contains":keyword}) |
|
return self.collection.query( |
|
query_texts=query_texts, |
|
where_document={ |
|
"$or":contains |
|
}, |
|
n_results=n_results, |
|
) |
|
|
|
def query_exact(self, id): |
|
ids = ["{}_{}".format(id,j) for j in range(0,10)] |
|
return self.collection.get(ids=ids) |
|
|
|
def add(self, crawl_records): |
|
""" |
|
Add crawl_records (list) obtained from arxiv_crawlers |
|
A record is a list of 8 columns: |
|
[topic, id, updated, published, title, author, link, summary] |
|
Return the final length of the database table |
|
""" |
|
for record in crawl_records: |
|
embed_text = """ |
|
Topic: {}, |
|
Title: {}, |
|
Summary: {} |
|
""".format(record[0],record[4],record[7]) |
|
chunks = chunk_text_with_overlap(embed_text) |
|
ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))] |
|
paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))] |
|
self.collection.add( |
|
documents = chunks, |
|
metadatas=paper_ids, |
|
ids = ids |
|
) |
|
return self.collection.count() |
|
|
|
def chunk_text_with_overlap(text, max_char=400, overlap=100): |
|
""" |
|
Chunk a long text into several chunks, with each chunk about 300-400 characters long, |
|
but make sure no word is cut in half. It also ensures an overlap of a specified length |
|
between consecutive chunks. |
|
|
|
Args: |
|
text: The long text to be chunked. |
|
max_char: The maximum number of characters per chunk (default: 400). |
|
overlap: The desired overlap between consecutive chunks (default: 70). |
|
|
|
Returns: |
|
A list of chunks. |
|
""" |
|
chunks = [] |
|
current_chunk = "" |
|
words = text.split() |
|
for word in words: |
|
|
|
if len(current_chunk) + len(word) + 1 >= max_char: |
|
chunks.append(current_chunk) |
|
split_point = current_chunk.find(" ",len(current_chunk)-overlap) |
|
current_chunk = current_chunk[split_point:] + " " + word |
|
else: |
|
current_chunk += " " + word |
|
|
|
chunks.append(current_chunk.strip()) |
|
return chunks |
|
|
|
def trimming(txt): |
|
start = txt.find("{") |
|
end = txt.rfind("}") |
|
return txt[start:end+1] |
|
|
|
def extract_tag(txt,tagname): |
|
return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")] |
|
|
|
def get_record(extract): |
|
|
|
|
|
|
|
|
|
|
|
id = extract_tag(extract,"id") |
|
updated = extract_tag(extract,"updated") |
|
published = extract_tag(extract,"published") |
|
title = extract_tag(extract,"title").replace("\n ","").strip() |
|
summary = extract_tag(extract,"summary").replace("\n","").strip() |
|
authors = [] |
|
while extract.find("<author>")!=-1: |
|
|
|
author = extract_tag(extract,"name") |
|
extract = extract[extract.find("</author>")+9:] |
|
authors.append(author) |
|
pattern = '<link title="pdf" href="' |
|
link_start = extract.find('<link title="pdf" href="') |
|
link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2] |
|
return [id, updated, published, title, authors, link, summary] |
|
|
|
def choose_topic(summary): |
|
model_embedding = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', |
|
trust_remote_code=True, |
|
cache_dir='models') |
|
embed = model_embedding.encode(summary) |
|
cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b)) |
|
descriptions = json.load(open("topic_descriptions")) |
|
topic = "" |
|
max_sim = 0. |
|
for key in descriptions: |
|
sim = cos_sim(embed,model_embedding.encode(descriptions[key])) |
|
if sim > max_sim: |
|
topic = key |
|
max_sim = sim |
|
return topic |
|
|
|
def crawl_arxiv(keyword_list, max_results=100): |
|
baseurl = 'http://export.arxiv.org/api/query?search_query=' |
|
records = [] |
|
for keyword in keyword_list: |
|
if i ==0: |
|
url = baseurl + 'all:' + keyword |
|
i = i + 1 |
|
else: |
|
url = url + '+OR+' + 'all:' + keyword |
|
url = url+ '&max_results=' + str(max_results) |
|
url = url.replace(' ', '%20') |
|
try: |
|
arxiv_page = urllib.request.urlopen(url,timeout=100).read() |
|
arxiv_page = str(arxiv_page,encoding="utf-8") |
|
while xml.find("<entry>") != -1: |
|
extract = xml[xml.find("<entry>")+7:xml.find("</entry>")] |
|
xml = xml[xml.find("</entry>")+8:] |
|
extract = get_record(extract) |
|
topic = choose_topic(extract[6]) |
|
records.append([topic,*extract]) |
|
return records |
|
except Exception as e: |
|
return "Error: "+str(e) |
|
|
|
def process_authors_str(authors): |
|
"""input a list of authors, return a string represent authors""" |
|
text = "" |
|
for author in authors: |
|
text+=author+", " |
|
return text[:-3] |
|
|
|
def process_authors_list(string): |
|
"""input a string of authors, return a list of authors""" |
|
authors = [] |
|
list_auth = string.split("and").strip() |
|
for author in list_auth: |
|
if author != "et al.": |
|
authors.append(author) |
|
return authors |