Spaces:
Runtime error
Runtime error
File size: 4,983 Bytes
cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b ff9325e cb92d2b 43148fd cb92d2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import logging
import traceback
from config import Args
from user_queue import UserDataEventMap, UserDataEvent
import uuid
from asyncio import Event, sleep
import time
from PIL import Image
import io
from types import SimpleNamespace
def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("Init app", app)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
if args.max_queue_size > 0 and len(user_data_events) >= args.max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
return
try:
uid = str(uuid.uuid4())
print(f"New user connected: {uid}")
await websocket.send_json(
{"status": "success", "message": "Connected", "userId": uid}
)
user_data_events[uid] = UserDataEvent()
print(f"User data events: {user_data_events}")
await websocket.send_json(
{"status": "start", "message": "Start Streaming", "userId": uid}
)
await handle_websocket_data(websocket, uid)
except WebSocketDisconnect as e:
logging.error(f"WebSocket Error: {e}, {uid}")
traceback.print_exc()
finally:
print(f"User disconnected: {uid}")
del user_data_events[uid]
@app.get("/queue_size")
async def get_queue_size():
queue_size = len(user_data_events)
return JSONResponse({"queue_size": queue_size})
@app.get("/stream/{user_id}")
async def stream(user_id: uuid.UUID):
uid = str(user_id)
try:
async def generate():
last_prompt: str = None
while True:
data = await user_data_events[uid].wait_for_data()
params = data["params"]
# input_image = data["image"]
# if input_image is None:
# continue
image = pipeline.predict(params)
if image is None:
continue
frame_data = io.BytesIO()
image.save(frame_data, format="JPEG")
frame_data = frame_data.getvalue()
if frame_data is not None and len(frame_data) > 0:
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
await sleep(1.0 / 120.0)
return StreamingResponse(
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
)
except Exception as e:
logging.error(f"Streaming Error: {e}, {user_data_events}")
traceback.print_exc()
return HTTPException(status_code=404, detail="User not found")
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
uid = str(user_id)
if uid not in user_data_events:
return HTTPException(status_code=404, detail="User not found")
last_time = time.time()
try:
while True:
params = await websocket.receive_json()
params = pipeline.InputParams(**params)
params = SimpleNamespace(**params.dict())
if hasattr(params, "image"):
image_data = await websocket.receive_bytes()
pil_image = Image.open(io.BytesIO(image_data))
params.image = pil_image
user_data_events[uid].update_data({"params": params})
if args.timeout > 0 and time.time() - last_time > args.timeout:
await websocket.send_json(
{
"status": "timeout",
"message": "Your session has ended",
"userId": uid,
}
)
await websocket.close()
return
except Exception as e:
logging.error(f"Error: {e}")
traceback.print_exc()
# route to setup frontend
@app.get("/settings")
async def settings():
info = pipeline.Info.schema()
input_params = pipeline.InputParams.schema()
return JSONResponse({"info": info, "input_params": input_params})
app.mount("/", StaticFiles(directory="public", html=True), name="public")
|