yonikremer
commited on
Commit
•
22e2fd1
1
Parent(s):
05393a3
adaptation to new versions of grouped-sampling
Browse files- hanlde_form_submit.py +5 -14
- tests.py +4 -8
hanlde_form_submit.py
CHANGED
@@ -2,12 +2,10 @@ import os
|
|
2 |
from time import time
|
3 |
|
4 |
import streamlit as st
|
5 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
6 |
|
7 |
from download_repo import download_pytorch_model
|
8 |
from prompt_engeneering import rewrite_prompt
|
9 |
-
from supported_models import is_supported, SUPPORTED_MODEL_NAME_PAGES_FORMAT, BLACKLISTED_MODEL_NAMES, \
|
10 |
-
BLACKLISTED_ORGANIZATIONS
|
11 |
|
12 |
|
13 |
def is_downloaded(model_name: str) -> bool:
|
@@ -94,17 +92,10 @@ def on_form_submit(
|
|
94 |
TypeError: If the output length is not an integer or the prompt is not a string.
|
95 |
RuntimeError: If the model is not found.
|
96 |
"""
|
97 |
-
if not is_supported(model_name, 1, 1):
|
98 |
-
raise ValueError(
|
99 |
-
f"The model: {model_name} is not supported."
|
100 |
-
f"The supported models are the models from {SUPPORTED_MODEL_NAME_PAGES_FORMAT}"
|
101 |
-
f" that satisfy the following conditions:\n"
|
102 |
-
f"1. The model has at least one like and one download.\n"
|
103 |
-
f"2. The model is not one of: {BLACKLISTED_MODEL_NAMES}.\n"
|
104 |
-
f"3. The model was not created any of those organizations: {BLACKLISTED_ORGANIZATIONS}.\n"
|
105 |
-
)
|
106 |
if len(prompt) == 0:
|
107 |
-
raise ValueError(
|
|
|
|
|
108 |
st.write(f"Loading model: {model_name}...")
|
109 |
loading_start_time = time()
|
110 |
pipeline = create_pipeline(
|
@@ -114,7 +105,7 @@ def on_form_submit(
|
|
114 |
loading_end_time = time()
|
115 |
loading_time = loading_end_time - loading_start_time
|
116 |
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
|
117 |
-
st.write(
|
118 |
generation_start_time = time()
|
119 |
generated_text = generate_text(
|
120 |
pipeline=pipeline,
|
|
|
2 |
from time import time
|
3 |
|
4 |
import streamlit as st
|
5 |
+
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
|
6 |
|
7 |
from download_repo import download_pytorch_model
|
8 |
from prompt_engeneering import rewrite_prompt
|
|
|
|
|
9 |
|
10 |
|
11 |
def is_downloaded(model_name: str) -> bool:
|
|
|
92 |
TypeError: If the output length is not an integer or the prompt is not a string.
|
93 |
RuntimeError: If the model is not found.
|
94 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if len(prompt) == 0:
|
96 |
+
raise ValueError("The prompt must not be empty.")
|
97 |
+
if not is_supported(model_name):
|
98 |
+
raise UnsupportedModelNameException(model_name)
|
99 |
st.write(f"Loading model: {model_name}...")
|
100 |
loading_start_time = time()
|
101 |
pipeline = create_pipeline(
|
|
|
105 |
loading_end_time = time()
|
106 |
loading_time = loading_end_time - loading_start_time
|
107 |
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
|
108 |
+
st.write("Generating text...")
|
109 |
generation_start_time = time()
|
110 |
generated_text = generate_text(
|
111 |
pipeline=pipeline,
|
tests.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import os
|
2 |
|
3 |
import pytest as pytest
|
4 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
5 |
|
6 |
from on_server_start import download_useful_models
|
7 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
8 |
from prompt_engeneering import rewrite_prompt
|
9 |
-
from supported_models import get_supported_model_names
|
10 |
|
11 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
12 |
|
@@ -21,7 +20,7 @@ def test_prompt_engineering():
|
|
21 |
|
22 |
|
23 |
def test_get_supported_model_names():
|
24 |
-
supported_model_names =
|
25 |
assert len(supported_model_names) > 0
|
26 |
assert "gpt2" in supported_model_names
|
27 |
assert all(isinstance(name, str) for name in supported_model_names)
|
@@ -44,16 +43,13 @@ def test_on_form_submit():
|
|
44 |
with pytest.raises(ValueError):
|
45 |
on_form_submit(model_name, output_length, empty_prompt, web_search=False)
|
46 |
unsupported_model_name = "unsupported_model_name"
|
47 |
-
with pytest.raises(
|
48 |
on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
|
49 |
|
50 |
|
51 |
@pytest.mark.parametrize(
|
52 |
"model_name",
|
53 |
-
|
54 |
-
min_number_of_downloads=1000,
|
55 |
-
min_number_of_likes=100,
|
56 |
-
)
|
57 |
)
|
58 |
def test_create_pipeline(model_name: str):
|
59 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|
|
|
1 |
import os
|
2 |
|
3 |
import pytest as pytest
|
4 |
+
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
|
5 |
|
6 |
from on_server_start import download_useful_models
|
7 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
8 |
from prompt_engeneering import rewrite_prompt
|
|
|
9 |
|
10 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
11 |
|
|
|
20 |
|
21 |
|
22 |
def test_get_supported_model_names():
|
23 |
+
supported_model_names = get_full_models_list()
|
24 |
assert len(supported_model_names) > 0
|
25 |
assert "gpt2" in supported_model_names
|
26 |
assert all(isinstance(name, str) for name in supported_model_names)
|
|
|
43 |
with pytest.raises(ValueError):
|
44 |
on_form_submit(model_name, output_length, empty_prompt, web_search=False)
|
45 |
unsupported_model_name = "unsupported_model_name"
|
46 |
+
with pytest.raises(UnsupportedModelNameException):
|
47 |
on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
|
48 |
|
49 |
|
50 |
@pytest.mark.parametrize(
|
51 |
"model_name",
|
52 |
+
get_full_models_list()[:3]
|
|
|
|
|
|
|
53 |
)
|
54 |
def test_create_pipeline(model_name: str):
|
55 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|