yonikremer commited on
Commit
c9089bd
1 Parent(s): 30f253f

Improver user messages

Browse files
Files changed (1) hide show
  1. hanlde_form_submit.py +27 -5
hanlde_form_submit.py CHANGED
@@ -14,13 +14,34 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
14
  :param group_size: The size of the groups to use.
15
  :return: A pipeline with the given model name and group size.
16
  """
17
- return GroupedSamplingPipeLine(
 
18
  model_name=model_name,
19
  group_size=group_size,
20
  end_of_sentence_stop=True,
21
  temp=0.5,
22
  top_p=0.6,
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  @st.cache
@@ -40,7 +61,8 @@ def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
40
  model_name=model_name,
41
  group_size=output_length,
42
  )
43
- return pipeline(
44
- prompt_s=prompt,
45
- max_new_tokens=output_length,
46
- )["generated_text"]
 
 
14
  :param group_size: The size of the groups to use.
15
  :return: A pipeline with the given model name and group size.
16
  """
17
+ print(f"Starts downloading model: {model_name} from the internet.")
18
+ pipeline = GroupedSamplingPipeLine(
19
  model_name=model_name,
20
  group_size=group_size,
21
  end_of_sentence_stop=True,
22
  temp=0.5,
23
  top_p=0.6,
24
  )
25
+ print(f"Finished downloading model: {model_name} from the internet.")
26
+ return pipeline
27
+
28
+
29
+ def generate_text(
30
+ pipeline: GroupedSamplingPipeLine,
31
+ prompt: str,
32
+ output_length: int,
33
+ ) -> str:
34
+ """
35
+ Generates text using the given pipeline.
36
+ :param pipeline: The pipeline to use.
37
+ :param prompt: The prompt to use.
38
+ :param output_length: The size of the groups to use.
39
+ :return: The generated text.
40
+ """
41
+ return pipeline(
42
+ prompt_s=prompt,
43
+ max_new_tokens=output_length,
44
+ )["generated_text"]
45
 
46
 
47
  @st.cache
 
61
  model_name=model_name,
62
  group_size=output_length,
63
  )
64
+ return generate_text(
65
+ pipeline=pipeline,
66
+ prompt=prompt,
67
+ output_length=output_length,
68
+ )