ncncncn / app.py
Yussifweb3's picture
Create app.py
3657bdc verified
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from PIL import Image
import io
import base64
from gradio_client import Client
import uvicorn
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def encode_image_to_base64(image_bytes):
"""Convert image bytes to base64 string"""
return base64.b64encode(image_bytes).decode('utf-8')
def process_uploaded_image(file: UploadFile) -> bytes:
"""Process uploaded image and return bytes"""
contents = file.file.read()
image = Image.open(io.BytesIO(contents))
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Save to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr
@app.post("/api/face-swap")
async def face_swap(
source_file: UploadFile = File(...),
target_file: UploadFile = File(...),
do_face_enhancer: Optional[bool] = Form(True)
):
try:
# Process uploaded images
source_bytes = process_uploaded_image(source_file)
target_bytes = process_uploaded_image(target_file)
# Initialize Gradio client
client = Client("tuan2308/face-swap")
# Make prediction
result = await client.predict(
source_bytes, # Source image
target_bytes, # Target image
do_face_enhancer, # Face enhancement option
api_name="/predict"
)
# Process result
if isinstance(result, bytes):
# If result is already bytes, encode to base64
result_base64 = encode_image_to_base64(result)
else:
# If result is a path or other format, handle accordingly
# You might need to adjust this based on the actual return type
with open(result, 'rb') as f:
result_base64 = encode_image_to_base64(f.read())
return {
"status": "success",
"image": f"data:image/png;base64,{result_base64}"
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)