yonikremer
commited on
Commit
•
df1b7f8
1
Parent(s):
70130da
added tests
Browse files
tests.py
CHANGED
@@ -4,7 +4,8 @@ import shutil
|
|
4 |
import pytest as pytest
|
5 |
from grouped_sampling import GroupedSamplingPipeLine
|
6 |
|
7 |
-
from
|
|
|
8 |
from prompt_engeneering import rewrite_prompt
|
9 |
from supported_models import get_supported_model_names
|
10 |
|
@@ -24,10 +25,40 @@ def test_get_supported_model_names():
|
|
24 |
supported_model_names = get_supported_model_names()
|
25 |
assert len(supported_model_names) > 0
|
26 |
assert "gpt2" in supported_model_names
|
27 |
-
assert all(
|
28 |
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def test_create_pipeline(model_name: str):
|
32 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|
33 |
assert pipeline is not None
|
|
|
4 |
import pytest as pytest
|
5 |
from grouped_sampling import GroupedSamplingPipeLine
|
6 |
|
7 |
+
from on_server_start import download_useful_models
|
8 |
+
from hanlde_form_submit import create_pipeline, on_form_submit
|
9 |
from prompt_engeneering import rewrite_prompt
|
10 |
from supported_models import get_supported_model_names
|
11 |
|
|
|
25 |
supported_model_names = get_supported_model_names()
|
26 |
assert len(supported_model_names) > 0
|
27 |
assert "gpt2" in supported_model_names
|
28 |
+
assert all(isinstance(name, str) for name in supported_model_names)
|
29 |
|
30 |
|
31 |
+
def test_on_server_start():
|
32 |
+
if os.path.exists(HUGGING_FACE_CACHE_DIR):
|
33 |
+
shutil.rmtree(HUGGING_FACE_CACHE_DIR)
|
34 |
+
assert not os.path.exists(HUGGING_FACE_CACHE_DIR)
|
35 |
+
download_useful_models()
|
36 |
+
assert os.path.exists(HUGGING_FACE_CACHE_DIR)
|
37 |
+
assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
|
38 |
+
|
39 |
+
|
40 |
+
def test_on_form_submit():
|
41 |
+
model_name = "gpt2"
|
42 |
+
output_length = 10
|
43 |
+
prompt = "Answer yes or no, is the sky blue?"
|
44 |
+
output = on_form_submit(model_name, output_length, prompt)
|
45 |
+
assert output is not None
|
46 |
+
assert len(output) > 0
|
47 |
+
empty_prompt = ""
|
48 |
+
with pytest.raises(ValueError):
|
49 |
+
on_form_submit(model_name, output_length, empty_prompt)
|
50 |
+
unsupported_model_name = "unsupported_model_name"
|
51 |
+
with pytest.raises(ValueError):
|
52 |
+
on_form_submit(unsupported_model_name, output_length, prompt)
|
53 |
+
|
54 |
+
|
55 |
+
@pytest.mark.parametrize(
|
56 |
+
"model_name",
|
57 |
+
get_supported_model_names(
|
58 |
+
min_number_of_downloads=1000,
|
59 |
+
min_number_of_likes=100,
|
60 |
+
)
|
61 |
+
)
|
62 |
def test_create_pipeline(model_name: str):
|
63 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|
64 |
assert pipeline is not None
|