grouped-sampling-demo / hanlde_form_submit.py
yonikremer's picture
adaptation to new versions of grouped-sampling
22e2fd1
raw
history blame
4.64 kB
import os
from time import time
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
from download_repo import download_pytorch_model
from prompt_engeneering import rewrite_prompt
def is_downloaded(model_name: str) -> bool:
"""
Checks if the model is downloaded.
:param model_name: The name of the model to check.
:return: True if the model is downloaded, False otherwise.
"""
models_dir = "/root/.cache/huggingface/hub"
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
return os.path.isdir(model_dir)
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
if not is_downloaded(model_name):
download_repository_start_time = time()
st.write(f"Starts downloading model: {model_name} from the internet.")
download_pytorch_model(model_name)
download_repository_end_time = time()
download_time = download_repository_end_time - download_repository_start_time
st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.")
st.write(f"Starts creating pipeline with model: {model_name}")
pipeline_start_time = time()
pipeline = GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=False,
top_k=1,
)
pipeline_end_time = time()
pipeline_time = pipeline_end_time - pipeline_start_time
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
return pipeline
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
web_search: bool,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:param web_search: Whether to use web search or not. bool.
:return: The generated text. str.
"""
if web_search:
better_prompt = rewrite_prompt(prompt)
else:
better_prompt = prompt
return pipeline(
prompt_s=better_prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
def on_form_submit(
model_name: str,
output_length: int,
prompt: str,
web_search: bool
) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:param web_search: Whether to use web search or not.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if len(prompt) == 0:
raise ValueError("The prompt must not be empty.")
if not is_supported(model_name):
raise UnsupportedModelNameException(model_name)
st.write(f"Loading model: {model_name}...")
loading_start_time = time()
pipeline = create_pipeline(
model_name=model_name,
group_size=output_length,
)
loading_end_time = time()
loading_time = loading_end_time - loading_start_time
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
st.write("Generating text...")
generation_start_time = time()
generated_text = generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
web_search=web_search,
)
generation_end_time = time()
generation_time = generation_end_time - generation_start_time
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
if not isinstance(generated_text, str):
raise RuntimeError(f"The model {model_name} did not generate any text.")
if len(generated_text) == 0:
raise RuntimeError(f"The model {model_name} did not generate any text.")
return generated_text