Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, Form, File, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from gradio_client import Client, file | |
import aiofiles | |
import os | |
import shutil | |
import base64 | |
import traceback | |
app = FastAPI() | |
# Allow CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# client = Client("yisol/IDM-VTON") | |
client = Client("tuan2308/IDM-VTON") | |
# client = Client("kadirnar/IDM-VTON") | |
# Directory to save uploaded and processed files | |
UPLOAD_FOLDER = 'static/uploads' | |
RESULT_FOLDER = 'static/results' | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULT_FOLDER, exist_ok=True) | |
async def hello(): | |
return {"Wearon": "wearon model is running"} | |
async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)): | |
try: | |
if not model_image: | |
raise HTTPException(status_code=400, detail="No model image file provided") | |
# Save the uploaded file to the upload directory | |
filename = os.path.join(UPLOAD_FOLDER, model_image.filename) | |
async with aiofiles.open(filename, "wb") as buffer: | |
content = await model_image.read() | |
await buffer.write(content) | |
base_path = os.getcwd() | |
full_filename = os.path.normpath(os.path.join(base_path, filename)) | |
print("Product image =", product_image_url) | |
print("Model image =", full_filename) | |
# Perform prediction | |
try: | |
result = client.predict( | |
dict={"background": file(full_filename), "layers": [], "composite": None}, | |
garm_img=file(product_image_url), | |
garment_des="Hello!!", | |
is_checked=True, | |
is_checked_crop=False, | |
denoise_steps=30, | |
seed=42, | |
api_name="/tryon" | |
) | |
except Exception as e: | |
traceback.print_exc() | |
raise | |
print(result) | |
# Extract the path of the first output image | |
output_image_path = result[0] | |
# Copy the output image to the RESULT_FOLDER | |
output_image_filename = os.path.basename(output_image_path) | |
local_output_path = os.path.join(RESULT_FOLDER, output_image_filename) | |
shutil.copy(output_image_path, local_output_path) | |
# Remove the uploaded file after processing | |
os.remove(filename) | |
# Encode the output image in base64 | |
async with aiofiles.open(local_output_path, "rb") as image_file: | |
encoded_image = base64.b64encode(await image_file.read()).decode('utf-8') | |
# Return the output image in JSON format | |
return JSONResponse(content={"image": encoded_image}, status_code=200) | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def uploaded_file(filename: str): | |
file_path = os.path.join(UPLOAD_FOLDER, filename) | |
if os.path.exists(file_path): | |
return FileResponse(file_path) | |
else: | |
raise HTTPException(status_code=404, detail="File not found") | |