yonikremer
commited on
Commit
•
826e275
1
Parent(s):
feb3275
created an initial app
Browse files
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The Streamlit app for the project demo.
|
3 |
+
In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
from grouped_sampling import GroupedSamplingPipeLine
|
8 |
+
|
9 |
+
available_models_list = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
|
10 |
+
|
11 |
+
|
12 |
+
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
|
13 |
+
"""
|
14 |
+
Creates a pipeline with the given model name and group size.
|
15 |
+
:param model_name: The name of the model to use.
|
16 |
+
:param group_size: The size of the groups to use.
|
17 |
+
:return: A pipeline with the given model name and group size.
|
18 |
+
"""
|
19 |
+
return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)
|
20 |
+
|
21 |
+
|
22 |
+
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
|
23 |
+
"""
|
24 |
+
Called when the user submits the form.
|
25 |
+
:param model_name: The name of the model to use.
|
26 |
+
:param group_size: The size of the groups to use.
|
27 |
+
:param prompt: The prompt to use.
|
28 |
+
:return: The output of the model.
|
29 |
+
"""
|
30 |
+
pipeline = create_pipeline(model_name, group_size)
|
31 |
+
return pipeline(prompt)["generated_text"]
|
32 |
+
|
33 |
+
|
34 |
+
with st.form("request_form"):
|
35 |
+
selected_model_name: str = st.text_input(
|
36 |
+
label="Model name",
|
37 |
+
value="gpt2",
|
38 |
+
help=f"The name of the model to use. Must be a model from this list: {available_models_list}"
|
39 |
+
)
|
40 |
+
|
41 |
+
output_length: int = st.number_input(
|
42 |
+
label="Output Length in tokens",
|
43 |
+
min_value=1,
|
44 |
+
max_value=4096,
|
45 |
+
value=100,
|
46 |
+
help="The length of the output text in tokens (word pieces)."
|
47 |
+
)
|
48 |
+
|
49 |
+
submitted_prompt: str = st.text_area(
|
50 |
+
label="Input for the model",
|
51 |
+
help="Enter the prompt for the model. The model will generate a response based on this prompt.",
|
52 |
+
max_chars=16384,
|
53 |
+
)
|
54 |
+
|
55 |
+
submitted: bool = st.form_submit_button(
|
56 |
+
label="Generate",
|
57 |
+
help="Generate the output text.",
|
58 |
+
disabled=False
|
59 |
+
|
60 |
+
)
|
61 |
+
|
62 |
+
if submitted:
|
63 |
+
output = on_form_submit(selected_model_name, output_length, submitted_prompt)
|
64 |
+
st.write(f"Generated text: {output}")
|