File size: 2,148 Bytes
826e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
The Streamlit app for the project demo.
In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm.
"""

import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine

available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"


def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
    """
    Creates a pipeline with the given model name and group size.
    :param model_name: The name of the model to use.
    :param group_size: The size of the groups to use.
    :return: A pipeline with the given model name and group size.
    """
    return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)


def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
    """
    Called when the user submits the form.
    :param model_name: The name of the model to use.
    :param group_size: The size of the groups to use.
    :param prompt: The prompt to use.
    :return: The output of the model.
    """
    pipeline = create_pipeline(model_name, group_size)
    return pipeline(prompt)["generated_text"]


with st.form("request_form"):
    selected_model_name: str = st.text_input(
        label="Model name",
        value="gpt2",
        help=f"The name of the model to use. Must be a model from this list: {available_models_list}"
    )

    output_length: int = st.number_input(
        label="Output Length in tokens",
        min_value=1,
        max_value=4096,
        value=100,
        help="The length of the output text in tokens (word pieces)."
    )

    submitted_prompt: str = st.text_area(
        label="Input for the model",
        help="Enter the prompt for the model. The model will generate a response based on this prompt.",
        max_chars=16384,
    )

    submitted: bool = st.form_submit_button(
        label="Generate",
        help="Generate the output text.",
        disabled=False

    )

    if submitted:
        output = on_form_submit(selected_model_name, output_length, submitted_prompt)
        st.write(f"Generated text: {output}")