import os import shutil import pytest as pytest from grouped_sampling import GroupedSamplingPipeLine from hanlde_form_submit import create_pipeline from prompt_engeneering import rewrite_prompt from supported_models import get_supported_model_names HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub" def test_prompt_engineering(): example_prompt = "Answer yes or no, is the sky blue?" rewritten_prompt = rewrite_prompt(example_prompt) assert rewritten_prompt.startswith("Web search results:") assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?") assert "Current date: " in rewritten_prompt assert "Instructions: " in rewritten_prompt def test_get_supported_model_names(): supported_model_names = get_supported_model_names() assert len(supported_model_names) > 0 assert "gpt2" in supported_model_names assert all([isinstance(name, str) for name in supported_model_names]) @pytest.mark.parametrize("model_name", get_supported_model_names()) def test_create_pipeline(model_name: str): pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5) assert pipeline is not None assert pipeline.model_name == model_name assert pipeline.wrapped_model.group_size == 5 assert pipeline.wrapped_model.end_of_sentence_stop is True model_folder_name = "models--" + model_name.replace("/", "--") model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name) shutil.rmtree(model_path) if __name__ == "__main__": pytest.main()