File size: 4,281 Bytes
5967916
3a9aacf
 
2fd3831
22e2fd1
7a75a15
17edf44
b1dd47e
 
5967916
 
 
 
 
 
 
 
 
 
 
4b6c061
7a75a15
 
 
 
 
 
5967916
 
 
17edf44
5967916
 
bf8a943
a4b0060
 
d102e03
7a75a15
 
1fac618
98c1d3b
7a75a15
a4b0060
 
bf8a943
c9089bd
 
 
 
 
 
 
 
 
 
dfa084c
 
 
 
c9089bd
 
a671856
c9089bd
d102e03
 
c9089bd
7a75a15
 
0499581
 
 
 
 
7a75a15
 
 
30f253f
7a75a15
 
d73a8e9
 
 
 
 
7a75a15
d73a8e9
22e2fd1
 
 
3a9aacf
 
7a75a15
30f253f
 
7a75a15
3a9aacf
 
bf8a943
22e2fd1
3a9aacf
e63724c
c9089bd
 
 
 
3a9aacf
 
bf8a943
e63724c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from time import time

import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException

from download_repo import download_pytorch_model


def is_downloaded(model_name: str) -> bool:
    """
    Checks if the model is downloaded.
    :param model_name: The name of the model to check.
    :return: True if the model is downloaded, False otherwise.
    """
    models_dir = "/root/.cache/huggingface/hub"
    model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
    return os.path.isdir(model_dir)


def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
    """
    Creates a pipeline with the given model name and group size.
    :param model_name: The name of the model to use.
    :param group_size: The size of the groups to use.
    :return: A pipeline with the given model name and group size.
    """
    if not is_downloaded(model_name):
        download_repository_start_time = time()
        st.write(f"Starts downloading model: {model_name} from the internet.")
        download_pytorch_model(model_name)
        download_repository_end_time = time()
        download_time = download_repository_end_time - download_repository_start_time
        st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.")
    st.write(f"Starts creating pipeline with model: {model_name}")
    pipeline_start_time = time()
    pipeline = GroupedSamplingPipeLine(
        model_name=model_name,
        group_size=group_size,
        end_of_sentence_stop=False,
        top_k=50,
    )
    pipeline_end_time = time()
    pipeline_time = pipeline_end_time - pipeline_start_time
    st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
    return pipeline


def generate_text(
        pipeline: GroupedSamplingPipeLine,
        prompt: str,
        output_length: int,
) -> str:
    """
    Generates text using the given pipeline.
    :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
    :param prompt: The prompt to use. str.
    :param output_length: The size of the text to generate in tokens. int > 0.
    :return: The generated text. str.
    """
    return pipeline(
        prompt_s=prompt,
        max_new_tokens=output_length,
        return_text=True,
        return_full_text=False,
    )["generated_text"]


def on_form_submit(
        model_name: str,
        output_length: int,
        prompt: str,
) -> str:
    """
    Called when the user submits the form.
    :param model_name: The name of the model to use.
    :param output_length: The size of the groups to use.
    :param prompt: The prompt to use.
    :return: The output of the model.
    :raises ValueError: If the model name is not supported, the output length is <= 0,
     the prompt is empty or longer than
     16384 characters, or the output length is not an integer.
     TypeError: If the output length is not an integer or the prompt is not a string.
     RuntimeError: If the model is not found.
    """
    if len(prompt) == 0:
        raise ValueError("The prompt must not be empty.")
    if not is_supported(model_name):
        raise UnsupportedModelNameException(model_name)
    st.write(f"Loading model: {model_name}...")
    loading_start_time = time()
    pipeline = create_pipeline(
        model_name=model_name,
        group_size=output_length,
    )
    loading_end_time = time()
    loading_time = loading_end_time - loading_start_time
    st.write(f"Finished loading model: {model_name}  in {loading_time:,.2f} seconds.")
    st.write("Generating text...")
    generation_start_time = time()
    generated_text = generate_text(
        pipeline=pipeline,
        prompt=prompt,
        output_length=output_length,
    )
    generation_end_time = time()
    generation_time = generation_end_time - generation_start_time
    st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
    if not isinstance(generated_text, str):
        raise RuntimeError(f"The model {model_name} did not generate any text.")
    if len(generated_text) == 0:
        raise RuntimeError(f"The model {model_name} did not generate any text.")
    return generated_text