# Generate references # 1. select most correlated references from "references" dataset or Arxiv search engine. # 2. Generate bibtex from the selected papers. --> to_bibtex() # 3. Generate prompts from the selected papers: --> to_prompts() # {"paper_id": "paper summary"} import requests import re ######################################################### # Some basic tools ######################################################### def remove_newlines(serie): serie = serie.replace('\n', ' ') serie = serie.replace('\\n', ' ') serie = serie.replace(' ', ' ') serie = serie.replace(' ', ' ') return serie ######################################################### # Semantic Scholar (SS) API ######################################################### def ss_search(keywords, limit=20, fields=None): # space between the query to be removed and replaced with + if fields is None: fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"] keywords = keywords.lower() keywords = keywords.replace(" ", "+") url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)}' # headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY} headers = {"Accept": "*/*"} response = requests.get(url, headers=headers, timeout=30) return response.json() def _collect_papers_ss(keyword, counts=3, tldr=False): def externalIds2link(externalIds): # externalIds is similar to "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}" if externalIds: # Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI # priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral) # DBLP dblp_id = externalIds.get('DBLP') if dblp_id is not None: dblp_link = f"dblp.org/rec/{dblp_id}" return dblp_link # arXiv arxiv_id = externalIds.get('ArXiv') if arxiv_id is not None: arxiv_link = f"arxiv.org/abs/{arxiv_id}" return arxiv_link return "" else: # if this is an empty dictionary, return an empty string return "" def extract_paper_id(last_name, year_str, title): return last_name + year_str + title.split(' ', 1)[0] def extract_author_info(raw_authors): authors = [author['name'] for author in raw_authors] authors_str = " and ".join(authors) last_name = authors[0].split()[-1] return authors_str, last_name def parse_search_results(search_results): # turn the search result to a list of paper dictionary. papers = [] for raw_paper in search_results: if raw_paper["abstract"] is None: continue authors_str, last_name = extract_author_info(raw_paper['authors']) year_str = str(raw_paper['year']) title = raw_paper['title'] journal = raw_paper['venue'] if not journal: journal = "arXiv preprint" paper_id = extract_paper_id(last_name, year_str, title).lower() link = externalIds2link(raw_paper['externalIds']) if tldr and raw_paper['tldr'] is not None: abstract = raw_paper['tldr']['text'] else: abstract = remove_newlines(raw_paper['abstract']) result = { "paper_id": paper_id, "title": title, "abstract": abstract, # todo: compare results with tldr "link": link, "authors": authors_str, "year": year_str, "journal": journal } papers.append(result) return papers raw_results = ss_search(keyword, limit=counts) if raw_results is not None: search_results = raw_results['data'] else: search_results = [] results = parse_search_results(search_results) return results ######################################################### # ArXiv API ######################################################### def _collect_papers_arxiv(keyword, counts=3, tldr=False): # Build the arXiv API query URL with the given keyword and other parameters def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"): base_url = "http://export.arxiv.org/api/query?" query = f"search_query=all:{keyword}&start=0&max_results={results_limit}" query += f"&sortBy={sort_by}&sortOrder={sort_order}" return base_url + query # Fetch search results from the arXiv API using the constructed URL def fetch_search_results(query_url): response = requests.get(query_url) return response.text # Parse the XML content of the API response to extract paper information def parse_results(content): from xml.etree import ElementTree as ET root = ET.fromstring(content) namespace = "{http://www.w3.org/2005/Atom}" entries = root.findall(f"{namespace}entry") results = [] for entry in entries: title = entry.find(f"{namespace}title").text link = entry.find(f"{namespace}id").text summary = entry.find(f"{namespace}summary").text summary = remove_newlines(summary) # Extract the authors authors = entry.findall(f"{namespace}author") author_list = [] for author in authors: name = author.find(f"{namespace}name").text author_list.append(name) authors_str = " and ".join(author_list) # Extract the year published = entry.find(f"{namespace}published").text year = published.split("-")[0] founds = re.search(r'\d+\.\d+', link) if founds is None: # some links are not standard; such as "https://arxiv.org/abs/cs/0603127v1". # will be solved in the future. continue else: arxiv_id = founds.group(0) journal = f"arXiv preprint arXiv:{arxiv_id}" result = { "paper_id": arxiv_id, "title": title, "link": link, "abstract": summary, "authors": authors_str, "year": year, "journal": journal } results.append(result) return results query_url = build_query_url(keyword, counts) content = fetch_search_results(query_url) results = parse_results(content) return results # Each `paper` is a dictionary containing (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal class References: def __init__(self, load_papers = ""): if load_papers: # todo: read a json file from the given path # this could be used to support pre-defined references pass else: self.papers = [] def collect_papers(self, keywords_dict, method="arxiv", tldr=False): """ keywords_dict: {"machine learning": 5, "language model": 2}; the first is the keyword, the second is how many references are needed. """ match method: case "arxiv": process =_collect_papers_arxiv case "ss": process = _collect_papers_ss case _: raise NotImplementedError("Other sources have not been not supported yet.") for key, counts in keywords_dict.items(): self.papers = self.papers + process(key, counts, tldr) seen = set() papers = [] for paper in self.papers: paper_id = paper["paper_id"] if paper_id not in seen: seen.add(paper_id) papers.append(paper) self.papers = papers def to_bibtex(self, path_to_bibtex="ref.bib"): """ Turn the saved paper list into bibtex file "ref.bib". Return a list of all `paper_id`. """ papers = self.papers # clear the bibtex file with open(path_to_bibtex, "w", encoding="utf-8") as file: file.write("") bibtex_entries = [] paper_ids = [] for paper in papers: bibtex_entry = f"""@article{{{paper["paper_id"]}, title = {{{paper["title"]}}}, author = {{{paper["authors"]}}}, journal={{{paper["journal"]}}}, year = {{{paper["year"]}}}, url = {{{paper["link"]}}} }}""" bibtex_entries.append(bibtex_entry) paper_ids.append(paper["paper_id"]) # Save the generated BibTeX entries to a file with open(path_to_bibtex, "a", encoding="utf-8") as file: file.write(bibtex_entry) file.write("\n\n") return paper_ids def to_prompts(self): # `prompts`: # {"paper1_bibtex_id": "paper_1_abstract", "paper2_bibtex_id": "paper2_abstract"} # this will be used to instruct GPT model to cite the correct bibtex entry. prompts = {} for paper in self.papers: prompts[paper["paper_id"]] = paper["abstract"] return prompts if __name__ == "__main__": refs = References() keywords_dict = { "Deep Q-Networks": 15, "Policy Gradient Methods": 24, "Actor-Critic Algorithms": 4, "Model-Based Reinforcement Learning": 13, "Exploration-Exploitation Trade-off": 7 } refs.collect_papers(keywords_dict, method="ss", tldr=True) for p in refs.papers: print(p["paper_id"]) print(len(refs.papers))