File size: 1,544 Bytes
9b37b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import shutil

import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine

from hanlde_form_submit import create_pipeline
from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names

HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"


def test_prompt_engineering():
    example_prompt = "Answer yes or no, is the sky blue?"
    rewritten_prompt = rewrite_prompt(example_prompt)
    assert rewritten_prompt.startswith("Web search results:")
    assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
    assert "Current date: " in rewritten_prompt
    assert "Instructions: " in rewritten_prompt


def test_get_supported_model_names():
    supported_model_names = get_supported_model_names()
    assert len(supported_model_names) > 0
    assert "gpt2" in supported_model_names
    assert all([isinstance(name, str) for name in supported_model_names])


@pytest.mark.parametrize("model_name", get_supported_model_names())
def test_create_pipeline(model_name: str):
    pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
    assert pipeline is not None
    assert pipeline.model_name == model_name
    assert pipeline.wrapped_model.group_size == 5
    assert pipeline.wrapped_model.end_of_sentence_stop is True
    model_folder_name = "models--" + model_name.replace("/", "--")
    model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name)
    shutil.rmtree(model_path)


if __name__ == "__main__":
    pytest.main()