Spaces:
Sleeping
Sleeping
jwkirchenbauer
commited on
Commit
•
5b3e92c
1
Parent(s):
c3f4b90
API timeout incr to 60 sec
Browse files- demo_watermark.py +19 -10
demo_watermark.py
CHANGED
@@ -210,12 +210,13 @@ def load_model(args):
|
|
210 |
|
211 |
|
212 |
from text_generation import InferenceAPIClient
|
|
|
213 |
def generate_with_api(prompt, args):
|
214 |
hf_api_key = os.environ.get("HF_API_KEY")
|
215 |
if hf_api_key is None:
|
216 |
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
|
217 |
|
218 |
-
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key)
|
219 |
|
220 |
assert args.n_beams == 1, "HF API models do not support beam search."
|
221 |
generation_params = {
|
@@ -226,14 +227,22 @@ def generate_with_api(prompt, args):
|
|
226 |
generation_params["temperature"] = args.sampling_temp
|
227 |
generation_params["seed"] = args.generation_seed
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
return (output_text_without_watermark,
|
238 |
output_text_with_watermark)
|
239 |
|
@@ -737,7 +746,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
737 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
738 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
739 |
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
740 |
-
# When the parameters change, display the update and fire detection, since some detection params dont change the model output.
|
741 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
742 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
743 |
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
|
|
|
210 |
|
211 |
|
212 |
from text_generation import InferenceAPIClient
|
213 |
+
from requests.exceptions import ReadTimeout
|
214 |
def generate_with_api(prompt, args):
|
215 |
hf_api_key = os.environ.get("HF_API_KEY")
|
216 |
if hf_api_key is None:
|
217 |
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
|
218 |
|
219 |
+
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
|
220 |
|
221 |
assert args.n_beams == 1, "HF API models do not support beam search."
|
222 |
generation_params = {
|
|
|
227 |
generation_params["temperature"] = args.sampling_temp
|
228 |
generation_params["seed"] = args.generation_seed
|
229 |
|
230 |
+
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
|
231 |
+
try:
|
232 |
+
generation_params["watermark"] = False
|
233 |
+
output = client.generate(prompt, **generation_params)
|
234 |
+
output_text_without_watermark = output.generated_text
|
235 |
+
except ReadTimeout as e:
|
236 |
+
print(e)
|
237 |
+
output_text_without_watermark = timeout_msg
|
238 |
+
try:
|
239 |
+
generation_params["watermark"] = True
|
240 |
+
output = client.generate(prompt, **generation_params)
|
241 |
+
output_text_with_watermark = output.generated_text
|
242 |
+
except ReadTimeout as e:
|
243 |
+
print(e)
|
244 |
+
output_text_with_watermark = timeout_msg
|
245 |
+
|
246 |
return (output_text_without_watermark,
|
247 |
output_text_with_watermark)
|
248 |
|
|
|
746 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
747 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
748 |
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
749 |
+
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
|
750 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
751 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
752 |
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
|