import streamlit as st from grouped_sampling import GroupedSamplingPipeLine from prompt_engeneering import rewrite_prompt from supported_models import get_supported_model_names SUPPORTED_MODEL_NAMES = get_supported_model_names() def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine: """ Creates a pipeline with the given model name and group size. :param model_name: The name of the model to use. :param group_size: The size of the groups to use. :return: A pipeline with the given model name and group size. """ print(f"Starts downloading model: {model_name} from the internet.") pipeline = GroupedSamplingPipeLine( model_name=model_name, group_size=group_size, end_of_sentence_stop=True, temp=0.5, top_p=0.6, ) print(f"Finished downloading model: {model_name} from the internet.") return pipeline def generate_text( pipeline: GroupedSamplingPipeLine, prompt: str, output_length: int, ) -> str: """ Generates text using the given pipeline. :param pipeline: The pipeline to use. GroupedSamplingPipeLine. :param prompt: The prompt to use. str. :param output_length: The size of the text to generate in tokens. int > 0. :return: The generated text. str. """ better_prompt = rewrite_prompt(prompt) return pipeline( prompt_s=better_prompt, max_new_tokens=output_length, return_text=True, return_full_text=False, )["generated_text"] @st.cache def on_form_submit(model_name: str, output_length: int, prompt: str) -> str: """ Called when the user submits the form. :param model_name: The name of the model to use. :param output_length: The size of the groups to use. :param prompt: The prompt to use. :return: The output of the model. """ if model_name not in SUPPORTED_MODEL_NAMES: raise ValueError(f"The selected model {model_name} is not supported." f"Supported models are all the models in:" f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch") pipeline = create_pipeline( model_name=model_name, group_size=output_length, ) return generate_text( pipeline=pipeline, prompt=prompt, output_length=output_length, )