yonikremer
commited on
Commit
•
7a75a15
1
Parent(s):
a95851c
moved functions to new file
Browse files- app.py +7 -24
- hanlde_form_submit.py +30 -0
app.py
CHANGED
@@ -1,41 +1,24 @@
|
|
1 |
"""
|
2 |
The Streamlit app for the project demo.
|
3 |
-
In the demo, the user can write a prompt
|
|
|
4 |
"""
|
5 |
|
6 |
import streamlit as st
|
7 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
8 |
|
9 |
-
|
10 |
|
11 |
|
12 |
-
|
13 |
-
"""
|
14 |
-
Creates a pipeline with the given model name and group size.
|
15 |
-
:param model_name: The name of the model to use.
|
16 |
-
:param group_size: The size of the groups to use.
|
17 |
-
:return: A pipeline with the given model name and group size.
|
18 |
-
"""
|
19 |
-
return GroupedSamplingPipeLine(model_name=model_name, group_size=group_size)
|
20 |
-
|
21 |
-
|
22 |
-
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
|
23 |
-
"""
|
24 |
-
Called when the user submits the form.
|
25 |
-
:param model_name: The name of the model to use.
|
26 |
-
:param group_size: The size of the groups to use.
|
27 |
-
:param prompt: The prompt to use.
|
28 |
-
:return: The output of the model.
|
29 |
-
"""
|
30 |
-
pipeline = create_pipeline(model_name, group_size)
|
31 |
-
return pipeline(prompt)["generated_text"]
|
32 |
|
33 |
|
34 |
with st.form("request_form"):
|
35 |
selected_model_name: str = st.text_input(
|
36 |
label="Model name",
|
37 |
value="gpt2",
|
38 |
-
help=f"The name of the model to use.
|
|
|
|
|
39 |
)
|
40 |
|
41 |
output_length: int = st.number_input(
|
|
|
1 |
"""
|
2 |
The Streamlit app for the project demo.
|
3 |
+
In the demo, the user can write a prompt
|
4 |
+
and the model will generate a response using the grouped sampling algorithm.
|
5 |
"""
|
6 |
|
7 |
import streamlit as st
|
|
|
8 |
|
9 |
+
from hanlde_form_submit import on_form_submit
|
10 |
|
11 |
|
12 |
+
AVAILABLE_MODEL_NAMES = "https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
with st.form("request_form"):
|
16 |
selected_model_name: str = st.text_input(
|
17 |
label="Model name",
|
18 |
value="gpt2",
|
19 |
+
help=f"The name of the model to use."
|
20 |
+
f" Must be a model from this list:"
|
21 |
+
f" {AVAILABLE_MODEL_NAMES}"
|
22 |
)
|
23 |
|
24 |
output_length: int = st.number_input(
|
hanlde_form_submit.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from grouped_sampling import GroupedSamplingPipeLine
|
2 |
+
|
3 |
+
|
4 |
+
def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
|
5 |
+
"""
|
6 |
+
Creates a pipeline with the given model name and group size.
|
7 |
+
:param model_name: The name of the model to use.
|
8 |
+
:param group_size: The size of the groups to use.
|
9 |
+
:return: A pipeline with the given model name and group size.
|
10 |
+
"""
|
11 |
+
return GroupedSamplingPipeLine(
|
12 |
+
model_name=model_name,
|
13 |
+
group_size=group_size,
|
14 |
+
end_of_sentence_stop=True,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def on_form_submit(model_name: str, group_size: int, prompt: str) -> str:
|
19 |
+
"""
|
20 |
+
Called when the user submits the form.
|
21 |
+
:param model_name: The name of the model to use.
|
22 |
+
:param group_size: The size of the groups to use.
|
23 |
+
:param prompt: The prompt to use.
|
24 |
+
:return: The output of the model.
|
25 |
+
"""
|
26 |
+
pipeline = create_pipeline(
|
27 |
+
model_name,
|
28 |
+
group_size,
|
29 |
+
)
|
30 |
+
return pipeline(prompt)["generated_text"]
|