import streamlit as st | |
from grouped_sampling import GroupedSamplingPipeLine | |
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine: | |
""" | |
Creates a pipeline with the given model name and group size. | |
:param model_name: The name of the model to use. | |
:param group_size: The size of the groups to use. | |
:return: A pipeline with the given model name and group size. | |
""" | |
return GroupedSamplingPipeLine( | |
model_name=model_name, | |
group_size=group_size, | |
end_of_sentence_stop=True, | |
) | |
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str: | |
""" | |
Called when the user submits the form. | |
:param model_name: The name of the model to use. | |
:param group_size: The size of the groups to use. | |
:param prompt: The prompt to use. | |
:return: The output of the model. | |
""" | |
pipeline = create_pipeline( | |
model_name, | |
group_size, | |
) | |
return pipeline(prompt)["generated_text"] | |