from grouped_sampling import GroupedSamplingPipeLine def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine: """ Creates a pipeline with the given model name and group size. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :return: A pipeline with the given model name and group size. """ return GroupedSamplingPipeLine( model_name=model_name, group_size=group_size, end_of_sentence_stop=True, ) def on_form_submit(model_name: str, group_size: int, prompt: str) -> str: """ Called when the user submits the form. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :param prompt: The prompt to use. :return: The output of the model. """ pipeline = create_pipeline( model_name, group_size, ) return pipeline(prompt)["generated_text"]