File size: 2,398 Bytes
2fd3831
7a75a15
 
d102e03
b1dd47e
 
 
 
 
 
4b6c061
7a75a15
 
 
 
 
 
c9089bd
d102e03
7a75a15
 
 
30f253f
 
7a75a15
c9089bd
 
 
 
 
 
 
 
 
 
 
dfa084c
 
 
 
c9089bd
dfa084c
c9089bd
d102e03
c9089bd
d102e03
 
c9089bd
7a75a15
 
2fd3831
30f253f
7a75a15
 
 
30f253f
7a75a15
 
 
b1dd47e
 
 
 
7a75a15
30f253f
 
7a75a15
c9089bd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine

from prompt_engeneering import rewrite_prompt
from supported_models import get_supported_model_names


SUPPORTED_MODEL_NAMES = get_supported_model_names()


def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
    """
    Creates a pipeline with the given model name and group size.
    :param model_name: The name of the model to use.
    :param group_size: The size of the groups to use.
    :return: A pipeline with the given model name and group size.
    """
    print(f"Starts downloading model: {model_name} from the internet.")
    pipeline = GroupedSamplingPipeLine(
        model_name=model_name,
        group_size=group_size,
        end_of_sentence_stop=True,
        temp=0.5,
        top_p=0.6,
    )
    print(f"Finished downloading model: {model_name} from the internet.")
    return pipeline


def generate_text(
        pipeline: GroupedSamplingPipeLine,
        prompt: str,
        output_length: int,
) -> str:
    """
    Generates text using the given pipeline.
    :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
    :param prompt: The prompt to use. str.
    :param output_length: The size of the text to generate in tokens. int > 0.
    :return: The generated text. str.
    """
    better_prompt = rewrite_prompt(prompt)
    return pipeline(
        prompt_s=better_prompt,
        max_new_tokens=output_length,
        return_text=True,
        return_full_text=False,
    )["generated_text"]


@st.cache
def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
    """
    Called when the user submits the form.
    :param model_name: The name of the model to use.
    :param output_length: The size of the groups to use.
    :param prompt: The prompt to use.
    :return: The output of the model.
    """
    if model_name not in SUPPORTED_MODEL_NAMES:
        raise ValueError(f"The selected model {model_name} is not supported."
                         f"Supported models are all the models in:"
                         f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch")
    pipeline = create_pipeline(
        model_name=model_name,
        group_size=output_length,
    )
    return generate_text(
        pipeline=pipeline,
        prompt=prompt,
        output_length=output_length,
    )