File size: 4,349 Bytes
23ed1d9 a85b6f3 23ed1d9 a85b6f3 23ed1d9 a85b6f3 23ed1d9 a85b6f3 23ed1d9 a85b6f3 23ed1d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import cv2
import numpy as np
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from urllib.parse import urlparse
import logging
import json
from io import BytesIO
from dataclasses import dataclass
@dataclass
class TryOnDiffusionAPIResponse:
status_code: int
image: np.ndarray = None
response_data: bytes = None
error_details: str = None
seed: int = None
class TryOnDiffusionClient:
def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""):
self._logger = logging.getLogger("try_on_diffusion_client")
self._base_url = base_url
self._api_key = api_key
if self._base_url[-1] == "/":
self._base_url = self._base_url[:-1]
parsed_url = urlparse(self._base_url)
self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None
if self._rapidapi_host is not None:
self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}")
@staticmethod
def _image_to_upload_file(image: np.ndarray) -> tuple:
_, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
jpeg_data = jpeg_data.tobytes()
fp = BytesIO(jpeg_data)
return "image.jpg", fp, "image/jpeg"
def try_on_file(
self,
clothing_image: np.ndarray = None,
clothing_prompt: str = None,
avatar_image: np.ndarray = None,
avatar_prompt: str = None,
avatar_sex: str = None,
background_image: np.ndarray = None,
background_prompt: str = None,
seed: int = -1,
raw_response: bool = False,
) -> TryOnDiffusionAPIResponse:
url = self._base_url + "/try-on-file"
request_data = {"seed": str(seed)}
if clothing_image is not None:
request_data["clothing_image"] = self._image_to_upload_file(clothing_image)
if clothing_prompt is not None:
request_data["clothing_prompt"] = clothing_prompt
if avatar_image is not None:
request_data["avatar_image"] = self._image_to_upload_file(avatar_image)
if avatar_prompt is not None:
request_data["avatar_prompt"] = avatar_prompt
if avatar_sex is not None:
request_data["avatar_sex"] = avatar_sex
if background_image is not None:
request_data["background_image"] = self._image_to_upload_file(background_image)
if background_prompt is not None:
request_data["background_prompt"] = background_prompt
multipart_data = MultipartEncoder(fields=request_data)
headers = {"Content-Type": multipart_data.content_type}
if self._rapidapi_host is not None:
headers["X-RapidAPI-Key"] = self._api_key
headers["X-RapidAPI-Host"] = self._rapidapi_host
else:
headers["X-API-Key"] = self._api_key
try:
response = requests.post(
url,
data=multipart_data,
headers=headers,
)
except Exception as e:
self._logger.error(e, exc_info=True)
return TryOnDiffusionAPIResponse(status_code=0)
if response.status_code != 200:
self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")
result = TryOnDiffusionAPIResponse(status_code=response.status_code)
if not raw_response and response.status_code == 200:
try:
result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
except:
result.image = None
else:
result.response_data = response.content
if result.status_code == 200:
if "X-Seed" in response.headers:
result.seed = int(response.headers["X-Seed"])
else:
try:
response_json = (
json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
)
if response_json is not None and "detail" in response_json:
result.error_details = response_json["detail"]
except:
result.error_details = None
return result
|