VTdockfast / main.py
Ashrafb's picture
Update main.py
92105c7 verified
raw
history blame
1.8 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
import shutil
import cv2
import numpy as np
import dlib
from torchvision import transforms
import torch.nn.functional as F
import gradio as gr
import os
import torch
from io import BytesIO
app = FastAPI()
# Load model and necessary components
model = None
def load_model():
global model
from vtoonify_model import Model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon1')
# Define endpoints
@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
global model
if model is None:
load_model()
# Read the uploaded image file
contents = await file.read()
# Convert the uploaded image to numpy array
nparr = np.frombuffer(contents, np.uint8)
frame_rgb = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Process the uploaded image
aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, top, bottom, left, right)
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
# Convert processed image to bytes
_, encoded_image = cv2.imencode('.jpg', processed_image)
# Return the processed image as a streaming response
return StreamingResponse(io.BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
# Mount static files directory
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
# Define index route
@app.get("/")
def index():
return FileResponse(path="/app/AB/index.html", media_type="text/html")