arxiv_chatbot / chat /arxiv_bot /arxiv_bot_utils2.py
Artteiv's picture
Refactoring gemini functions (#9)
3826b3b verified
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from transformers import AutoModel
import json
from numpy.linalg import norm
import sqlite3
import urllib
from django.conf import settings
import Levenshtein
# this module act as a singleton class
class JinaAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, model):
super().__init__()
self.model = model
def __call__(self, input: Documents) -> Embeddings:
embeddings = self.model.encode(input)
return embeddings.tolist()
# instance of embedding_model
embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
trust_remote_code=True,
cache_dir='models')
# instance of JinaAIEmbeddingFunction
ef = JinaAIEmbeddingFunction(embedding_model)
# list of topics
topic_descriptions = json.load(open("topic_descriptions.txt"))
topics = list(dict.keys(topic_descriptions))
embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
def lev_sim(a,b): return Levenshtein.distance(a,b)
def choose_topic(summary):
embed = embedding_model.encode(summary)
topic = ""
max_sim = 0.
for i,key in enumerate(topics):
sim = cos_sim(embed,embeddings[i])
if sim > max_sim:
topic = key
max_sim = sim
return topic
def authors_list_to_str(authors):
"""input a list of authors, return a string represent authors"""
text = ""
for author in authors:
text+=author+", "
return text[:-3]
def authors_str_to_list(string):
"""input a string of authors, return a list of authors"""
authors = []
list_auth = string.split("and")
for author in list_auth:
if author != "et al.":
authors.append(author.strip())
return authors
def chunk_texts(text, max_char=400):
"""
Chunk a long text into several chunks, with each chunk about 300-400 characters long,
but make sure no word is cut in half.
Args:
text: The long text to be chunked.
max_char: The maximum number of characters per chunk (default: 400).
Returns:
A list of chunks.
"""
chunks = []
current_chunk = ""
words = text.split()
for word in words:
if len(current_chunk) + len(word) + 1 >= max_char:
chunks.append(current_chunk)
current_chunk = " "
else:
current_chunk += " " + word
chunks.append(current_chunk.strip())
return chunks
def trimming(txt):
start = txt.find("{")
end = txt.rfind("}")
return txt[start:end+1].replace("\n"," ")
# crawl data
def extract_tag(txt,tagname):
return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
def get_record(extract):
id = extract_tag(extract,"id")
updated = extract_tag(extract,"updated")
published = extract_tag(extract,"published")
title = extract_tag(extract,"title").replace("\n ","").strip()
summary = extract_tag(extract,"summary").replace("\n","").strip()
authors = []
while extract.find("<author>")!=-1:
author = extract_tag(extract,"name")
extract = extract[extract.find("</author>")+9:]
authors.append(author)
pattern = '<link title="pdf" href="'
link_start = extract.find('<link title="pdf" href="')
link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
return [id, updated, published, title, authors, link, summary]
def crawl_exact_paper(title,author,max_results=3):
authors = authors_list_to_str(author)
records = []
url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
url = url.replace(" ","%20")
try:
arxiv_page = urllib.request.urlopen(url,timeout=100).read()
xml = str(arxiv_page,encoding="utf-8")
while xml.find("<entry>") != -1:
extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
xml = xml[xml.find("</entry>")+8:]
extract = get_record(extract)
topic = choose_topic(extract[6])
records.append([topic,*extract])
return records
except Exception as e:
return "Error: "+str(e)
def crawl_arxiv(keyword_list, max_results=100):
baseurl = 'http://export.arxiv.org/api/query?search_query='
records = []
for i,keyword in enumerate(keyword_list):
if i ==0:
url = baseurl + 'all:' + keyword
else:
url = url + '+OR+' + 'all:' + keyword
url = url+ '&max_results=' + str(max_results)
url = url.replace(' ', '%20')
try:
arxiv_page = urllib.request.urlopen(url,timeout=100).read()
xml = str(arxiv_page,encoding="utf-8")
while xml.find("<entry>") != -1:
extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
xml = xml[xml.find("</entry>")+8:]
extract = get_record(extract)
topic = choose_topic(extract[6])
records.append([topic,*extract])
return records
except Exception as e:
return "Error: "+str(e)
# This class act as a module
class ArxivChroma:
"""
Create an interface to arxivdb, which only support query and addition.
This interface do not support edition and deletion procedures.
"""
client = None
model = None
collection = None
@staticmethod
def connect(table="arxiv_records", name="arxivdb/"):
ArxivChroma.client = chromadb.PersistentClient(name)
ArxivChroma.model = embedding_model
ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table,
embedding_function=JinaAIEmbeddingFunction(
model = ArxivChroma.model
))
@staticmethod
def query_relevant(keywords, query_texts, n_results=3):
"""
Perform a query using a list of keywords (str),
or using a relavant string
"""
contains = []
for keyword in keywords:
contains.append({"$contains":keyword.lower()})
return ArxivChroma.collection.query(
query_texts=query_texts,
where_document={
"$or":contains
},
n_results=n_results,
)
@staticmethod
def query_exact(id):
ids = ["{}_{}".format(id,j) for j in range(0,10)]
return ArxivChroma.collection.get(ids=ids)
@staticmethod
def add(crawl_records):
"""
Add crawl_records (list) obtained from arxiv_crawlers
A record is a list of 8 columns:
[topic, id, updated, published, title, author, link, summary]
Return the final length of the database table
"""
for record in crawl_records:
embed_text = """
Topic: {},
Title: {},
Summary: {}
""".format(record[0],record[4],record[7])
chunks = chunk_texts(embed_text)
ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
ArxivChroma.collection.add(
documents = chunks,
metadatas=paper_ids,
ids = ids
)
return ArxivChroma.collection.count()
@staticmethod
def close_connection():
pass
# This class act as a module
class ArxivSQL:
table = "arxivsql"
con = None
cur = None
@staticmethod
def connect(name="db.sqlite3"):
ArxivSQL.con = sqlite3.connect(name, check_same_thread=False)
ArxivSQL.cur = ArxivSQL.con.cursor()
@staticmethod
def query(title="", author=[], threshold = 15):
if len(author)>0:
query_author= " OR ".join([f"author LIKE '%{a}%'" for a in author])
else:
query_author= "True"
# Execute the query
query = f"select * from {ArxivSQL.table} where {query_author}"
results = ArxivSQL.cursor.execute(query).fetchall()
if len(title) == 0:
return results
else:
sim_score = {}
for row in results:
row_title = row[2]
row_id = row[0]
score = lev_sim(title, row_title)
if score < threshold:
sim_score[row_id] = score
sorted_results = sorted(sim_score.items(), key=lambda x: x[1])
return ArxivSQL.query_id(sorted_results)
@staticmethod
def query_id(ids=[]):
try:
if len(ids) == 0:
return None
query = "select * from {} where id in (".format(ArxivSQL.table)
for id in ids:
query+="'"+id+"',"
query = query[:-1] + ")"
result = ArxivSQL.cur.execute(query)
return result.fetchall()
except Exception as e:
print(e)
print("Error query: ",query)
@staticmethod
def add(crawl_records):
"""
Add crawl_records (list) obtained from arxiv_crawlers
A record is a list of 8 columns:
[topic, id, updated, published, title, author, link, summary]
Return the final length of the database table
"""
results = ""
for record in crawl_records:
try:
query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
record[1][21:],
record[0],
record[4].replace('"',"'"),
authors_list_to_str(record[5]),
record[2][:10],
record[3][:10],
record[6]
)
ArxivSQL.cur.execute(query)
ArxivSQL.con.commit()
except Exception as e:
results+=str(e)
results+="\n" + query + "\n"
finally:
return results