Files changed (1) hide show
  1. app.py +32 -61
app.py CHANGED
@@ -8,19 +8,6 @@ import re
8
  hf_token = os.environ.get("HF_TOKEN")
9
  from gradio_client import Client
10
 
11
-
12
- def safety_check(user_prompt):
13
-
14
- client = Client("fffiloni/safety-checker-bot", hf_token=hf_token)
15
- response = client.predict(
16
- source_space="consistent-character space",
17
- user_prompt=user_prompt,
18
- api_name="/infer"
19
- )
20
- print(response)
21
-
22
- return response
23
-
24
  from utils.gradio_helpers import parse_outputs, process_outputs
25
 
26
  names = ['prompt', 'negative_prompt', 'subject', 'number_of_outputs', 'number_of_images_per_pose', 'randomise_poses', 'output_format', 'output_quality', 'seed']
@@ -34,57 +21,42 @@ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
34
  raise gr.Error(f"You forgot to provide a prompt.")
35
 
36
  try:
 
37
 
38
- is_safe = safety_check(args[0])
39
- print(is_safe)
40
-
41
- match = re.search(r'\bYes\b', is_safe)
42
-
43
- if match:
44
- status = 'Yes'
45
- else:
46
- status = None
47
-
48
- if status == "Yes" :
49
- raise gr.Error("Do not ask for such things.")
50
- else:
51
-
52
- headers = {'Content-Type': 'application/json'}
53
-
54
- payload = {"input": {}}
55
-
56
-
57
- base_url = "http://0.0.0.0:7860"
58
- for i, key in enumerate(names):
59
- value = args[i]
60
- if value and (os.path.exists(str(value))):
61
- value = f"{base_url}/file=" + value
62
- if value is not None and value != "":
63
- payload["input"][key] = value
64
-
65
- response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
66
 
 
67
 
68
- if response.status_code == 201:
69
- follow_up_url = response.json()["urls"]["get"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  response = requests.get(follow_up_url, headers=headers)
71
- while response.json()["status"] != "succeeded":
72
- if response.json()["status"] == "failed":
73
- raise gr.Error("The submission failed!")
74
- response = requests.get(follow_up_url, headers=headers)
75
- time.sleep(1)
76
- if response.status_code == 200:
77
- json_response = response.json()
78
- #If the output component is JSON return the entire output response
79
- if(outputs[0].get_config()["name"] == "json"):
80
- return json_response["output"]
81
- predict_outputs = parse_outputs(json_response["output"])
82
- processed_outputs = process_outputs(predict_outputs)
83
- return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
84
- else:
85
- if(response.status_code == 409):
86
- raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
87
- raise gr.Error(f"The submission failed! Error: {response.status_code}")
88
 
89
  except Exception as e:
90
  # Handle any other type of error
@@ -173,4 +145,3 @@ with gr.Blocks(css=css) as app:
173
  )
174
 
175
  app.queue(max_size=12, api_open=False).launch(share=False, show_api=False, show_error=True)
176
-
 
8
  hf_token = os.environ.get("HF_TOKEN")
9
  from gradio_client import Client
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from utils.gradio_helpers import parse_outputs, process_outputs
12
 
13
  names = ['prompt', 'negative_prompt', 'subject', 'number_of_outputs', 'number_of_images_per_pose', 'randomise_poses', 'output_format', 'output_quality', 'seed']
 
21
  raise gr.Error(f"You forgot to provide a prompt.")
22
 
23
  try:
24
+ # Safety check is removed here
25
 
26
+ headers = {'Content-Type': 'application/json'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ payload = {"input": {}}
29
 
30
+ base_url = "http://0.0.0.0:7860"
31
+ for i, key in enumerate(names):
32
+ value = args[i]
33
+ if value and (os.path.exists(str(value))):
34
+ value = f"{base_url}/file=" + value
35
+ if value is not None and value != "":
36
+ payload["input"][key] = value
37
+
38
+ response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
39
+
40
+ if response.status_code == 201:
41
+ follow_up_url = response.json()["urls"]["get"]
42
+ response = requests.get(follow_up_url, headers=headers)
43
+ while response.json()["status"] != "succeeded":
44
+ if response.json()["status"] == "failed":
45
+ raise gr.Error("The submission failed!")
46
  response = requests.get(follow_up_url, headers=headers)
47
+ time.sleep(1)
48
+ if response.status_code == 200:
49
+ json_response = response.json()
50
+ # If the output component is JSON return the entire output response
51
+ if(outputs[0].get_config()["name"] == "json"):
52
+ return json_response["output"]
53
+ predict_outputs = parse_outputs(json_response["output"])
54
+ processed_outputs = process_outputs(predict_outputs)
55
+ return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
56
+ else:
57
+ if(response.status_code == 409):
58
+ raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
59
+ raise gr.Error(f"The submission failed! Error: {response.status_code}")
 
 
 
 
60
 
61
  except Exception as e:
62
  # Handle any other type of error
 
145
  )
146
 
147
  app.queue(max_size=12, api_open=False).launch(share=False, show_api=False, show_error=True)