File size: 4,280 Bytes
5967916 3a9aacf 2fd3831 22e2fd1 7a75a15 17edf44 b1dd47e 5967916 4b6c061 7a75a15 5967916 17edf44 5967916 bf8a943 a4b0060 d102e03 7a75a15 1fac618 49ad8d3 7a75a15 a4b0060 bf8a943 c9089bd dfa084c c9089bd a671856 c9089bd d102e03 c9089bd 7a75a15 0499581 7a75a15 30f253f 7a75a15 d73a8e9 7a75a15 d73a8e9 22e2fd1 3a9aacf 7a75a15 30f253f 7a75a15 3a9aacf bf8a943 22e2fd1 3a9aacf e63724c c9089bd 3a9aacf bf8a943 e63724c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
from time import time
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
from download_repo import download_pytorch_model
def is_downloaded(model_name: str) -> bool:
"""
Checks if the model is downloaded.
:param model_name: The name of the model to check.
:return: True if the model is downloaded, False otherwise.
"""
models_dir = "/root/.cache/huggingface/hub"
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
return os.path.isdir(model_dir)
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
"""
Creates a pipeline with the given model name and group size.
:param model_name: The name of the model to use.
:param group_size: The size of the groups to use.
:return: A pipeline with the given model name and group size.
"""
if not is_downloaded(model_name):
download_repository_start_time = time()
st.write(f"Starts downloading model: {model_name} from the internet.")
download_pytorch_model(model_name)
download_repository_end_time = time()
download_time = download_repository_end_time - download_repository_start_time
st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.")
st.write(f"Starts creating pipeline with model: {model_name}")
pipeline_start_time = time()
pipeline = GroupedSamplingPipeLine(
model_name=model_name,
group_size=group_size,
end_of_sentence_stop=False,
top_k=1,
)
pipeline_end_time = time()
pipeline_time = pipeline_end_time - pipeline_start_time
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
return pipeline
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:return: The generated text. str.
"""
return pipeline(
prompt_s=prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
def on_form_submit(
model_name: str,
output_length: int,
prompt: str,
) -> str:
"""
Called when the user submits the form.
:param model_name: The name of the model to use.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if len(prompt) == 0:
raise ValueError("The prompt must not be empty.")
if not is_supported(model_name):
raise UnsupportedModelNameException(model_name)
st.write(f"Loading model: {model_name}...")
loading_start_time = time()
pipeline = create_pipeline(
model_name=model_name,
group_size=output_length,
)
loading_end_time = time()
loading_time = loading_end_time - loading_start_time
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
st.write("Generating text...")
generation_start_time = time()
generated_text = generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
)
generation_end_time = time()
generation_time = generation_end_time - generation_start_time
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
if not isinstance(generated_text, str):
raise RuntimeError(f"The model {model_name} did not generate any text.")
if len(generated_text) == 0:
raise RuntimeError(f"The model {model_name} did not generate any text.")
return generated_text
|