import requests from requests.adapters import HTTPAdapter, Retry import logging from typing import Union, Any, Optional import re """ Usage : get_paper_id("8-bit matrix multiplication for transformers at scale") -> 2106.09680 """ paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)') def retry_request_session(retries: Optional[int] = 5): # we setup retry strategy to retry on common errors retries = Retry( total=retries, backoff_factor=0.1, status_forcelist=[ 408, # request timeout 500, # internal server error 502, # bad gateway 503, # service unavailable 504 # gateway timeout ] ) # we setup a session with the retry strategy session = requests.Session() session.mount('https://', HTTPAdapter(max_retries=retries)) return session def get_paper_id(query: str, handle_not_found: bool = True): """Get the paper ID from a query. :param query: The query to search with :type query: str :param handle_not_found: Whether to return None if no paper is found, defaults to True :type handle_not_found: bool, optional :return: The paper ID :rtype: str """ special_chars = { ":": "%3A", "|": "%7C", ",": "%2C", " ": "+" } # create a translation table from the special_chars dictionary translation_table = query.maketrans(special_chars) # use the translate method to replace the special characters search_term = query.translate(translation_table) # init requests search session session = retry_request_session() # get the search results res = session.get(f"https://www.google.com/search?q={search_term}&sclient=gws-wiz-serp") try: # extract the paper id paper_id = paper_id_re.findall(res.text)[0] except IndexError: if handle_not_found: # if no paper is found, return None return None else: # if no paper is found, raise an error raise Exception(f'No paper found for query: {query}') return paper_id