yonikremer commited on
Commit
0499581
·
1 Parent(s): 6bcf2e3

added a checkbox that can disable the web search

Browse files
Files changed (2) hide show
  1. app.py +12 -1
  2. hanlde_form_submit.py +14 -2
app.py CHANGED
@@ -40,6 +40,12 @@ with st.form("request_form"):
40
  max_chars=2048,
41
  )
42
 
 
 
 
 
 
 
43
  submitted: bool = st.form_submit_button(
44
  label="Generate",
45
  help="Generate the output text.",
@@ -48,7 +54,12 @@ with st.form("request_form"):
48
 
49
  if submitted:
50
  try:
51
- output = on_form_submit(selected_model_name, output_length, submitted_prompt)
 
 
 
 
 
52
  except CudaError as e:
53
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
54
  except (ValueError, TypeError, RuntimeError) as e:
 
40
  max_chars=2048,
41
  )
42
 
43
+ web_search: bool = st.checkbox(
44
+ label="Web search",
45
+ value=True,
46
+ help="If checked, the model will get your prompt as well as some web search results."
47
+ )
48
+
49
  submitted: bool = st.form_submit_button(
50
  label="Generate",
51
  help="Generate the output text.",
 
54
 
55
  if submitted:
56
  try:
57
+ output = on_form_submit(
58
+ selected_model_name,
59
+ output_length,
60
+ submitted_prompt,
61
+ web_search,
62
+ )
63
  except CudaError as e:
64
  st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.")
65
  except (ValueError, TypeError, RuntimeError) as e:
hanlde_form_submit.py CHANGED
@@ -52,15 +52,20 @@ def generate_text(
52
  pipeline: GroupedSamplingPipeLine,
53
  prompt: str,
54
  output_length: int,
 
55
  ) -> str:
56
  """
57
  Generates text using the given pipeline.
58
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
59
  :param prompt: The prompt to use. str.
60
  :param output_length: The size of the text to generate in tokens. int > 0.
 
61
  :return: The generated text. str.
62
  """
63
- better_prompt = rewrite_prompt(prompt)
 
 
 
64
  return pipeline(
65
  prompt_s=better_prompt,
66
  max_new_tokens=output_length,
@@ -69,12 +74,18 @@ def generate_text(
69
  )["generated_text"]
70
 
71
 
72
- def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
 
 
 
 
 
73
  """
74
  Called when the user submits the form.
75
  :param model_name: The name of the model to use.
76
  :param output_length: The size of the groups to use.
77
  :param prompt: The prompt to use.
 
78
  :return: The output of the model.
79
  :raises ValueError: If the model name is not supported, the output length is <= 0,
80
  the prompt is empty or longer than
@@ -99,6 +110,7 @@ def on_form_submit(model_name: str, output_length: int, prompt: str) -> str:
99
  pipeline=pipeline,
100
  prompt=prompt,
101
  output_length=output_length,
 
102
  )
103
  generation_end_time = time()
104
  generation_time = generation_end_time - generation_start_time
 
52
  pipeline: GroupedSamplingPipeLine,
53
  prompt: str,
54
  output_length: int,
55
+ web_search: bool,
56
  ) -> str:
57
  """
58
  Generates text using the given pipeline.
59
  :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
60
  :param prompt: The prompt to use. str.
61
  :param output_length: The size of the text to generate in tokens. int > 0.
62
+ :param web_search: Whether to use web search or not. bool.
63
  :return: The generated text. str.
64
  """
65
+ if web_search:
66
+ better_prompt = rewrite_prompt(prompt)
67
+ else:
68
+ better_prompt = prompt
69
  return pipeline(
70
  prompt_s=better_prompt,
71
  max_new_tokens=output_length,
 
74
  )["generated_text"]
75
 
76
 
77
+ def on_form_submit(
78
+ model_name: str,
79
+ output_length: int,
80
+ prompt: str,
81
+ web_search: bool
82
+ ) -> str:
83
  """
84
  Called when the user submits the form.
85
  :param model_name: The name of the model to use.
86
  :param output_length: The size of the groups to use.
87
  :param prompt: The prompt to use.
88
+ :param web_search: Whether to use web search or not.
89
  :return: The output of the model.
90
  :raises ValueError: If the model name is not supported, the output length is <= 0,
91
  the prompt is empty or longer than
 
110
  pipeline=pipeline,
111
  prompt=prompt,
112
  output_length=output_length,
113
+ web_search=web_search,
114
  )
115
  generation_end_time = time()
116
  generation_time = generation_end_time - generation_start_time