yonikremer
commited on
Commit
•
7924ca5
1
Parent(s):
45e9ce6
added a supported model check
Browse files- hanlde_form_submit.py +11 -0
- supported_models.py +17 -2
hanlde_form_submit.py
CHANGED
@@ -6,6 +6,8 @@ from grouped_sampling import GroupedSamplingPipeLine
|
|
6 |
|
7 |
from download_repo import download_repository
|
8 |
from prompt_engeneering import rewrite_prompt
|
|
|
|
|
9 |
|
10 |
|
11 |
def is_downloaded(model_name: str) -> bool:
|
@@ -93,6 +95,15 @@ def on_form_submit(
|
|
93 |
TypeError: If the output length is not an integer or the prompt is not a string.
|
94 |
RuntimeError: If the model is not found.
|
95 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if len(prompt) == 0:
|
97 |
raise ValueError(f"The prompt must not be empty.")
|
98 |
st.write(f"Loading model: {model_name}...")
|
|
|
6 |
|
7 |
from download_repo import download_repository
|
8 |
from prompt_engeneering import rewrite_prompt
|
9 |
+
from supported_models import is_supported, SUPPORTED_MODEL_NAME_PAGES_FORMAT, BLACKLISTED_MODEL_NAMES, \
|
10 |
+
BLACKLISTED_ORGANIZATIONS
|
11 |
|
12 |
|
13 |
def is_downloaded(model_name: str) -> bool:
|
|
|
95 |
TypeError: If the output length is not an integer or the prompt is not a string.
|
96 |
RuntimeError: If the model is not found.
|
97 |
"""
|
98 |
+
if not is_supported(model_name, 1, 1):
|
99 |
+
raise ValueError(
|
100 |
+
f"The model: {model_name} is not supported."
|
101 |
+
f"The supported models are the models from {SUPPORTED_MODEL_NAME_PAGES_FORMAT}"
|
102 |
+
f" that satisfy the following conditions:\n"
|
103 |
+
f"1. The model has at least one like and one download.\n"
|
104 |
+
f"2. The model is not one of: {BLACKLISTED_MODEL_NAMES}.\n"
|
105 |
+
f"3. The model was not created any of those organizations: {BLACKLISTED_ORGANIZATIONS}.\n"
|
106 |
+
)
|
107 |
if len(prompt) == 0:
|
108 |
raise ValueError(f"The prompt must not be empty.")
|
109 |
st.write(f"Loading model: {model_name}...")
|
supported_models.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Generator, Set, Union, List, Optional
|
2 |
|
3 |
import requests
|
@@ -137,14 +138,28 @@ def generate_supported_model_names(
|
|
137 |
)
|
138 |
|
139 |
|
|
|
140 |
def get_supported_model_names(
|
141 |
min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
|
142 |
min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
|
143 |
) -> Set[str]:
|
144 |
-
return set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
min_number_of_downloads=min_number_of_downloads,
|
146 |
min_number_of_likes=min_number_of_likes,
|
147 |
-
)
|
148 |
|
149 |
|
150 |
if __name__ == "__main__":
|
|
|
1 |
+
from functools import lru_cache
|
2 |
from typing import Generator, Set, Union, List, Optional
|
3 |
|
4 |
import requests
|
|
|
138 |
)
|
139 |
|
140 |
|
141 |
+
@lru_cache
|
142 |
def get_supported_model_names(
|
143 |
min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
|
144 |
min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
|
145 |
) -> Set[str]:
|
146 |
+
return set(
|
147 |
+
generate_supported_model_names(
|
148 |
+
min_number_of_downloads=min_number_of_downloads,
|
149 |
+
min_number_of_likes=min_number_of_likes,
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def is_supported(
|
155 |
+
model_name: str,
|
156 |
+
min_number_of_downloads: int = DEFAULT_MIN_NUMBER_OF_DOWNLOADS,
|
157 |
+
min_number_of_likes: int = DEFAULT_MIN_NUMBER_OF_LIKES,
|
158 |
+
) -> bool:
|
159 |
+
return model_name in get_supported_model_names(
|
160 |
min_number_of_downloads=min_number_of_downloads,
|
161 |
min_number_of_likes=min_number_of_likes,
|
162 |
+
)
|
163 |
|
164 |
|
165 |
if __name__ == "__main__":
|