yonikremer commited on
Commit
05393a3
1 Parent(s): 3da4a00

deleted supported_models.py

Browse files
Files changed (1) hide show
  1. supported_models.py +0 -168
supported_models.py DELETED
@@ -1,168 +0,0 @@
1
- from functools import lru_cache
2
- from typing import Generator, Set, Union, List, Optional
3
-
4
- import requests
5
- from bs4 import BeautifulSoup, Tag, NavigableString, PageElement
6
- from concurrent.futures import ThreadPoolExecutor, as_completed
7
-
8
- SUPPORTED_MODEL_NAME_PAGES_FORMAT = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
9
- MAX_WORKERS = 10
10
- BLACKLISTED_MODEL_NAMES = {
11
- "ykilcher/gpt-4chan",
12
- "bigscience/mt0-xxl",
13
- "bigscience/mt0-xl",
14
- "bigscience/mt0-large",
15
- "bigscience/mt0-base",
16
- "bigscience/mt0-small",
17
- }
18
- BLACKLISTED_ORGANIZATIONS = {
19
- "huggingtweets"
20
- }
21
- DEFAULT_MIN_NUMBER_OF_DOWNLOADS = 100
22
- DEFAULT_MIN_NUMBER_OF_LIKES = 20
23
-
24
-
25
- def get_model_name(model_card: Tag) -> str:
26
- """returns the model name from the model card tag"""
27
- h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
28
- h4_tag = model_card.find("h4", class_=h4_class)
29
- return h4_tag.text
30
-
31
-
32
- def is_a_number(element: Union[PageElement, Tag]) -> bool:
33
- """returns True if the element is a number, False otherwise"""
34
- if isinstance(element, Tag):
35
- return False
36
- text = element.text
37
- lowered_text = text.strip().lower()
38
- no_characters_text = lowered_text.replace("k", "").replace("m", "").replace("b", "")
39
- element = no_characters_text.replace(",", "").replace(".", "")
40
- try:
41
- float(element)
42
- except ValueError:
43
- return False
44
- return True
45
-
46
-
47
- def get_numeric_contents(model_card: Tag) -> List[PageElement]:
48
- """returns the number of likes and downloads from the model card tag it they exist in the model card"""
49
- div: Union[Tag | NavigableString] = model_card.find(
50
- "div",
51
- class_="mr-1 flex items-center overflow-hidden whitespace-nowrap text-sm leading-tight text-gray-400",
52
- recursive=True
53
- )
54
- contents: List[PageElement] = div.contents
55
- number_contents: List[PageElement] = [content for content in contents if is_a_number(content)]
56
- return number_contents
57
-
58
-
59
- def convert_to_int(element: PageElement) -> int:
60
- """converts the element to an int"""
61
- element_str = element.text.strip().lower()
62
- if element_str.endswith("k"):
63
- return int(float(element_str[:-1]) * 1_000)
64
- elif element_str.endswith("m"):
65
- return int(float(element_str[:-1]) * 1_000_000)
66
- elif element_str.endswith("b"):
67
- return int(float(element_str[:-1]) * 1_000_000_000)
68
- return int(element_str)
69
-
70
-
71
- def get_page(page_index: int) -> Optional[BeautifulSoup]:
72
- """returns the page with the given index if it exists, None otherwise"""
73
- curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={page_index}"
74
- response = requests.get(curr_page_url)
75
- if response.status_code == 200:
76
- soup = BeautifulSoup(response.content, "html.parser")
77
- return soup
78
- return None
79
-
80
-
81
- def card_filter(
82
- model_card: Tag,
83
- model_name: str,
84
- min_number_of_downloads: int,
85
- min_number_of_likes: int,
86
- ) -> bool:
87
- """returns True if the model card is valid, False otherwise"""
88
- if model_name in BLACKLISTED_MODEL_NAMES:
89
- return False
90
- organization = model_name.split("/")[0]
91
- if organization in BLACKLISTED_ORGANIZATIONS:
92
- return False
93
- numeric_contents = get_numeric_contents(model_card)
94
- if len(numeric_contents) < 2:
95
- # If the model card doesn't have at least 2 numeric contents,
96
- # It means that he doesn't have any downloads/likes, so it's not a valid model card.
97
- return False
98
- number_of_downloads = convert_to_int(numeric_contents[0])
99
- if number_of_downloads < min_number_of_downloads:
100
- return False
101
- number_of_likes = convert_to_int(numeric_contents[1])
102
- if number_of_likes < min_number_of_likes:
103
- return False
104
- return True
105
-
106
-
107
- def get_model_names(
108
- soup: BeautifulSoup,
109
- min_number_of_downloads: int,
110
- min_number_of_likes: int,
111
- ) -> Generator[str, None, None]:
112
- """Scrapes the model names from the given soup"""
113
- model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
114
- for model_card in model_cards:
115
- model_name = get_model_name(model_card)
116
- if card_filter(
117
- model_card=model_card,
118
- model_name=model_name,
119
- min_number_of_downloads=min_number_of_downloads,
120
- min_number_of_likes=min_number_of_likes
121
- ):
122
- yield model_name
123
-
124
-
125
- def generate_supported_model_names(
126
- min_number_of_downloads: int,
127
- min_number_of_likes: int,
128
- ) -> Generator[str, None, None]:
129
- with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
130
- future_to_index = {executor.submit(get_page, index): index for index in range(300)}
131
- for future in as_completed(future_to_index):
132
- soup = future.result()
133
- if soup:
134
- yield from get_model_names(
135
- soup=soup,
136
- min_number_of_downloads=min_number_of_downloads,
137
- min_number_of_likes=min_number_of_likes,
138
- )
139
-
140
-
141
- @lru_cache
142
- def get_supported_model_names(
143
- min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
144
- min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
145
- ) -> Set[str]:
146
- return set(
147
- generate_supported_model_names(
148
- min_number_of_downloads=min_number_of_downloads,
149
- min_number_of_likes=min_number_of_likes,
150
- )
151
- )
152
-
153
-
154
- def is_supported(
155
- model_name: str,
156
- min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
157
- min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
158
- ) -> bool:
159
- return model_name in get_supported_model_names(
160
- min_number_of_downloads=min_number_of_downloads,
161
- min_number_of_likes=min_number_of_likes,
162
- )
163
-
164
-
165
- if __name__ == "__main__":
166
- supported_model_names = get_supported_model_names(1, 1)
167
- print(f"Number of supported model names: {len(supported_model_names)}")
168
- print(f"Supported model names: {supported_model_names}")