arxiv_chatbot / chat /arxiv_bot /arxiv_bot_utils.py
Artteiv's picture
Refactoring gemini functions (#9)
3826b3b verified
raw
history blame
10.5 kB
# import chromadb
# from chromadb import Documents, EmbeddingFunction, Embeddings
# from transformers import AutoModel
# import json
# from numpy.linalg import norm
# import sqlite3
# import urllib
# from django.conf import settings
# # this module act as a singleton class
# class JinaAIEmbeddingFunction(EmbeddingFunction):
# def __init__(self, model):
# super().__init__()
# self.model = model
# def __call__(self, input: Documents) -> Embeddings:
# embeddings = self.model.encode(input)
# return embeddings.tolist()
# # instance of embedding_model
# embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
# trust_remote_code=True,
# cache_dir='models')
# # instance of JinaAIEmbeddingFunction
# ef = JinaAIEmbeddingFunction(embedding_model)
# # list of topics
# topic_descriptions = json.load(open("topic_descriptions.txt"))
# topics = list(dict.keys(topic_descriptions))
# embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
# cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
# def choose_topic(summary):
# embed = embedding_model.encode(summary)
# topic = ""
# max_sim = 0.
# for i,key in enumerate(topics):
# sim = cos_sim(embed,embeddings[i])
# if sim > max_sim:
# topic = key
# max_sim = sim
# return topic
# def authors_list_to_str(authors):
# """input a list of authors, return a string represent authors"""
# text = ""
# for author in authors:
# text+=author+", "
# return text[:-3]
# def authors_str_to_list(string):
# """input a string of authors, return a list of authors"""
# authors = []
# list_auth = string.split("and")
# for author in list_auth:
# if author != "et al.":
# authors.append(author.strip())
# return authors
# def chunk_texts(text, max_char=400):
# """
# Chunk a long text into several chunks, with each chunk about 300-400 characters long,
# but make sure no word is cut in half.
# Args:
# text: The long text to be chunked.
# max_char: The maximum number of characters per chunk (default: 400).
# Returns:
# A list of chunks.
# """
# chunks = []
# current_chunk = ""
# words = text.split()
# for word in words:
# if len(current_chunk) + len(word) + 1 >= max_char:
# chunks.append(current_chunk)
# current_chunk = " "
# else:
# current_chunk += " " + word
# chunks.append(current_chunk.strip())
# return chunks
# def trimming(txt):
# start = txt.find("{")
# end = txt.rfind("}")
# return txt[start:end+1].replace("\n"," ")
# # crawl data
# def extract_tag(txt,tagname):
# return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
# def get_record(extract):
# id = extract_tag(extract,"id")
# updated = extract_tag(extract,"updated")
# published = extract_tag(extract,"published")
# title = extract_tag(extract,"title").replace("\n ","").strip()
# summary = extract_tag(extract,"summary").replace("\n","").strip()
# authors = []
# while extract.find("<author>")!=-1:
# author = extract_tag(extract,"name")
# extract = extract[extract.find("</author>")+9:]
# authors.append(author)
# pattern = '<link title="pdf" href="'
# link_start = extract.find('<link title="pdf" href="')
# link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
# return [id, updated, published, title, authors, link, summary]
# def crawl_exact_paper(title,author,max_results=3):
# authors = authors_list_to_str(author)
# records = []
# url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
# url = url.replace(" ","%20")
# try:
# arxiv_page = urllib.request.urlopen(url,timeout=100).read()
# xml = str(arxiv_page,encoding="utf-8")
# while xml.find("<entry>") != -1:
# extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
# xml = xml[xml.find("</entry>")+8:]
# extract = get_record(extract)
# topic = choose_topic(extract[6])
# records.append([topic,*extract])
# return records
# except Exception as e:
# return "Error: "+str(e)
# def crawl_arxiv(keyword_list, max_results=100):
# baseurl = 'http://export.arxiv.org/api/query?search_query='
# records = []
# for i,keyword in enumerate(keyword_list):
# if i ==0:
# url = baseurl + 'all:' + keyword
# else:
# url = url + '+OR+' + 'all:' + keyword
# url = url+ '&max_results=' + str(max_results)
# url = url.replace(' ', '%20')
# try:
# arxiv_page = urllib.request.urlopen(url,timeout=100).read()
# xml = str(arxiv_page,encoding="utf-8")
# while xml.find("<entry>") != -1:
# extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
# xml = xml[xml.find("</entry>")+8:]
# extract = get_record(extract)
# topic = choose_topic(extract[6])
# records.append([topic,*extract])
# return records
# except Exception as e:
# return "Error: "+str(e)
# class ArxivSQL:
# def __init__(self, table="arxivsql", name="db.sqlite3"):
# self.con = sqlite3.connect(name)
# self.cur = self.con.cursor()
# self.table = table
# def query(self, title="", author=[]):
# if len(title)>0:
# query_title = 'title like "%{}%"'.format(title)
# else:
# query_title = "True"
# if len(author)>0:
# query_author = 'authors like '
# for auth in author:
# query_author += "'%{}%' or ".format(auth)
# query_author = query_author[:-4]
# else:
# query_author = "True"
# query = "select * from {} where {} and {}".format(self.table,query_title,query_author)
# result = self.cur.execute(query)
# return result.fetchall()
# def query_id(self, ids=[]):
# try:
# if len(ids) == 0:
# return None
# query = "select * from {} where id in (".format(self.table)
# for id in ids:
# query+="'"+id+"',"
# query = query[:-1] + ")"
# result = self.cur.execute(query)
# return result.fetchall()
# except Exception as e:
# print(e)
# print("Error query: ",query)
# def add(self, crawl_records):
# """
# Add crawl_records (list) obtained from arxiv_crawlers
# A record is a list of 8 columns:
# [topic, id, updated, published, title, author, link, summary]
# Return the final length of the database table
# """
# results = ""
# for record in crawl_records:
# try:
# query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
# record[1][21:],
# record[0],
# record[4].replace('"',"'"),
# authors_list_to_str(record[5]),
# record[2][:10],
# record[3][:10],
# record[6]
# )
# self.cur.execute(query)
# self.con.commit()
# except Exception as e:
# result+=str(e)
# result+="\n" + query + "\n"
# finally:
# return results
# # instance of ArxivSQL
# sqldb = ArxivSQL()
# class ArxivChroma:
# """
# Create an interface to arxivdb, which only support query and addition.
# This interface do not support edition and deletion procedures.
# """
# def __init__(self, table="arxiv_records", name="arxivdb/"):
# self.client = chromadb.PersistentClient(name)
# self.model = embedding_model
# self.collection = self.client.get_or_create_collection(table,
# embedding_function=JinaAIEmbeddingFunction(
# model = self.model
# ))
# def query_relevant(self, keywords, query_texts, n_results=3):
# """
# Perform a query using a list of keywords (str),
# or using a relavant string
# """
# contains = []
# for keyword in keywords:
# contains.append({"$contains":keyword.lower()})
# return self.collection.query(
# query_texts=query_texts,
# where_document={
# "$or":contains
# },
# n_results=n_results,
# )
# def query_exact(self, id):
# ids = ["{}_{}".format(id,j) for j in range(0,10)]
# return self.collection.get(ids=ids)
# def add(self, crawl_records):
# """
# Add crawl_records (list) obtained from arxiv_crawlers
# A record is a list of 8 columns:
# [topic, id, updated, published, title, author, link, summary]
# Return the final length of the database table
# """
# for record in crawl_records:
# embed_text = """
# Topic: {},
# Title: {},
# Summary: {}
# """.format(record[0],record[4],record[7])
# chunks = chunk_texts(embed_text)
# ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
# paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
# self.collection.add(
# documents = chunks,
# metadatas=paper_ids,
# ids = ids
# )
# return self.collection.count()
# # instance of ArxivChroma
# db = ArxivChroma()