from fastapi import FastAPI, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from typing import Optional from PIL import Image import io import base64 from gradio_client import Client import uvicorn app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace with your frontend URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def encode_image_to_base64(image_bytes): """Convert image bytes to base64 string""" return base64.b64encode(image_bytes).decode('utf-8') def process_uploaded_image(file: UploadFile) -> bytes: """Process uploaded image and return bytes""" contents = file.file.read() image = Image.open(io.BytesIO(contents)) # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') # Save to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() return img_byte_arr @app.post("/api/face-swap") async def face_swap( source_file: UploadFile = File(...), target_file: UploadFile = File(...), do_face_enhancer: Optional[bool] = Form(True) ): try: # Process uploaded images source_bytes = process_uploaded_image(source_file) target_bytes = process_uploaded_image(target_file) # Initialize Gradio client client = Client("tuan2308/face-swap") # Make prediction result = await client.predict( source_bytes, # Source image target_bytes, # Target image do_face_enhancer, # Face enhancement option api_name="/predict" ) # Process result if isinstance(result, bytes): # If result is already bytes, encode to base64 result_base64 = encode_image_to_base64(result) else: # If result is a path or other format, handle accordingly # You might need to adjust this based on the actual return type with open(result, 'rb') as f: result_base64 = encode_image_to_base64(f.read()) return { "status": "success", "image": f"data:image/png;base64,{result_base64}" } except Exception as e: return { "status": "error", "message": str(e) } # Health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)