yonikremer
commited on
Commit
•
c9089bd
1
Parent(s):
30f253f
Improver user messages
Browse files- hanlde_form_submit.py +27 -5
hanlde_form_submit.py
CHANGED
@@ -14,13 +14,34 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
|
|
14 |
:param group_size: The size of the groups to use.
|
15 |
:return: A pipeline with the given model name and group size.
|
16 |
"""
|
17 |
-
|
|
|
18 |
model_name=model_name,
|
19 |
group_size=group_size,
|
20 |
end_of_sentence_stop=True,
|
21 |
temp=0.5,
|
22 |
top_p=0.6,
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
@st.cache
|
@@ -40,7 +61,8 @@ def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
|
|
40 |
model_name=model_name,
|
41 |
group_size=output_length,
|
42 |
)
|
43 |
-
return
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
14 |
:param group_size: The size of the groups to use.
|
15 |
:return: A pipeline with the given model name and group size.
|
16 |
"""
|
17 |
+
print(f"Starts downloading model: {model_name} from the internet.")
|
18 |
+
pipeline = GroupedSamplingPipeLine(
|
19 |
model_name=model_name,
|
20 |
group_size=group_size,
|
21 |
end_of_sentence_stop=True,
|
22 |
temp=0.5,
|
23 |
top_p=0.6,
|
24 |
)
|
25 |
+
print(f"Finished downloading model: {model_name} from the internet.")
|
26 |
+
return pipeline
|
27 |
+
|
28 |
+
|
29 |
+
def generate_text(
|
30 |
+
pipeline: GroupedSamplingPipeLine,
|
31 |
+
prompt: str,
|
32 |
+
output_length: int,
|
33 |
+
) -> str:
|
34 |
+
"""
|
35 |
+
Generates text using the given pipeline.
|
36 |
+
:param pipeline: The pipeline to use.
|
37 |
+
:param prompt: The prompt to use.
|
38 |
+
:param output_length: The size of the groups to use.
|
39 |
+
:return: The generated text.
|
40 |
+
"""
|
41 |
+
return pipeline(
|
42 |
+
prompt_s=prompt,
|
43 |
+
max_new_tokens=output_length,
|
44 |
+
)["generated_text"]
|
45 |
|
46 |
|
47 |
@st.cache
|
|
|
61 |
model_name=model_name,
|
62 |
group_size=output_length,
|
63 |
)
|
64 |
+
return generate_text(
|
65 |
+
pipeline=pipeline,
|
66 |
+
prompt=prompt,
|
67 |
+
output_length=output_length,
|
68 |
+
)
|