File size: 2,671 Bytes
3657bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from PIL import Image
import io
import base64
from gradio_client import Client
import uvicorn

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def encode_image_to_base64(image_bytes):
    """Convert image bytes to base64 string"""
    return base64.b64encode(image_bytes).decode('utf-8')

def process_uploaded_image(file: UploadFile) -> bytes:
    """Process uploaded image and return bytes"""
    contents = file.file.read()
    image = Image.open(io.BytesIO(contents))
    
    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Save to bytes
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()
    
    return img_byte_arr

@app.post("/api/face-swap")
async def face_swap(
    source_file: UploadFile = File(...),
    target_file: UploadFile = File(...),
    do_face_enhancer: Optional[bool] = Form(True)
):
    try:
        # Process uploaded images
        source_bytes = process_uploaded_image(source_file)
        target_bytes = process_uploaded_image(target_file)
        
        # Initialize Gradio client
        client = Client("tuan2308/face-swap")
        
        # Make prediction
        result = await client.predict(
            source_bytes,           # Source image
            target_bytes,           # Target image
            do_face_enhancer,       # Face enhancement option
            api_name="/predict"
        )
        
        # Process result
        if isinstance(result, bytes):
            # If result is already bytes, encode to base64
            result_base64 = encode_image_to_base64(result)
        else:
            # If result is a path or other format, handle accordingly
            # You might need to adjust this based on the actual return type
            with open(result, 'rb') as f:
                result_base64 = encode_image_to_base64(f.read())
        
        return {
            "status": "success",
            "image": f"data:image/png;base64,{result_base64}"
        }
        
    except Exception as e:
        return {
            "status": "error",
            "message": str(e)
        }

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)