ajv009 commited on
Commit
558490f
β€’
1 Parent(s): 4d2b94e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -2
app.py CHANGED
@@ -5,12 +5,53 @@ import time
5
  import os
6
  import spaces
7
  import torch
 
 
8
 
9
  zero = torch.Tensor([0]).cuda()
10
  print(zero.device) # <-- 'cpu' πŸ€”
11
 
12
  names = ['prompt', 'negative_prompt', 'subject', 'number_of_outputs', 'number_of_images_per_pose', 'randomise_poses', 'output_format', 'output_quality', 'seed']
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @spaces.GPU
15
  def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
16
  print(zero.device) # <-- 'cuda:0' πŸ€—
@@ -18,7 +59,6 @@ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
18
 
19
  payload = {"input": {}}
20
 
21
-
22
  base_url = "http://0.0.0.0:7860"
23
  for i, key in enumerate(names):
24
  value = args[i]
@@ -29,7 +69,6 @@ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
29
 
30
  response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
31
 
32
-
33
  if response.status_code == 201:
34
  follow_up_url = response.json()["urls"]["get"]
35
  response = requests.get(follow_up_url, headers=headers)
@@ -132,4 +171,7 @@ with gr.Blocks(css=css) as app:
132
  show_api = False
133
  )
134
 
 
 
 
135
  app.queue(max_size=12, api_open=False).launch(share=False, show_api=False)
 
5
  import os
6
  import spaces
7
  import torch
8
+ import subprocess
9
+ import signal
10
 
11
  zero = torch.Tensor([0]).cuda()
12
  print(zero.device) # <-- 'cpu' πŸ€”
13
 
14
  names = ['prompt', 'negative_prompt', 'subject', 'number_of_outputs', 'number_of_images_per_pose', 'randomise_poses', 'output_format', 'output_quality', 'seed']
15
 
16
+ def check_cog_server():
17
+ try:
18
+ # Start the Cog server in the background
19
+ cog_process = subprocess.Popen(["python3", "-m", "cog.server.http", "--threads=10"], cwd="/src")
20
+
21
+ # Wait for the Cog server to start on port 5000
22
+ counter1 = 0
23
+ while True:
24
+ try:
25
+ requests.get("http://localhost:5000")
26
+ print("Cog server is running on port 5000.")
27
+ break
28
+ except requests.exceptions.ConnectionError:
29
+ print("Waiting for Cog server to start on port 5000...")
30
+ time.sleep(5)
31
+ counter1 += 1
32
+ if counter1 >= 250:
33
+ raise Exception("Error: Cog server did not start on port 5000 after 250 attempts.")
34
+
35
+ # Wait for the Cog server to be fully ready
36
+ counter2 = 0
37
+ while True:
38
+ response = requests.get("http://localhost:5000/health-check")
39
+ status = response.json().get("status")
40
+ if status == "READY":
41
+ print("Cog server is fully ready.")
42
+ break
43
+ else:
44
+ print("Waiting for Cog server (models loading) on port 5000...")
45
+ time.sleep(5)
46
+ counter2 += 1
47
+ if counter2 >= 250:
48
+ raise Exception("Error: Cog server did not become fully ready after 250 attempts.")
49
+
50
+ except Exception as e:
51
+ print(f"Error: {str(e)}")
52
+ cog_process.send_signal(signal.SIGINT) # Send interrupt signal to the Cog process
53
+ raise e
54
+
55
  @spaces.GPU
56
  def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
57
  print(zero.device) # <-- 'cuda:0' πŸ€—
 
59
 
60
  payload = {"input": {}}
61
 
 
62
  base_url = "http://0.0.0.0:7860"
63
  for i, key in enumerate(names):
64
  value = args[i]
 
69
 
70
  response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
71
 
 
72
  if response.status_code == 201:
73
  follow_up_url = response.json()["urls"]["get"]
74
  response = requests.get(follow_up_url, headers=headers)
 
171
  show_api = False
172
  )
173
 
174
+ # Check the Cog server's readiness before launching the Gradio app
175
+ check_cog_server()
176
+
177
  app.queue(max_size=12, api_open=False).launch(share=False, show_api=False)