|
""" |
|
The Streamlit app for the project demo. |
|
In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm. |
|
""" |
|
|
|
import streamlit as st |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
|
|
available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads" |
|
|
|
|
|
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine: |
|
""" |
|
Creates a pipeline with the given model name and group size. |
|
:param model_name: The name of the model to use. |
|
:param group_size: The size of the groups to use. |
|
:return: A pipeline with the given model name and group size. |
|
""" |
|
return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size) |
|
|
|
|
|
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str: |
|
""" |
|
Called when the user submits the form. |
|
:param model_name: The name of the model to use. |
|
:param group_size: The size of the groups to use. |
|
:param prompt: The prompt to use. |
|
:return: The output of the model. |
|
""" |
|
pipeline = create_pipeline(model_name, group_size) |
|
return pipeline(prompt)["generated_text"] |
|
|
|
|
|
with st.form("request_form"): |
|
selected_model_name: str = st.text_input( |
|
label="Model name", |
|
value="gpt2", |
|
help=f"The name of the model to use. Must be a model from this list: {available_models_list}" |
|
) |
|
|
|
output_length: int = st.number_input( |
|
label="Output Length in tokens", |
|
min_value=1, |
|
max_value=4096, |
|
value=100, |
|
help="The length of the output text in tokens (word pieces)." |
|
) |
|
|
|
submitted_prompt: str = st.text_area( |
|
label="Input for the model", |
|
help="Enter the prompt for the model. The model will generate a response based on this prompt.", |
|
max_chars=16384, |
|
) |
|
|
|
submitted: bool = st.form_submit_button( |
|
label="Generate", |
|
help="Generate the output text.", |
|
disabled=False |
|
|
|
) |
|
|
|
if submitted: |
|
output = on_form_submit(selected_model_name, output_length, submitted_prompt) |
|
st.write(f"Generated text: {output}") |
|
|