File size: 4,349 Bytes
23ed1d9
 
 
 
a85b6f3
23ed1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b6f3
 
 
 
 
 
 
23ed1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b6f3
23ed1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b6f3
 
 
 
 
 
 
 
23ed1d9
 
 
 
a85b6f3
23ed1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import cv2
import numpy as np
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from urllib.parse import urlparse
import logging
import json
from io import BytesIO
from dataclasses import dataclass


@dataclass
class TryOnDiffusionAPIResponse:
    status_code: int
    image: np.ndarray = None
    response_data: bytes = None
    error_details: str = None
    seed: int = None


class TryOnDiffusionClient:
    def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""):
        self._logger = logging.getLogger("try_on_diffusion_client")
        self._base_url = base_url
        self._api_key = api_key

        if self._base_url[-1] == "/":
            self._base_url = self._base_url[:-1]

        parsed_url = urlparse(self._base_url)

        self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None

        if self._rapidapi_host is not None:
            self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}")

    @staticmethod
    def _image_to_upload_file(image: np.ndarray) -> tuple:
        _, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
        jpeg_data = jpeg_data.tobytes()

        fp = BytesIO(jpeg_data)

        return "image.jpg", fp, "image/jpeg"

    def try_on_file(
        self,
        clothing_image: np.ndarray = None,
        clothing_prompt: str = None,
        avatar_image: np.ndarray = None,
        avatar_prompt: str = None,
        avatar_sex: str = None,
        background_image: np.ndarray = None,
        background_prompt: str = None,
        seed: int = -1,
        raw_response: bool = False,
    ) -> TryOnDiffusionAPIResponse:
        url = self._base_url + "/try-on-file"

        request_data = {"seed": str(seed)}

        if clothing_image is not None:
            request_data["clothing_image"] = self._image_to_upload_file(clothing_image)

        if clothing_prompt is not None:
            request_data["clothing_prompt"] = clothing_prompt

        if avatar_image is not None:
            request_data["avatar_image"] = self._image_to_upload_file(avatar_image)

        if avatar_prompt is not None:
            request_data["avatar_prompt"] = avatar_prompt

        if avatar_sex is not None:
            request_data["avatar_sex"] = avatar_sex

        if background_image is not None:
            request_data["background_image"] = self._image_to_upload_file(background_image)

        if background_prompt is not None:
            request_data["background_prompt"] = background_prompt

        multipart_data = MultipartEncoder(fields=request_data)

        headers = {"Content-Type": multipart_data.content_type}

        if self._rapidapi_host is not None:
            headers["X-RapidAPI-Key"] = self._api_key
            headers["X-RapidAPI-Host"] = self._rapidapi_host
        else:
            headers["X-API-Key"] = self._api_key

        try:
            response = requests.post(
                url,
                data=multipart_data,
                headers=headers,
            )
        except Exception as e:
            self._logger.error(e, exc_info=True)
            return TryOnDiffusionAPIResponse(status_code=0)

        if response.status_code != 200:
            self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")

        result = TryOnDiffusionAPIResponse(status_code=response.status_code)

        if not raw_response and response.status_code == 200:
            try:
                result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
            except:
                result.image = None
        else:
            result.response_data = response.content

        if result.status_code == 200:
            if "X-Seed" in response.headers:
                result.seed = int(response.headers["X-Seed"])
        else:
            try:
                response_json = (
                    json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
                )

                if response_json is not None and "detail" in response_json:
                    result.error_details = response_json["detail"]
            except:
                result.error_details = None

        return result