Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Optional | |
from PIL import Image | |
import io | |
import base64 | |
from gradio_client import Client | |
import uvicorn | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, replace with your frontend URL | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def encode_image_to_base64(image_bytes): | |
"""Convert image bytes to base64 string""" | |
return base64.b64encode(image_bytes).decode('utf-8') | |
def process_uploaded_image(file: UploadFile) -> bytes: | |
"""Process uploaded image and return bytes""" | |
contents = file.file.read() | |
image = Image.open(io.BytesIO(contents)) | |
# Convert to RGB if necessary | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Save to bytes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
return img_byte_arr | |
async def face_swap( | |
source_file: UploadFile = File(...), | |
target_file: UploadFile = File(...), | |
do_face_enhancer: Optional[bool] = Form(True) | |
): | |
try: | |
# Process uploaded images | |
source_bytes = process_uploaded_image(source_file) | |
target_bytes = process_uploaded_image(target_file) | |
# Initialize Gradio client | |
client = Client("tuan2308/face-swap") | |
# Make prediction | |
result = await client.predict( | |
source_bytes, # Source image | |
target_bytes, # Target image | |
do_face_enhancer, # Face enhancement option | |
api_name="/predict" | |
) | |
# Process result | |
if isinstance(result, bytes): | |
# If result is already bytes, encode to base64 | |
result_base64 = encode_image_to_base64(result) | |
else: | |
# If result is a path or other format, handle accordingly | |
# You might need to adjust this based on the actual return type | |
with open(result, 'rb') as f: | |
result_base64 = encode_image_to_base64(f.read()) | |
return { | |
"status": "success", | |
"image": f"data:image/png;base64,{result_base64}" | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"message": str(e) | |
} | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |