grouped-sampling-demo / hanlde_form_submit.py
yonikremer's picture
on_form_submit is now cached
2fd3831
raw
history blame
No virus
1 kB
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
return GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=True,
)
@st.cache
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
"""
pipeline = create_pipeline(
model_name,
group_size,
)
return pipeline(prompt)["generated_text"]