File size: 2,534 Bytes
9b37b1f df1b7f8 9b37b1f df1b7f8 9b37b1f df1b7f8 9b37b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import shutil
import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine
from on_server_start import download_useful_models
from hanlde_form_submit import create_pipeline, on_form_submit
from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
def test_prompt_engineering():
example_prompt = "Answer yes or no, is the sky blue?"
rewritten_prompt = rewrite_prompt(example_prompt)
assert rewritten_prompt.startswith("Web search results:")
assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
assert "Current date: " in rewritten_prompt
assert "Instructions: " in rewritten_prompt
def test_get_supported_model_names():
supported_model_names = get_supported_model_names()
assert len(supported_model_names) > 0
assert "gpt2" in supported_model_names
assert all(isinstance(name, str) for name in supported_model_names)
def test_on_server_start():
if os.path.exists(HUGGING_FACE_CACHE_DIR):
shutil.rmtree(HUGGING_FACE_CACHE_DIR)
assert not os.path.exists(HUGGING_FACE_CACHE_DIR)
download_useful_models()
assert os.path.exists(HUGGING_FACE_CACHE_DIR)
assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
def test_on_form_submit():
model_name = "gpt2"
output_length = 10
prompt = "Answer yes or no, is the sky blue?"
output = on_form_submit(model_name, output_length, prompt)
assert output is not None
assert len(output) > 0
empty_prompt = ""
with pytest.raises(ValueError):
on_form_submit(model_name, output_length, empty_prompt)
unsupported_model_name = "unsupported_model_name"
with pytest.raises(ValueError):
on_form_submit(unsupported_model_name, output_length, prompt)
@pytest.mark.parametrize(
"model_name",
get_supported_model_names(
min_number_of_downloads=1000,
min_number_of_likes=100,
)
)
def test_create_pipeline(model_name: str):
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
assert pipeline is not None
assert pipeline.model_name == model_name
assert pipeline.wrapped_model.group_size == 5
assert pipeline.wrapped_model.end_of_sentence_stop is True
model_folder_name = "models--" + model_name.replace("/", "--")
model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name)
shutil.rmtree(model_path)
if __name__ == "__main__":
pytest.main()
|