yonikremer's picture
added tests
9b37b1f
raw
history blame
No virus
1.54 kB
import os
import shutil
import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine
from hanlde_form_submit import create_pipeline
from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
def test_prompt_engineering():
example_prompt = "Answer yes or no, is the sky blue?"
rewritten_prompt = rewrite_prompt(example_prompt)
assert rewritten_prompt.startswith("Web search results:")
assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
assert "Current date: " in rewritten_prompt
assert "Instructions: " in rewritten_prompt
def test_get_supported_model_names():
supported_model_names = get_supported_model_names()
assert len(supported_model_names) > 0
assert "gpt2" in supported_model_names
assert all([isinstance(name, str) for name in supported_model_names])
@pytest.mark.parametrize("model_name", get_supported_model_names())
def test_create_pipeline(model_name: str):
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
assert pipeline is not None
assert pipeline.model_name == model_name
assert pipeline.wrapped_model.group_size == 5
assert pipeline.wrapped_model.end_of_sentence_stop is True
model_folder_name = "models--" + model_name.replace("/", "--")
model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name)
shutil.rmtree(model_path)
if __name__ == "__main__":
pytest.main()