|
import os |
|
|
|
import pytest as pytest |
|
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException |
|
|
|
from on_server_start import download_useful_models |
|
from hanlde_form_submit import create_pipeline, on_form_submit |
|
from prompt_engeneering import rewrite_prompt |
|
|
|
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub" |
|
|
|
|
|
def test_prompt_engineering(): |
|
example_prompt = "Answer yes or no, is the sky blue?" |
|
rewritten_prompt = rewrite_prompt(example_prompt) |
|
assert rewritten_prompt.startswith("Web search results:") |
|
assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?") |
|
assert "Current date: " in rewritten_prompt |
|
assert "Instructions: " in rewritten_prompt |
|
|
|
|
|
def test_get_supported_model_names(): |
|
supported_model_names = get_full_models_list() |
|
assert len(supported_model_names) > 0 |
|
assert "gpt2" in supported_model_names |
|
assert all(isinstance(name, str) for name in supported_model_names) |
|
|
|
|
|
def test_on_server_start(): |
|
download_useful_models() |
|
assert os.path.exists(HUGGING_FACE_CACHE_DIR) |
|
assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0 |
|
|
|
|
|
def test_on_form_submit(): |
|
model_name = "gpt2" |
|
output_length = 10 |
|
prompt = "Answer yes or no, is the sky blue?" |
|
output = on_form_submit(model_name, output_length, prompt, web_search=False) |
|
assert output is not None |
|
assert len(output) > 0 |
|
empty_prompt = "" |
|
with pytest.raises(ValueError): |
|
on_form_submit(model_name, output_length, empty_prompt, web_search=False) |
|
unsupported_model_name = "unsupported_model_name" |
|
with pytest.raises(UnsupportedModelNameException): |
|
on_form_submit(unsupported_model_name, output_length, prompt, web_search=False) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"model_name", |
|
get_full_models_list()[:3] |
|
) |
|
def test_create_pipeline(model_name: str): |
|
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5) |
|
assert pipeline is not None |
|
assert pipeline.model_name == model_name |
|
assert pipeline.wrapped_model.group_size == 5 |
|
assert pipeline.wrapped_model.end_of_sentence_stop is False |
|
del pipeline |
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main() |
|
|