|
import cv2 |
|
import numpy as np |
|
import requests |
|
from requests_toolbelt.multipart.encoder import MultipartEncoder |
|
from urllib.parse import urlparse |
|
import logging |
|
import json |
|
from io import BytesIO |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class TryOnDiffusionAPIResponse: |
|
status_code: int |
|
image: np.ndarray = None |
|
response_data: bytes = None |
|
error_details: str = None |
|
seed: int = None |
|
|
|
|
|
class TryOnDiffusionClient: |
|
def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""): |
|
self._logger = logging.getLogger("try_on_diffusion_client") |
|
self._base_url = base_url |
|
self._api_key = api_key |
|
|
|
if self._base_url[-1] == "/": |
|
self._base_url = self._base_url[:-1] |
|
|
|
parsed_url = urlparse(self._base_url) |
|
|
|
self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None |
|
|
|
if self._rapidapi_host is not None: |
|
self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}") |
|
|
|
@staticmethod |
|
def _image_to_upload_file(image: np.ndarray) -> tuple: |
|
_, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99]) |
|
jpeg_data = jpeg_data.tobytes() |
|
|
|
fp = BytesIO(jpeg_data) |
|
|
|
return "image.jpg", fp, "image/jpeg" |
|
|
|
def try_on_file( |
|
self, |
|
clothing_image: np.ndarray = None, |
|
clothing_prompt: str = None, |
|
avatar_image: np.ndarray = None, |
|
avatar_prompt: str = None, |
|
avatar_sex: str = None, |
|
background_image: np.ndarray = None, |
|
background_prompt: str = None, |
|
seed: int = -1, |
|
raw_response: bool = False, |
|
) -> TryOnDiffusionAPIResponse: |
|
url = self._base_url + "/try-on-file" |
|
|
|
request_data = {"seed": str(seed)} |
|
|
|
if clothing_image is not None: |
|
request_data["clothing_image"] = self._image_to_upload_file(clothing_image) |
|
|
|
if clothing_prompt is not None: |
|
request_data["clothing_prompt"] = clothing_prompt |
|
|
|
if avatar_image is not None: |
|
request_data["avatar_image"] = self._image_to_upload_file(avatar_image) |
|
|
|
if avatar_prompt is not None: |
|
request_data["avatar_prompt"] = avatar_prompt |
|
|
|
if avatar_sex is not None: |
|
request_data["avatar_sex"] = avatar_sex |
|
|
|
if background_image is not None: |
|
request_data["background_image"] = self._image_to_upload_file(background_image) |
|
|
|
if background_prompt is not None: |
|
request_data["background_prompt"] = background_prompt |
|
|
|
multipart_data = MultipartEncoder(fields=request_data) |
|
|
|
headers = {"Content-Type": multipart_data.content_type} |
|
|
|
if self._rapidapi_host is not None: |
|
headers["X-RapidAPI-Key"] = self._api_key |
|
headers["X-RapidAPI-Host"] = self._rapidapi_host |
|
else: |
|
headers["X-API-Key"] = self._api_key |
|
|
|
try: |
|
response = requests.post( |
|
url, |
|
data=multipart_data, |
|
headers=headers, |
|
) |
|
except Exception as e: |
|
self._logger.error(e, exc_info=True) |
|
return TryOnDiffusionAPIResponse(status_code=0) |
|
|
|
if response.status_code != 200: |
|
self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}") |
|
|
|
result = TryOnDiffusionAPIResponse(status_code=response.status_code) |
|
|
|
if not raw_response and response.status_code == 200: |
|
try: |
|
result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR) |
|
except: |
|
result.image = None |
|
else: |
|
result.response_data = response.content |
|
|
|
if result.status_code == 200: |
|
if "X-Seed" in response.headers: |
|
result.seed = int(response.headers["X-Seed"]) |
|
else: |
|
try: |
|
response_json = ( |
|
json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None |
|
) |
|
|
|
if response_json is not None and "detail" in response_json: |
|
result.error_details = response_json["detail"] |
|
except: |
|
result.error_details = None |
|
|
|
return result |
|
|