downloading only the pytorch model and important files, not the other versions of the model
17edf44
import os | |
from time import time | |
import streamlit as st | |
from grouped_sampling import GroupedSamplingPipeLine | |
from download_repo import download_pytorch_model | |
from prompt_engeneering import rewrite_prompt | |
from supported_models import is_supported, SUPPORTED_MODEL_NAME_PAGES_FORMAT, BLACKLISTED_MODEL_NAMES, \ | |
BLACKLISTED_ORGANIZATIONS | |
def is_downloaded(model_name: str) -> bool: | |
""" | |
Checks if the model is downloaded. | |
:param model_name: The name of the model to check. | |
:return: True if the model is downloaded, False otherwise. | |
""" | |
models_dir = "/root/.cache/huggingface/hub" | |
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}") | |
return os.path.isdir(model_dir) | |
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine: | |
""" | |
Creates a pipeline with the given model name and group size. | |
:param model_name: The name of the model to use. | |
:param group_size: The size of the groups to use. | |
:return: A pipeline with the given model name and group size. | |
""" | |
if not is_downloaded(model_name): | |
download_repository_start_time = time() | |
st.write(f"Starts downloading model: {model_name} from the internet.") | |
download_pytorch_model(model_name) | |
download_repository_end_time = time() | |
download_time = download_repository_end_time - download_repository_start_time | |
st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.") | |
st.write(f"Starts creating pipeline with model: {model_name}") | |
pipeline_start_time = time() | |
pipeline = GroupedSamplingPipeLine( | |
model_name=model_name, | |
group_size=group_size, | |
end_of_sentence_stop=False, | |
top_k=1, | |
) | |
pipeline_end_time = time() | |
pipeline_time = pipeline_end_time - pipeline_start_time | |
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.") | |
return pipeline | |
def generate_text( | |
pipeline: GroupedSamplingPipeLine, | |
prompt: str, | |
output_length: int, | |
web_search: bool, | |
) -> str: | |
""" | |
Generates text using the given pipeline. | |
:param pipeline: The pipeline to use. GroupedSamplingPipeLine. | |
:param prompt: The prompt to use. str. | |
:param output_length: The size of the text to generate in tokens. int > 0. | |
:param web_search: Whether to use web search or not. bool. | |
:return: The generated text. str. | |
""" | |
if web_search: | |
better_prompt = rewrite_prompt(prompt) | |
else: | |
better_prompt = prompt | |
return pipeline( | |
prompt_s=better_prompt, | |
max_new_tokens=output_length, | |
return_text=True, | |
return_full_text=False, | |
)["generated_text"] | |
def on_form_submit( | |
model_name: str, | |
output_length: int, | |
prompt: str, | |
web_search: bool | |
) -> str: | |
""" | |
Called when the user submits the form. | |
:param model_name: The name of the model to use. | |
:param output_length: The size of the groups to use. | |
:param prompt: The prompt to use. | |
:param web_search: Whether to use web search or not. | |
:return: The output of the model. | |
:raises ValueError: If the model name is not supported, the output length is <= 0, | |
the prompt is empty or longer than | |
16384 characters, or the output length is not an integer. | |
TypeError: If the output length is not an integer or the prompt is not a string. | |
RuntimeError: If the model is not found. | |
""" | |
if not is_supported(model_name, 1, 1): | |
raise ValueError( | |
f"The model: {model_name} is not supported." | |
f"The supported models are the models from {SUPPORTED_MODEL_NAME_PAGES_FORMAT}" | |
f" that satisfy the following conditions:\n" | |
f"1. The model has at least one like and one download.\n" | |
f"2. The model is not one of: {BLACKLISTED_MODEL_NAMES}.\n" | |
f"3. The model was not created any of those organizations: {BLACKLISTED_ORGANIZATIONS}.\n" | |
) | |
if len(prompt) == 0: | |
raise ValueError(f"The prompt must not be empty.") | |
st.write(f"Loading model: {model_name}...") | |
loading_start_time = time() | |
pipeline = create_pipeline( | |
model_name=model_name, | |
group_size=output_length, | |
) | |
loading_end_time = time() | |
loading_time = loading_end_time - loading_start_time | |
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.") | |
st.write(f"Generating text...") | |
generation_start_time = time() | |
generated_text = generate_text( | |
pipeline=pipeline, | |
prompt=prompt, | |
output_length=output_length, | |
web_search=web_search, | |
) | |
generation_end_time = time() | |
generation_time = generation_end_time - generation_start_time | |
st.write(f"Finished generating text in {generation_time:,.2f} seconds.") | |
if not isinstance(generated_text, str): | |
raise RuntimeError(f"The model {model_name} did not generate any text.") | |
if len(generated_text) == 0: | |
raise RuntimeError(f"The model {model_name} did not generate any text.") | |
return generated_text | |