|
import os |
|
from functools import lru_cache |
|
from time import time |
|
|
|
import streamlit as st |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
|
|
from download_repo import download_pytorch_model |
|
|
|
|
|
def is_downloaded(model_name: str) -> bool: |
|
""" |
|
Checks if the model is downloaded. |
|
:param model_name: The name of the model to check. |
|
:return: True if the model is downloaded, False otherwise. |
|
""" |
|
models_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") |
|
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}") |
|
return os.path.isdir(model_dir) |
|
|
|
|
|
@lru_cache(maxsize=10) |
|
def create_pipeline(model_name: str) -> GroupedSamplingPipeLine: |
|
""" |
|
Creates a pipeline with the given model name and group size. |
|
:param model_name: The name of the model to use. |
|
:return: A pipeline with the given model name and group size. |
|
""" |
|
if not is_downloaded(model_name): |
|
download_repository_start_time = time() |
|
st.write(f"Starts downloading model: {model_name} from the internet.") |
|
download_pytorch_model(model_name) |
|
download_repository_end_time = time() |
|
download_time = download_repository_end_time - download_repository_start_time |
|
st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.") |
|
st.write(f"Starts creating pipeline with model: {model_name}") |
|
pipeline_start_time = time() |
|
pipeline = GroupedSamplingPipeLine( |
|
model_name=model_name, |
|
group_size=512, |
|
end_of_sentence_stop=False, |
|
top_k=50, |
|
load_in_8bit=False, |
|
) |
|
pipeline_end_time = time() |
|
pipeline_time = pipeline_end_time - pipeline_start_time |
|
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") |
|
return pipeline |
|
|
|
|
|
def generate_text( |
|
pipeline: GroupedSamplingPipeLine, |
|
prompt: str, |
|
output_length: int, |
|
) -> str: |
|
""" |
|
Generates text using the given pipeline. |
|
:param pipeline: The pipeline to use. GroupedSamplingPipeLine. |
|
:param prompt: The prompt to use. str. |
|
:param output_length: The size of the text to generate in tokens. int > 0. |
|
:return: The generated text. str. |
|
""" |
|
return pipeline( |
|
prompt_s=prompt, |
|
max_new_tokens=output_length, |
|
return_text=True, |
|
return_full_text=False, |
|
)["generated_text"] |
|
|
|
|
|
def on_form_submit( |
|
model_name: str, |
|
output_length: int, |
|
prompt: str, |
|
) -> str: |
|
""" |
|
Called when the user submits the form. |
|
:param model_name: The name of the model to use. |
|
:param output_length: The size of the groups to use. |
|
:param prompt: The prompt to use. |
|
:return: The output of the model. |
|
:raises ValueError: If the model name is not supported, the output length is <= 0, |
|
the prompt is empty or longer than |
|
16384 characters, or the output length is not an integer. |
|
TypeError: If the output length is not an integer or the prompt is not a string. |
|
RuntimeError: If the model is not found. |
|
""" |
|
if len(prompt) == 0: |
|
raise ValueError("The prompt must not be empty.") |
|
st.write(f"Loading model: {model_name}...") |
|
loading_start_time = time() |
|
pipeline = create_pipeline( |
|
model_name=model_name, |
|
) |
|
loading_end_time = time() |
|
loading_time = loading_end_time - loading_start_time |
|
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.") |
|
st.write("Generating text...") |
|
generation_start_time = time() |
|
generated_text = generate_text( |
|
pipeline=pipeline, |
|
prompt=prompt, |
|
output_length=output_length, |
|
) |
|
generation_end_time = time() |
|
generation_time = generation_end_time - generation_start_time |
|
st.write(f"Finished generating text in {generation_time:,.2f} seconds.") |
|
if not isinstance(generated_text, str): |
|
raise RuntimeError(f"The model {model_name} did not generate any text.") |
|
if len(generated_text) == 0: |
|
raise RuntimeError(f"The model {model_name} did not generate any text.") |
|
return generated_text |
|
|