import chromadb from chromadb import Documents, EmbeddingFunction, Embeddings from transformers import AutoModel import json from numpy.linalg import norm import sqlite3 import urllib class JinaAIEmbeddingFunction(EmbeddingFunction): def __init__(self, model): super().__init__() self.model = model def __call__(self, input: Documents) -> Embeddings: embeddings = self.model.encode(input) return embeddings.tolist() class ArxivSQL: def __init__(self, table="arxivsql", name="arxiv_records_sql"): self.con = sqlite3.connect(name) self.cur = self.con.cursor() self.table = table def query(self, title="", author=[]): if len(title)>0: query_title = 'title like "%{}%"'.format(title) else: query_title = "True" if len(author)>0: query_author = 'authors like ' for auth in author: query_author += "'%{}%' or ".format(auth) query_author = query_author[:-4] else: query_author = "True" query = "select * from {} where {} and {}".format(self.table,query_title,query_author) result = self.cur.execute(query) return result.fetchall() def query_id(self, ids=[]): try: query = "select * from {} where id in (".format(self.table) for id in ids: query+="'"+id+"'," query = query[:-1] + ")" result = self.cur.execute(query) return result.fetchall() except Exception as e: print(e) print("Error query: ",query) def add(self, crawl_records): """ Add crawl_records (list) obtained from arxiv_crawlers A record is a list of 8 columns: [topic, id, updated, published, title, author, link, summary] Return the final length of the database table """ results = "" for record in crawl_records: try: query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format( record[1][21:], record[0], record[4].replace('"',"'"), process_authors_str(record[5]), record[2][:10], record[3][:10], record[6] ) self.cur.execute(query) self.con.commit() except Exception as e: result+=str(e) result+="\n" + query + "\n" finally: return results class ArxivChroma: """ Create an interface to arxivdb, which only support query and addition. This interface do not support edition and deletion procedures. """ def __init__(self, table="arxiv_records", name="arxivdb/"): self.client = chromadb.PersistentClient(name) self.model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True, cache_dir='models') self.collection = self.client.get_or_create_collection(table, embedding_function=JinaAIEmbeddingFunction( model = self.model )) def query_relevant(self, keywords, query_texts, n_results=3): """ Perform a query using a list of keywords (str), or using a relavant string """ contains = [] for keyword in keywords: contains.append({"$contains":keyword}) return self.collection.query( query_texts=query_texts, where_document={ "$or":contains }, n_results=n_results, ) def query_exact(self, id): ids = ["{}_{}".format(id,j) for j in range(0,10)] return self.collection.get(ids=ids) def add(self, crawl_records): """ Add crawl_records (list) obtained from arxiv_crawlers A record is a list of 8 columns: [topic, id, updated, published, title, author, link, summary] Return the final length of the database table """ for record in crawl_records: embed_text = """ Topic: {}, Title: {}, Summary: {} """.format(record[0],record[4],record[7]) chunks = chunk_text_with_overlap(embed_text) ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))] paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))] self.collection.add( documents = chunks, metadatas=paper_ids, ids = ids ) return self.collection.count() def chunk_text_with_overlap(text, max_char=400, overlap=100): """ Chunk a long text into several chunks, with each chunk about 300-400 characters long, but make sure no word is cut in half. It also ensures an overlap of a specified length between consecutive chunks. Args: text: The long text to be chunked. max_char: The maximum number of characters per chunk (default: 400). overlap: The desired overlap between consecutive chunks (default: 70). Returns: A list of chunks. """ chunks = [] current_chunk = "" words = text.split() for word in words: # Check if adding the word would exceed the chunk limit (including overlap) if len(current_chunk) + len(word) + 1 >= max_char: chunks.append(current_chunk) split_point = current_chunk.find(" ",len(current_chunk)-overlap) current_chunk = current_chunk[split_point:] + " " + word else: current_chunk += " " + word # Add the last chunk (including potential overlap) chunks.append(current_chunk.strip()) return chunks def trimming(txt): start = txt.find("{") end = txt.rfind("}") return txt[start:end+1].replace("\n"," ") # crawl data def extract_tag(txt,tagname): return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("")] def get_record(extract): # id = extract[extract.find("")+4:extract.find("")] # updated = extract[extract.find("")+9:extract.find("")] # published = extract[extract.find("")+11:extract.find("")] # title = extract[extract.find("")+7:extract.find("")] # summary = extract[extract.find("")+9:extract.find("")] id = extract_tag(extract,"id") updated = extract_tag(extract,"updated") published = extract_tag(extract,"published") title = extract_tag(extract,"title").replace("\n ","").strip() summary = extract_tag(extract,"summary").replace("\n","").strip() authors = [] while extract.find("")!=-1: # author = extract[extract.find("")+6:extract.find("")] author = extract_tag(extract,"name") extract = extract[extract.find("")+9:] authors.append(author) pattern = ' max_sim: topic = key max_sim = sim return topic class TopicClassifier: def __init__(self): self.model_embedding = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True, cache_dir='models') topic_descriptions = json.load(open("topic_descriptions.txt")) self.topics = list(dict.keys(topic_descriptions)) self.embeddings = [self.model_embedding.encode(topic_descriptions[key]) for key in topic_descriptions] self.cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b)) def classifier(self,description): embed = self.model_embedding.encode(description) max_sim = 0. topic = "" for i, key in enumerate(self.topics): sim = self.cos_sim(embed,self.embeddings[i]) if sim > max_sim: topic = key max_sim = sim return topic def crawl_exact_paper(title,author,max_results=3): authors = process_authors_str(author) records = [] url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results) url = url.replace(" ","%20") try: arxiv_page = urllib.request.urlopen(url,timeout=100).read() xml = str(arxiv_page,encoding="utf-8") while xml.find("") != -1: extract = xml[xml.find("")+7:xml.find("")] xml = xml[xml.find("")+8:] extract = get_record(extract) topic = choose_topic(extract[6]) records.append([topic,*extract]) return records except Exception as e: return "Error: "+str(e) def crawl_arxiv(keyword_list, max_results=100): baseurl = 'http://export.arxiv.org/api/query?search_query=' records = [] for i,keyword in enumerate(keyword_list): if i ==0: url = baseurl + 'all:' + keyword else: url = url + '+OR+' + 'all:' + keyword url = url+ '&max_results=' + str(max_results) url = url.replace(' ', '%20') try: arxiv_page = urllib.request.urlopen(url,timeout=100).read() xml = str(arxiv_page,encoding="utf-8") while xml.find("") != -1: extract = xml[xml.find("")+7:xml.find("")] xml = xml[xml.find("")+8:] extract = get_record(extract) topic = choose_topic(extract[6]) records.append([topic,*extract]) return records except Exception as e: return "Error: "+str(e) def process_authors_str(authors): """input a list of authors, return a string represent authors""" text = "" for author in authors: text+=author+", " return text[:-3] def process_authors_list(string): """input a string of authors, return a list of authors""" authors = [] list_auth = string.split("and") for author in list_auth: if author != "et al.": authors.append(author.strip()) return authors