Spaces:
Running
on
A100
Running
on
A100
better app structure
Browse files- app.py +0 -17
- app_init.py +0 -163
- config.py +2 -1
- connection_manager.py +116 -0
- frontend/src/lib/lcmLive.ts +9 -6
- frontend/src/routes/+page.svelte +3 -3
- main.py +184 -0
- run.py +0 -12
- user_queue.py +0 -63
- util.py +0 -2
app.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
-
|
3 |
-
from config import args
|
4 |
-
from device import device, torch_dtype
|
5 |
-
from app_init import init_app
|
6 |
-
from user_queue import user_data
|
7 |
-
from util import get_pipeline_class
|
8 |
-
|
9 |
-
print("DEVICE:", device)
|
10 |
-
print("TORCH_DTYPE:", torch_dtype)
|
11 |
-
args.pretty_print()
|
12 |
-
|
13 |
-
app = FastAPI()
|
14 |
-
|
15 |
-
pipeline_class = get_pipeline_class(args.pipeline)
|
16 |
-
pipeline = pipeline_class(args, device, torch_dtype)
|
17 |
-
init_app(app, user_data, args, pipeline)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_init.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
|
2 |
-
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
-
from fastapi.middleware.cors import CORSMiddleware
|
4 |
-
from fastapi.staticfiles import StaticFiles
|
5 |
-
from fastapi import Request
|
6 |
-
import markdown2
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import traceback
|
10 |
-
from config import Args
|
11 |
-
from user_queue import UserData
|
12 |
-
import uuid
|
13 |
-
import time
|
14 |
-
from types import SimpleNamespace
|
15 |
-
from util import pil_to_frame, bytes_to_pil, is_firefox
|
16 |
-
import asyncio
|
17 |
-
import os
|
18 |
-
import time
|
19 |
-
|
20 |
-
THROTTLE = 1.0 / 120
|
21 |
-
|
22 |
-
|
23 |
-
def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
|
24 |
-
app.add_middleware(
|
25 |
-
CORSMiddleware,
|
26 |
-
allow_origins=["*"],
|
27 |
-
allow_credentials=True,
|
28 |
-
allow_methods=["*"],
|
29 |
-
allow_headers=["*"],
|
30 |
-
)
|
31 |
-
|
32 |
-
@app.websocket("/api/ws")
|
33 |
-
async def websocket_endpoint(websocket: WebSocket):
|
34 |
-
await websocket.accept()
|
35 |
-
user_count = user_data.get_user_count()
|
36 |
-
if args.max_queue_size > 0 and user_count >= args.max_queue_size:
|
37 |
-
print("Server is full")
|
38 |
-
await websocket.send_json({"status": "error", "message": "Server is full"})
|
39 |
-
await websocket.close()
|
40 |
-
return
|
41 |
-
try:
|
42 |
-
user_id = uuid.uuid4()
|
43 |
-
print(f"New user connected: {user_id}")
|
44 |
-
await user_data.create_user(user_id, websocket)
|
45 |
-
await websocket.send_json(
|
46 |
-
{"status": "connected", "message": "Connected", "userId": str(user_id)}
|
47 |
-
)
|
48 |
-
await websocket.send_json({"status": "send_frame"})
|
49 |
-
await handle_websocket_data(user_id, websocket)
|
50 |
-
except WebSocketDisconnect as e:
|
51 |
-
logging.error(f"WebSocket Error: {e}, {user_id}")
|
52 |
-
traceback.print_exc()
|
53 |
-
finally:
|
54 |
-
print(f"User disconnected: {user_id}")
|
55 |
-
user_data.delete_user(user_id)
|
56 |
-
|
57 |
-
async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
|
58 |
-
if not user_data.check_user(user_id):
|
59 |
-
return HTTPException(status_code=404, detail="User not found")
|
60 |
-
last_time = time.time()
|
61 |
-
try:
|
62 |
-
while True:
|
63 |
-
if args.timeout > 0 and time.time() - last_time > args.timeout:
|
64 |
-
await websocket.send_json(
|
65 |
-
{
|
66 |
-
"status": "timeout",
|
67 |
-
"message": "Your session has ended",
|
68 |
-
"userId": str(user_id),
|
69 |
-
}
|
70 |
-
)
|
71 |
-
await websocket.close()
|
72 |
-
return
|
73 |
-
data = await websocket.receive_json()
|
74 |
-
if data["status"] != "next_frame":
|
75 |
-
asyncio.sleep(THROTTLE)
|
76 |
-
continue
|
77 |
-
|
78 |
-
params = await websocket.receive_json()
|
79 |
-
params = pipeline.InputParams(**params)
|
80 |
-
info = pipeline.Info()
|
81 |
-
params = SimpleNamespace(**params.dict())
|
82 |
-
if info.input_mode == "image":
|
83 |
-
image_data = await websocket.receive_bytes()
|
84 |
-
if len(image_data) == 0:
|
85 |
-
await websocket.send_json({"status": "send_frame"})
|
86 |
-
await asyncio.sleep(THROTTLE)
|
87 |
-
continue
|
88 |
-
params.image = bytes_to_pil(image_data)
|
89 |
-
await user_data.update_data(user_id, params)
|
90 |
-
await websocket.send_json({"status": "wait"})
|
91 |
-
|
92 |
-
except Exception as e:
|
93 |
-
logging.error(f"Error: {e}")
|
94 |
-
traceback.print_exc()
|
95 |
-
|
96 |
-
@app.get("/api/queue")
|
97 |
-
async def get_queue_size():
|
98 |
-
queue_size = user_data.get_user_count()
|
99 |
-
return JSONResponse({"queue_size": queue_size})
|
100 |
-
|
101 |
-
@app.get("/api/stream/{user_id}")
|
102 |
-
async def stream(user_id: uuid.UUID, request: Request):
|
103 |
-
try:
|
104 |
-
|
105 |
-
async def generate():
|
106 |
-
websocket = user_data.get_websocket(user_id)
|
107 |
-
last_params = SimpleNamespace()
|
108 |
-
while True:
|
109 |
-
last_time = time.time()
|
110 |
-
params = await user_data.get_latest_data(user_id)
|
111 |
-
if not vars(params) or params.__dict__ == last_params.__dict__:
|
112 |
-
await websocket.send_json({"status": "send_frame"})
|
113 |
-
await asyncio.sleep(THROTTLE)
|
114 |
-
continue
|
115 |
-
|
116 |
-
last_params = params
|
117 |
-
image = pipeline.predict(params)
|
118 |
-
|
119 |
-
if image is None:
|
120 |
-
await websocket.send_json({"status": "send_frame"})
|
121 |
-
await asyncio.sleep(THROTTLE)
|
122 |
-
continue
|
123 |
-
frame = pil_to_frame(image)
|
124 |
-
yield frame
|
125 |
-
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
126 |
-
if not is_firefox(request.headers["user-agent"]):
|
127 |
-
yield frame
|
128 |
-
await websocket.send_json({"status": "send_frame"})
|
129 |
-
if args.debug:
|
130 |
-
print(f"Time taken: {time.time() - last_time}")
|
131 |
-
|
132 |
-
return StreamingResponse(
|
133 |
-
generate(),
|
134 |
-
media_type="multipart/x-mixed-replace;boundary=frame",
|
135 |
-
headers={"Cache-Control": "no-cache"},
|
136 |
-
)
|
137 |
-
except Exception as e:
|
138 |
-
logging.error(f"Streaming Error: {e}, {user_id} ")
|
139 |
-
traceback.print_exc()
|
140 |
-
return HTTPException(status_code=404, detail="User not found")
|
141 |
-
|
142 |
-
# route to setup frontend
|
143 |
-
@app.get("/api/settings")
|
144 |
-
async def settings():
|
145 |
-
info_schema = pipeline.Info.schema()
|
146 |
-
info = pipeline.Info()
|
147 |
-
if info.page_content:
|
148 |
-
page_content = markdown2.markdown(info.page_content)
|
149 |
-
|
150 |
-
input_params = pipeline.InputParams.schema()
|
151 |
-
return JSONResponse(
|
152 |
-
{
|
153 |
-
"info": info_schema,
|
154 |
-
"input_params": input_params,
|
155 |
-
"max_queue_size": args.max_queue_size,
|
156 |
-
"page_content": page_content if info.page_content else "",
|
157 |
-
}
|
158 |
-
)
|
159 |
-
|
160 |
-
if not os.path.exists("public"):
|
161 |
-
os.makedirs("public")
|
162 |
-
|
163 |
-
app.mount("/", StaticFiles(directory="public", html=True), name="public")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.py
CHANGED
@@ -124,4 +124,5 @@ parser.add_argument(
|
|
124 |
)
|
125 |
parser.set_defaults(taesd=USE_TAESD)
|
126 |
|
127 |
-
|
|
|
|
124 |
)
|
125 |
parser.set_defaults(taesd=USE_TAESD)
|
126 |
|
127 |
+
config = Args(**vars(parser.parse_args()))
|
128 |
+
config.pretty_print()
|
connection_manager.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Union
|
2 |
+
from uuid import UUID
|
3 |
+
import asyncio
|
4 |
+
from fastapi import WebSocket
|
5 |
+
from starlette.websockets import WebSocketState
|
6 |
+
import logging
|
7 |
+
from types import SimpleNamespace
|
8 |
+
|
9 |
+
Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
|
10 |
+
|
11 |
+
|
12 |
+
class ServerFullException(Exception):
|
13 |
+
"""Exception raised when the server is full."""
|
14 |
+
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class ConnectionManager:
|
19 |
+
def __init__(self):
|
20 |
+
self.active_connections: Connections = {}
|
21 |
+
|
22 |
+
async def connect(
|
23 |
+
self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
|
24 |
+
):
|
25 |
+
await websocket.accept()
|
26 |
+
user_count = self.get_user_count()
|
27 |
+
print(f"User count: {user_count}")
|
28 |
+
if max_queue_size > 0 and user_count >= max_queue_size:
|
29 |
+
print("Server is full")
|
30 |
+
await websocket.send_json({"status": "error", "message": "Server is full"})
|
31 |
+
await websocket.close()
|
32 |
+
raise ServerFullException("Server is full")
|
33 |
+
print(f"New user connected: {user_id}")
|
34 |
+
self.active_connections[user_id] = {
|
35 |
+
"websocket": websocket,
|
36 |
+
"queue": asyncio.Queue(),
|
37 |
+
}
|
38 |
+
await websocket.send_json(
|
39 |
+
{"status": "connected", "message": "Connected"},
|
40 |
+
)
|
41 |
+
await websocket.send_json({"status": "wait"})
|
42 |
+
await websocket.send_json({"status": "send_frame"})
|
43 |
+
|
44 |
+
def check_user(self, user_id: UUID) -> bool:
|
45 |
+
return user_id in self.active_connections
|
46 |
+
|
47 |
+
async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
|
48 |
+
user_session = self.active_connections.get(user_id)
|
49 |
+
if user_session:
|
50 |
+
queue = user_session["queue"]
|
51 |
+
while not queue.empty():
|
52 |
+
try:
|
53 |
+
queue.get_nowait()
|
54 |
+
except asyncio.QueueEmpty:
|
55 |
+
continue
|
56 |
+
await queue.put(new_data)
|
57 |
+
|
58 |
+
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
59 |
+
user_session = self.active_connections.get(user_id)
|
60 |
+
if user_session:
|
61 |
+
queue = user_session["queue"]
|
62 |
+
try:
|
63 |
+
return await queue.get()
|
64 |
+
except asyncio.QueueEmpty:
|
65 |
+
return None
|
66 |
+
|
67 |
+
def delete_user(self, user_id: UUID):
|
68 |
+
user_session = self.active_connections.pop(user_id, None)
|
69 |
+
if user_session:
|
70 |
+
queue = user_session["queue"]
|
71 |
+
while not queue.empty():
|
72 |
+
try:
|
73 |
+
queue.get_nowait()
|
74 |
+
except asyncio.QueueEmpty:
|
75 |
+
continue
|
76 |
+
|
77 |
+
def get_user_count(self) -> int:
|
78 |
+
return len(self.active_connections)
|
79 |
+
|
80 |
+
def get_websocket(self, user_id: UUID) -> WebSocket:
|
81 |
+
user_session = self.active_connections.get(user_id)
|
82 |
+
if user_session:
|
83 |
+
websocket = user_session["websocket"]
|
84 |
+
if websocket.client_state == WebSocketState.CONNECTED:
|
85 |
+
return user_session["websocket"]
|
86 |
+
return None
|
87 |
+
|
88 |
+
async def disconnect(self, user_id: UUID):
|
89 |
+
websocket = self.get_websocket(user_id)
|
90 |
+
if websocket:
|
91 |
+
await websocket.close()
|
92 |
+
self.delete_user(user_id)
|
93 |
+
|
94 |
+
async def send_json(self, user_id: UUID, data: Dict):
|
95 |
+
try:
|
96 |
+
websocket = self.get_websocket(user_id)
|
97 |
+
if websocket:
|
98 |
+
await websocket.send_json(data)
|
99 |
+
except Exception as e:
|
100 |
+
logging.error(f"Error: Send json: {e}")
|
101 |
+
|
102 |
+
async def receive_json(self, user_id: UUID) -> Dict:
|
103 |
+
try:
|
104 |
+
websocket = self.get_websocket(user_id)
|
105 |
+
if websocket:
|
106 |
+
return await websocket.receive_json()
|
107 |
+
except Exception as e:
|
108 |
+
logging.error(f"Error: Receive json: {e}")
|
109 |
+
|
110 |
+
async def receive_bytes(self, user_id: UUID) -> bytes:
|
111 |
+
try:
|
112 |
+
websocket = self.get_websocket(user_id)
|
113 |
+
if websocket:
|
114 |
+
return await websocket.receive_bytes()
|
115 |
+
except Exception as e:
|
116 |
+
logging.error(f"Error: Receive bytes: {e}")
|
frontend/src/lib/lcmLive.ts
CHANGED
@@ -6,6 +6,7 @@ export enum LCMLiveStatus {
|
|
6 |
DISCONNECTED = "disconnected",
|
7 |
WAIT = "wait",
|
8 |
SEND_FRAME = "send_frame",
|
|
|
9 |
}
|
10 |
|
11 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
@@ -19,8 +20,9 @@ export const lcmLiveActions = {
|
|
19 |
return new Promise((resolve, reject) => {
|
20 |
|
21 |
try {
|
|
|
22 |
const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
|
23 |
-
}:${window.location.host}/api/ws`;
|
24 |
|
25 |
websocket = new WebSocket(websocketURL);
|
26 |
websocket.onopen = () => {
|
@@ -37,9 +39,9 @@ export const lcmLiveActions = {
|
|
37 |
const data = JSON.parse(event.data);
|
38 |
switch (data.status) {
|
39 |
case "connected":
|
40 |
-
const userId = data.userId;
|
41 |
lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
|
42 |
streamId.set(userId);
|
|
|
43 |
break;
|
44 |
case "send_frame":
|
45 |
lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
|
@@ -54,14 +56,16 @@ export const lcmLiveActions = {
|
|
54 |
break;
|
55 |
case "timeout":
|
56 |
console.log("timeout");
|
57 |
-
lcmLiveStatus.set(LCMLiveStatus.
|
58 |
streamId.set(null);
|
59 |
-
|
|
|
60 |
case "error":
|
61 |
console.log(data.message);
|
62 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
63 |
streamId.set(null);
|
64 |
reject(new Error(data.message));
|
|
|
65 |
}
|
66 |
};
|
67 |
|
@@ -85,12 +89,11 @@ export const lcmLiveActions = {
|
|
85 |
}
|
86 |
},
|
87 |
async stop() {
|
88 |
-
|
89 |
if (websocket) {
|
90 |
websocket.close();
|
91 |
}
|
92 |
websocket = null;
|
93 |
-
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
94 |
streamId.set(null);
|
95 |
},
|
96 |
};
|
|
|
6 |
DISCONNECTED = "disconnected",
|
7 |
WAIT = "wait",
|
8 |
SEND_FRAME = "send_frame",
|
9 |
+
TIMEOUT = "timeout",
|
10 |
}
|
11 |
|
12 |
const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
|
|
|
20 |
return new Promise((resolve, reject) => {
|
21 |
|
22 |
try {
|
23 |
+
const userId = crypto.randomUUID();
|
24 |
const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
|
25 |
+
}:${window.location.host}/api/ws/${userId}`;
|
26 |
|
27 |
websocket = new WebSocket(websocketURL);
|
28 |
websocket.onopen = () => {
|
|
|
39 |
const data = JSON.parse(event.data);
|
40 |
switch (data.status) {
|
41 |
case "connected":
|
|
|
42 |
lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
|
43 |
streamId.set(userId);
|
44 |
+
resolve({ status: "connected", userId });
|
45 |
break;
|
46 |
case "send_frame":
|
47 |
lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
|
|
|
56 |
break;
|
57 |
case "timeout":
|
58 |
console.log("timeout");
|
59 |
+
lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
|
60 |
streamId.set(null);
|
61 |
+
reject(new Error("timeout"));
|
62 |
+
break;
|
63 |
case "error":
|
64 |
console.log(data.message);
|
65 |
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
66 |
streamId.set(null);
|
67 |
reject(new Error(data.message));
|
68 |
+
break;
|
69 |
}
|
70 |
};
|
71 |
|
|
|
89 |
}
|
90 |
},
|
91 |
async stop() {
|
92 |
+
lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
|
93 |
if (websocket) {
|
94 |
websocket.close();
|
95 |
}
|
96 |
websocket = null;
|
|
|
97 |
streamId.set(null);
|
98 |
},
|
99 |
};
|
frontend/src/routes/+page.svelte
CHANGED
@@ -20,7 +20,6 @@
|
|
20 |
let currentQueueSize: number = 0;
|
21 |
let queueCheckerRunning: boolean = false;
|
22 |
let warningMessage: string = '';
|
23 |
-
|
24 |
onMount(() => {
|
25 |
getSettings();
|
26 |
});
|
@@ -59,7 +58,9 @@
|
|
59 |
}
|
60 |
|
61 |
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
|
62 |
-
|
|
|
|
|
63 |
let disabled = false;
|
64 |
async function toggleLcmLive() {
|
65 |
try {
|
@@ -70,7 +71,6 @@
|
|
70 |
}
|
71 |
disabled = true;
|
72 |
await lcmLiveActions.start(getSreamdata);
|
73 |
-
warningMessage = 'Timeout, please try again.';
|
74 |
disabled = false;
|
75 |
toggleQueueChecker(false);
|
76 |
} else {
|
|
|
20 |
let currentQueueSize: number = 0;
|
21 |
let queueCheckerRunning: boolean = false;
|
22 |
let warningMessage: string = '';
|
|
|
23 |
onMount(() => {
|
24 |
getSettings();
|
25 |
});
|
|
|
58 |
}
|
59 |
|
60 |
$: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
|
61 |
+
$: if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
|
62 |
+
warningMessage = 'Session timed out. Please try again.';
|
63 |
+
}
|
64 |
let disabled = false;
|
65 |
async function toggleLcmLive() {
|
66 |
try {
|
|
|
71 |
}
|
72 |
disabled = true;
|
73 |
await lcmLiveActions.start(getSreamdata);
|
|
|
74 |
disabled = false;
|
75 |
toggleQueueChecker(false);
|
76 |
} else {
|
main.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
from fastapi import Request
|
6 |
+
import markdown2
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from config import config, Args
|
10 |
+
from connection_manager import ConnectionManager
|
11 |
+
import uuid
|
12 |
+
import time
|
13 |
+
from types import SimpleNamespace
|
14 |
+
from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class
|
15 |
+
from device import device, torch_dtype
|
16 |
+
import asyncio
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import torch
|
20 |
+
|
21 |
+
|
22 |
+
THROTTLE = 1.0 / 120
|
23 |
+
|
24 |
+
|
25 |
+
class App:
|
26 |
+
def __init__(self, config: Args, pipeline):
|
27 |
+
self.args = config
|
28 |
+
self.pipeline = pipeline
|
29 |
+
self.app = FastAPI()
|
30 |
+
self.conn_manager = ConnectionManager()
|
31 |
+
self.init_app()
|
32 |
+
|
33 |
+
def init_app(self):
|
34 |
+
self.app.add_middleware(
|
35 |
+
CORSMiddleware,
|
36 |
+
allow_origins=["*"],
|
37 |
+
allow_credentials=True,
|
38 |
+
allow_methods=["*"],
|
39 |
+
allow_headers=["*"],
|
40 |
+
)
|
41 |
+
|
42 |
+
@self.app.websocket("/api/ws/{user_id}")
|
43 |
+
async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
|
44 |
+
try:
|
45 |
+
await self.conn_manager.connect(
|
46 |
+
user_id, websocket, self.args.max_queue_size
|
47 |
+
)
|
48 |
+
await handle_websocket_data(user_id)
|
49 |
+
except ServerFullException as e:
|
50 |
+
logging.error(f"Server Full: {e}")
|
51 |
+
finally:
|
52 |
+
await self.conn_manager.disconnect(user_id)
|
53 |
+
logging.info(f"User disconnected: {user_id}")
|
54 |
+
|
55 |
+
async def handle_websocket_data(user_id: uuid.UUID):
|
56 |
+
if not self.conn_manager.check_user(user_id):
|
57 |
+
return HTTPException(status_code=404, detail="User not found")
|
58 |
+
last_time = time.time()
|
59 |
+
try:
|
60 |
+
while True:
|
61 |
+
if (
|
62 |
+
self.args.timeout > 0
|
63 |
+
and time.time() - last_time > self.args.timeout
|
64 |
+
):
|
65 |
+
await self.conn_manager.send_json(
|
66 |
+
user_id,
|
67 |
+
{
|
68 |
+
"status": "timeout",
|
69 |
+
"message": "Your session has ended",
|
70 |
+
},
|
71 |
+
)
|
72 |
+
await self.conn_manager.disconnect(user_id)
|
73 |
+
return
|
74 |
+
data = await self.conn_manager.receive_json(user_id)
|
75 |
+
if data["status"] != "next_frame":
|
76 |
+
asyncio.sleep(THROTTLE)
|
77 |
+
continue
|
78 |
+
|
79 |
+
params = await self.conn_manager.receive_json(user_id)
|
80 |
+
params = pipeline.InputParams(**params)
|
81 |
+
info = pipeline.Info()
|
82 |
+
params = SimpleNamespace(**params.dict())
|
83 |
+
if info.input_mode == "image":
|
84 |
+
image_data = await self.conn_manager.receive_bytes(user_id)
|
85 |
+
if len(image_data) == 0:
|
86 |
+
await self.conn_manager.send_json(
|
87 |
+
user_id, {"status": "send_frame"}
|
88 |
+
)
|
89 |
+
await asyncio.sleep(THROTTLE)
|
90 |
+
continue
|
91 |
+
params.image = bytes_to_pil(image_data)
|
92 |
+
await self.conn_manager.update_data(user_id, params)
|
93 |
+
await self.conn_manager.send_json(user_id, {"status": "wait"})
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
logging.error(f"Websocket Error: {e}, {user_id} ")
|
97 |
+
await self.conn_manager.disconnect(user_id)
|
98 |
+
|
99 |
+
@self.app.get("/api/queue")
|
100 |
+
async def get_queue_size():
|
101 |
+
queue_size = self.conn_manager.get_user_count()
|
102 |
+
return JSONResponse({"queue_size": queue_size})
|
103 |
+
|
104 |
+
@self.app.get("/api/stream/{user_id}")
|
105 |
+
async def stream(user_id: uuid.UUID, request: Request):
|
106 |
+
try:
|
107 |
+
|
108 |
+
async def generate():
|
109 |
+
last_params = SimpleNamespace()
|
110 |
+
while True:
|
111 |
+
last_time = time.time()
|
112 |
+
params = await self.conn_manager.get_latest_data(user_id)
|
113 |
+
if not vars(params) or params.__dict__ == last_params.__dict__:
|
114 |
+
await self.conn_manager.send_json(
|
115 |
+
user_id, {"status": "send_frame"}
|
116 |
+
)
|
117 |
+
continue
|
118 |
+
|
119 |
+
last_params = params
|
120 |
+
image = pipeline.predict(params)
|
121 |
+
if image is None:
|
122 |
+
await self.conn_manager.send_json(
|
123 |
+
user_id, {"status": "send_frame"}
|
124 |
+
)
|
125 |
+
continue
|
126 |
+
frame = pil_to_frame(image)
|
127 |
+
yield frame
|
128 |
+
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
|
129 |
+
if not is_firefox(request.headers["user-agent"]):
|
130 |
+
yield frame
|
131 |
+
await self.conn_manager.send_json(
|
132 |
+
user_id, {"status": "send_frame"}
|
133 |
+
)
|
134 |
+
if self.args.debug:
|
135 |
+
print(f"Time taken: {time.time() - last_time}")
|
136 |
+
|
137 |
+
return StreamingResponse(
|
138 |
+
generate(),
|
139 |
+
media_type="multipart/x-mixed-replace;boundary=frame",
|
140 |
+
headers={"Cache-Control": "no-cache"},
|
141 |
+
)
|
142 |
+
except Exception as e:
|
143 |
+
logging.error(f"Streaming Error: {e}, {user_id} ")
|
144 |
+
return HTTPException(status_code=404, detail="User not found")
|
145 |
+
|
146 |
+
# route to setup frontend
|
147 |
+
@self.app.get("/api/settings")
|
148 |
+
async def settings():
|
149 |
+
info_schema = pipeline.Info.schema()
|
150 |
+
info = pipeline.Info()
|
151 |
+
if info.page_content:
|
152 |
+
page_content = markdown2.markdown(info.page_content)
|
153 |
+
|
154 |
+
input_params = pipeline.InputParams.schema()
|
155 |
+
return JSONResponse(
|
156 |
+
{
|
157 |
+
"info": info_schema,
|
158 |
+
"input_params": input_params,
|
159 |
+
"max_queue_size": self.args.max_queue_size,
|
160 |
+
"page_content": page_content if info.page_content else "",
|
161 |
+
}
|
162 |
+
)
|
163 |
+
|
164 |
+
if not os.path.exists("public"):
|
165 |
+
os.makedirs("public")
|
166 |
+
|
167 |
+
self.app.mount("/", StaticFiles(directory="public", html=True), name="public")
|
168 |
+
|
169 |
+
|
170 |
+
pipeline_class = get_pipeline_class(config.pipeline)
|
171 |
+
pipeline = pipeline_class(config, device, torch_dtype)
|
172 |
+
app = App(config, pipeline).app
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
import uvicorn
|
176 |
+
|
177 |
+
uvicorn.run(
|
178 |
+
"main:app",
|
179 |
+
host=config.host,
|
180 |
+
port=config.port,
|
181 |
+
reload=config.reload,
|
182 |
+
ssl_certfile=config.ssl_certfile,
|
183 |
+
ssl_keyfile=config.ssl_keyfile,
|
184 |
+
)
|
run.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
if __name__ == "__main__":
|
2 |
-
import uvicorn
|
3 |
-
from config import args
|
4 |
-
|
5 |
-
uvicorn.run(
|
6 |
-
"app:app",
|
7 |
-
host=args.host,
|
8 |
-
port=args.port,
|
9 |
-
reload=args.reload,
|
10 |
-
ssl_certfile=args.ssl_certfile,
|
11 |
-
ssl_keyfile=args.ssl_keyfile,
|
12 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_queue.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
from typing import Dict
|
2 |
-
from uuid import UUID
|
3 |
-
import asyncio
|
4 |
-
from fastapi import WebSocket
|
5 |
-
from types import SimpleNamespace
|
6 |
-
from typing import Dict
|
7 |
-
from typing import Union
|
8 |
-
|
9 |
-
UserDataContent = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
|
10 |
-
|
11 |
-
|
12 |
-
class UserData:
|
13 |
-
def __init__(self):
|
14 |
-
self.data_content: Dict[UUID, UserDataContent] = {}
|
15 |
-
|
16 |
-
async def create_user(self, user_id: UUID, websocket: WebSocket):
|
17 |
-
self.data_content[user_id] = {
|
18 |
-
"websocket": websocket,
|
19 |
-
"queue": asyncio.Queue(),
|
20 |
-
}
|
21 |
-
await asyncio.sleep(1)
|
22 |
-
|
23 |
-
def check_user(self, user_id: UUID) -> bool:
|
24 |
-
return user_id in self.data_content
|
25 |
-
|
26 |
-
async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
|
27 |
-
user_session = self.data_content[user_id]
|
28 |
-
queue = user_session["queue"]
|
29 |
-
while not queue.empty():
|
30 |
-
try:
|
31 |
-
queue.get_nowait()
|
32 |
-
except asyncio.QueueEmpty:
|
33 |
-
continue
|
34 |
-
await queue.put(new_data)
|
35 |
-
|
36 |
-
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
|
37 |
-
user_session = self.data_content[user_id]
|
38 |
-
queue = user_session["queue"]
|
39 |
-
|
40 |
-
try:
|
41 |
-
return await queue.get()
|
42 |
-
except asyncio.QueueEmpty:
|
43 |
-
return None
|
44 |
-
|
45 |
-
def delete_user(self, user_id: UUID):
|
46 |
-
user_session = self.data_content[user_id]
|
47 |
-
queue = user_session["queue"]
|
48 |
-
while not queue.empty():
|
49 |
-
try:
|
50 |
-
queue.get_nowait()
|
51 |
-
except asyncio.QueueEmpty:
|
52 |
-
continue
|
53 |
-
if user_id in self.data_content:
|
54 |
-
del self.data_content[user_id]
|
55 |
-
|
56 |
-
def get_user_count(self) -> int:
|
57 |
-
return len(self.data_content)
|
58 |
-
|
59 |
-
def get_websocket(self, user_id: UUID) -> WebSocket:
|
60 |
-
return self.data_content[user_id]["websocket"]
|
61 |
-
|
62 |
-
|
63 |
-
user_data = UserData()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from importlib import import_module
|
2 |
from types import ModuleType
|
3 |
-
from typing import Dict, Any
|
4 |
-
from pydantic import BaseModel as PydanticBaseModel, Field
|
5 |
from PIL import Image
|
6 |
import io
|
7 |
|
|
|
1 |
from importlib import import_module
|
2 |
from types import ModuleType
|
|
|
|
|
3 |
from PIL import Image
|
4 |
import io
|
5 |
|