""" The Streamlit app for the project demo. In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm. """ import streamlit as st from grouped_sampling import GroupedSamplingPipeLine available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads" def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine: """ Creates a pipeline with the given model name and group size. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :return: A pipeline with the given model name and group size. """ return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size) def on_form_submit(model_name: str, group_size: int, prompt: str) -> str: """ Called when the user submits the form. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :param prompt: The prompt to use. :return: The output of the model. """ pipeline = create_pipeline(model_name, group_size) return pipeline(prompt)["generated_text"] with st.form("request_form"): selected_model_name: str = st.text_input( label="Model name", value="gpt2", help=f"The name of the model to use. Must be a model from this list: {available_models_list}" ) output_length: int = st.number_input( label="Output Length in tokens", min_value=1, max_value=4096, value=100, help="The length of the output text in tokens (word pieces)." ) submitted_prompt: str = st.text_area( label="Input for the model", help="Enter the prompt for the model. The model will generate a response based on this prompt.", max_chars=16384, ) submitted: bool = st.form_submit_button( label="Generate", help="Generate the output text.", disabled=False ) if submitted: output = on_form_submit(selected_model_name, output_length, submitted_prompt) st.write(f"Generated text: {output}")