from typing import Generator, Set, Union, List import requests from bs4 import BeautifulSoup, Tag, NavigableString, PageElement from concurrent.futures import ThreadPoolExecutor, as_completed SUPPORTED_MODEL_NAME_PAGES_FORMAT = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch" MAX_WORKERS = 10 BLACKLISTED_MODEL_NAMES = { "ykilcher/gpt-4chan", "bigscience/mt0-xxl", "bigscience/mt0-xl", "bigscience/mt0-large", "bigscience/mt0-base", "bigscience/mt0-small", } MIN_NUMBER_OF_DOWNLOADS = 100 MIN_NUMBER_OF_LIKES = 20 def get_model_name(model_card: Tag) -> str: h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600" h4_tag = model_card.find("h4", class_=h4_class) return h4_tag.text def is_a_number(s: PageElement) -> bool: s = s.text.strip().lower().replace("k", "").replace("m", "").replace(",", "").replace(".", "").replace("b", "") try: float(s) return True except ValueError: return False def get_numeric_contents(model_card): div: Union[Tag | NavigableString] = model_card.find( "div", class_="mr-1 flex items-center overflow-hidden whitespace-nowrap text-sm leading-tight text-gray-400", recursive=True ) contents: List[PageElement] = div.contents contents_without_tags: List[PageElement] = [content for content in contents if not isinstance(content, Tag)] number_contents: List[PageElement] = [content for content in contents_without_tags if is_a_number(content)] return number_contents def convert_to_int(element: PageElement) -> int: element_str = element.text.strip().lower() if element_str.endswith("k"): return int(float(element_str[:-1]) * 1_000) elif element_str.endswith("m"): return int(float(element_str[:-1]) * 1_000_000) elif element_str.endswith("b"): return int(float(element_str[:-1]) * 1_000_000_000) else: return int(element_str) def get_page(page_index: int): curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={page_index}" response = requests.get(curr_page_url) if response.status_code == 200: soup = BeautifulSoup(response.content, "html.parser") return soup return None def card_filter( model_card: Tag, model_name: str, min_number_of_downloads: int, min_number_of_likes: int, ) -> bool: if model_name in BLACKLISTED_MODEL_NAMES: return False numeric_contents = get_numeric_contents(model_card) if len(numeric_contents) < 2: # If the model card doesn't have at least 2 numeric contents, # It means that he doesn't have any downloads/likes, so it's not a valid model card. return False number_of_downloads = convert_to_int(numeric_contents[0]) if number_of_downloads < min_number_of_downloads: return False number_of_likes = convert_to_int(numeric_contents[1]) if number_of_likes < min_number_of_likes: return False return True def get_model_names( soup: BeautifulSoup, min_number_of_downloads: int, min_number_of_likes: int, ) -> Generator[str, None, None]: model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True) for model_card in model_cards: model_name = get_model_name(model_card) if card_filter( model_card=model_card, model_name=model_name, min_number_of_downloads=min_number_of_downloads, min_number_of_likes=min_number_of_likes ): yield model_name def generate_supported_model_names( min_number_of_downloads: int, min_number_of_likes: int, ) -> Generator[str, None, None]: with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: future_to_index = {executor.submit(get_page, index): index for index in range(100)} for future in as_completed(future_to_index): soup = future.result() if soup: yield from get_model_names( soup=soup, min_number_of_downloads=min_number_of_downloads, min_number_of_likes=min_number_of_likes, ) def get_supported_model_names( min_number_of_downloads: int = MIN_NUMBER_OF_DOWNLOADS, min_number_of_likes: int = MIN_NUMBER_OF_LIKES, ) -> Set[str]: return set(generate_supported_model_names( min_number_of_downloads=min_number_of_downloads, min_number_of_likes=min_number_of_likes, )) if __name__ == "__main__": supported_model_names = get_supported_model_names() print(f"Number of supported model names: {len(supported_model_names)}") print(f"Supported model names: {supported_model_names}")