|
import os |
|
import shutil |
|
|
|
import pytest as pytest |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
|
|
from hanlde_form_submit import create_pipeline |
|
from prompt_engeneering import rewrite_prompt |
|
from supported_models import get_supported_model_names |
|
|
|
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub" |
|
|
|
|
|
def test_prompt_engineering(): |
|
example_prompt = "Answer yes or no, is the sky blue?" |
|
rewritten_prompt = rewrite_prompt(example_prompt) |
|
assert rewritten_prompt.startswith("Web search results:") |
|
assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?") |
|
assert "Current date: " in rewritten_prompt |
|
assert "Instructions: " in rewritten_prompt |
|
|
|
|
|
def test_get_supported_model_names(): |
|
supported_model_names = get_supported_model_names() |
|
assert len(supported_model_names) > 0 |
|
assert "gpt2" in supported_model_names |
|
assert all([isinstance(name, str) for name in supported_model_names]) |
|
|
|
|
|
@pytest.mark.parametrize("model_name", get_supported_model_names()) |
|
def test_create_pipeline(model_name: str): |
|
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5) |
|
assert pipeline is not None |
|
assert pipeline.model_name == model_name |
|
assert pipeline.wrapped_model.group_size == 5 |
|
assert pipeline.wrapped_model.end_of_sentence_stop is True |
|
model_folder_name = "models--" + model_name.replace("/", "--") |
|
model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name) |
|
shutil.rmtree(model_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
pytest.main() |
|
|