File size: 1,262 Bytes
9b37b1f
 
 
22e2fd1
9b37b1f
df1b7f8
9b37b1f
 
 
 
df1b7f8
 
 
 
a671856
df1b7f8
 
 
 
a671856
df1b7f8
22e2fd1
a671856
df1b7f8
 
 
 
22e2fd1
df1b7f8
9b37b1f
 
 
 
 
28a440d
32e4e72
9b37b1f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os

import pytest as pytest
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException

from hanlde_form_submit import create_pipeline, on_form_submit

HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"


def test_on_form_submit():
    model_name = "gpt2"
    output_length = 10
    prompt = "Answer yes or no, is the sky blue?"
    output = on_form_submit(model_name, output_length, prompt)
    assert output is not None
    assert len(output) > 0
    empty_prompt = ""
    with pytest.raises(ValueError):
        on_form_submit(model_name, output_length, empty_prompt)
    unsupported_model_name = "unsupported_model_name"
    with pytest.raises(UnsupportedModelNameException):
        on_form_submit(unsupported_model_name, output_length, prompt)


@pytest.mark.parametrize(
    "model_name",
    get_full_models_list()[:3]
)
def test_create_pipeline(model_name: str):
    pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
    assert pipeline is not None
    assert pipeline.model_name == model_name
    assert pipeline.wrapped_model.group_size == 5
    assert pipeline.wrapped_model.end_of_sentence_stop is False
    del pipeline


if __name__ == "__main__":
    pytest.main()