yonikremer commited on
Commit
9b37b1f
1 Parent(s): dfa084c

added tests

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. tests.py +43 -0
requirements.txt CHANGED
@@ -7,4 +7,5 @@ beautifulsoup4~=4.11.2
7
  urllib3
8
  requests~=2.28.2
9
  accelerate
10
- bitsandbytes
 
 
7
  urllib3
8
  requests~=2.28.2
9
  accelerate
10
+ bitsandbytes
11
+ pytest
tests.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import pytest as pytest
5
+ from grouped_sampling import GroupedSamplingPipeLine
6
+
7
+ from hanlde_form_submit import create_pipeline
8
+ from prompt_engeneering import rewrite_prompt
9
+ from supported_models import get_supported_model_names
10
+
11
+ HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
12
+
13
+
14
+ def test_prompt_engineering():
15
+ example_prompt = "Answer yes or no, is the sky blue?"
16
+ rewritten_prompt = rewrite_prompt(example_prompt)
17
+ assert rewritten_prompt.startswith("Web search results:")
18
+ assert rewritten_prompt.endswith("Query: Answer yes or no, is the sky blue?")
19
+ assert "Current date: " in rewritten_prompt
20
+ assert "Instructions: " in rewritten_prompt
21
+
22
+
23
+ def test_get_supported_model_names():
24
+ supported_model_names = get_supported_model_names()
25
+ assert len(supported_model_names) > 0
26
+ assert "gpt2" in supported_model_names
27
+ assert all([isinstance(name, str) for name in supported_model_names])
28
+
29
+
30
+ @pytest.mark.parametrize("model_name", get_supported_model_names())
31
+ def test_create_pipeline(model_name: str):
32
+ pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
33
+ assert pipeline is not None
34
+ assert pipeline.model_name == model_name
35
+ assert pipeline.wrapped_model.group_size == 5
36
+ assert pipeline.wrapped_model.end_of_sentence_stop is True
37
+ model_folder_name = "models--" + model_name.replace("/", "--")
38
+ model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name)
39
+ shutil.rmtree(model_path)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ pytest.main()