Spaces:
Paused
Paused
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import asyncio
|
3 |
+
from aiohttp import web, WSMsgType
|
4 |
+
import json
|
5 |
+
from json import JSONEncoder
|
6 |
+
import numpy as np
|
7 |
+
import uuid
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import signal
|
11 |
+
from typing import Dict, Any, List, Optional
|
12 |
+
import base64
|
13 |
+
import io
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
import pillow_avif
|
18 |
+
|
19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
def SIGSEGV_signal_arises(signalNum, stack):
|
23 |
+
logger.critical(f"{signalNum} : SIGSEGV arises")
|
24 |
+
logger.critical(f"Stack trace: {stack}")
|
25 |
+
|
26 |
+
signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises)
|
27 |
+
|
28 |
+
from loader import initialize_models
|
29 |
+
from engine import Engine, base64_data_uri_to_PIL_Image
|
30 |
+
|
31 |
+
# Global constants
|
32 |
+
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
33 |
+
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
34 |
+
|
35 |
+
class NumpyEncoder(json.JSONEncoder):
|
36 |
+
def default(self, obj):
|
37 |
+
if isinstance(obj, np.integer):
|
38 |
+
return int(obj)
|
39 |
+
elif isinstance(obj, np.floating):
|
40 |
+
return float(obj)
|
41 |
+
elif isinstance(obj, np.ndarray):
|
42 |
+
return obj.tolist()
|
43 |
+
else:
|
44 |
+
return super(NumpyEncoder, self).default(obj)
|
45 |
+
|
46 |
+
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
47 |
+
ws = web.WebSocketResponse()
|
48 |
+
await ws.prepare(request)
|
49 |
+
engine = request.app['engine']
|
50 |
+
try:
|
51 |
+
#logger.info("New WebSocket connection established")
|
52 |
+
while True:
|
53 |
+
msg = await ws.receive()
|
54 |
+
|
55 |
+
if msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
56 |
+
#logger.warning(f"WebSocket connection closed: {msg.type}")
|
57 |
+
break
|
58 |
+
|
59 |
+
try:
|
60 |
+
if msg.type == WSMsgType.BINARY:
|
61 |
+
res = await engine.load_image(msg.data)
|
62 |
+
json_res = json.dumps(res, cls=NumpyEncoder)
|
63 |
+
await ws.send_str(json_res)
|
64 |
+
|
65 |
+
elif msg.type == WSMsgType.TEXT:
|
66 |
+
data = json.loads(msg.data)
|
67 |
+
webp_bytes = await engine.transform_image(data.get('uuid'), data.get('params'))
|
68 |
+
await ws.send_bytes(webp_bytes)
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Error in engine: {str(e)}")
|
72 |
+
logger.exception("Full traceback:")
|
73 |
+
await ws.send_json({"error": str(e)})
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Error in websocket_handler: {str(e)}")
|
77 |
+
logger.exception("Full traceback:")
|
78 |
+
return ws
|
79 |
+
|
80 |
+
async def handle_upload(request: web.Request) -> web.Response:
|
81 |
+
"""Recebe uma imagem e retorna informações sobre ela."""
|
82 |
+
engine = request.app['engine']
|
83 |
+
data = await request.content.read()
|
84 |
+
res = await engine.load_image(data)
|
85 |
+
return web.json_response(res)
|
86 |
+
|
87 |
+
async def handle_modify(request: web.Request) -> web.Response:
|
88 |
+
"""Recebe uma imagem e retorna informações sobre ela."""
|
89 |
+
engine = request.app['engine']
|
90 |
+
data = await request.json()
|
91 |
+
webp_bytes = await engine.transform_image(data.get('uuid'), data.get('params'))
|
92 |
+
return web.Response(body=webp_bytes, content_type="image/webp")
|
93 |
+
|
94 |
+
async def initialize_app() -> web.Application:
|
95 |
+
"""Initialize and configure the web application."""
|
96 |
+
try:
|
97 |
+
logger.info("Initializing application...")
|
98 |
+
live_portrait = await initialize_models()
|
99 |
+
|
100 |
+
logger.info("🚀 Creating Engine instance...")
|
101 |
+
engine = Engine(live_portrait=live_portrait)
|
102 |
+
logger.info("✅ Engine instance created.")
|
103 |
+
|
104 |
+
app = web.Application()
|
105 |
+
app['engine'] = engine
|
106 |
+
|
107 |
+
# Configure routes
|
108 |
+
app.router.add_post("/upload", handle_upload)
|
109 |
+
app.router.add_post("/modify", handle_modify)
|
110 |
+
|
111 |
+
logger.info("Application routes configured")
|
112 |
+
|
113 |
+
return app
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"🚨 Error during application initialization: {str(e)}")
|
116 |
+
logger.exception("Full traceback:")
|
117 |
+
raise
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
try:
|
121 |
+
logger.info("Starting FacePoke application")
|
122 |
+
app = asyncio.run(initialize_app())
|
123 |
+
logger.info("Application initialized, starting web server")
|
124 |
+
web.run_app(app, host="0.0.0.0", port=8080)
|
125 |
+
except Exception as e:
|
126 |
+
logger.critical(f"🚨 FATAL: Failed to start the app: {str(e)}")
|
127 |
+
logger.exception("Full traceback:")
|