SunderAli17 commited on
Commit
b2fb80e
1 Parent(s): 6f1c8f2

Create api.py

Browse files
Files changed (1) hide show
  1. flux/api.py +192 -0
flux/api.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+ Args:
48
+ prompt: Prompt to sample
49
+ width: Width of the image in pixel
50
+ height: Height of the image in pixel
51
+ name: Name of the model
52
+ num_steps: Number of network evaluations
53
+ prompt_upsampling: Use prompt upsampling
54
+ seed: Fix the generation seed
55
+ validate: Run input validation
56
+ launch: Directly launches request
57
+ api_key: Your API key if not provided by the environment
58
+ Raises:
59
+ ValueError: For invalid input
60
+ ApiException: For errors raised from the API
61
+ """
62
+ if validate:
63
+ if name not in ["flux.1-pro"]:
64
+ raise ValueError(f"Invalid model {name}")
65
+ elif width % 32 != 0:
66
+ raise ValueError(f"width must be divisible by 32, got {width}")
67
+ elif not (256 <= width <= 1440):
68
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
69
+ elif height % 32 != 0:
70
+ raise ValueError(f"height must be divisible by 32, got {height}")
71
+ elif not (256 <= height <= 1440):
72
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
73
+ elif not (1 <= num_steps <= 50):
74
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
75
+
76
+ self.request_json = {
77
+ "prompt": prompt,
78
+ "width": width,
79
+ "height": height,
80
+ "variant": name,
81
+ "steps": num_steps,
82
+ "prompt_upsampling": prompt_upsampling,
83
+ }
84
+ if seed is not None:
85
+ self.request_json["seed"] = seed
86
+
87
+ self.request_id: str = None
88
+ self.result: dict = None
89
+ self._image_bytes: bytes = None
90
+ self._url: str = None
91
+ if api_key is None:
92
+ self.api_key = os.environ.get("BFL_API_KEY")
93
+ else:
94
+ self.api_key = api_key
95
+
96
+ if launch:
97
+ self.request()
98
+
99
+ def request(self):
100
+ """
101
+ Request to generate the image.
102
+ """
103
+ if self.request_id is not None:
104
+ return
105
+ response = requests.post(
106
+ f"{API_ENDPOINT}/v1/image",
107
+ headers={
108
+ "accept": "application/json",
109
+ "x-key": self.api_key,
110
+ "Content-Type": "application/json",
111
+ },
112
+ json=self.request_json,
113
+ )
114
+ result = response.json()
115
+ if response.status_code != 200:
116
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
117
+ self.request_id = response.json()["id"]
118
+
119
+ def retrieve(self) -> dict:
120
+ """
121
+ Wait for the generation to finish and retrieve response.
122
+ """
123
+ if self.request_id is None:
124
+ self.request()
125
+ while self.result is None:
126
+ response = requests.get(
127
+ f"{API_ENDPOINT}/v1/get_result",
128
+ headers={
129
+ "accept": "application/json",
130
+ "x-key": self.api_key,
131
+ },
132
+ params={
133
+ "id": self.request_id,
134
+ },
135
+ )
136
+ result = response.json()
137
+ if "status" not in result:
138
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
139
+ elif result["status"] == "Ready":
140
+ self.result = result["result"]
141
+ elif result["status"] == "Pending":
142
+ time.sleep(0.5)
143
+ else:
144
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
145
+ return self.result
146
+
147
+ @property
148
+ def bytes(self) -> bytes:
149
+ """
150
+ Generated image as bytes.
151
+ """
152
+ if self._image_bytes is None:
153
+ response = requests.get(self.url)
154
+ if response.status_code == 200:
155
+ self._image_bytes = response.content
156
+ else:
157
+ raise ApiException(status_code=response.status_code)
158
+ return self._image_bytes
159
+
160
+ @property
161
+ def url(self) -> str:
162
+ """
163
+ Public url to retrieve the image from
164
+ """
165
+ if self._url is None:
166
+ result = self.retrieve()
167
+ self._url = result["sample"]
168
+ return self._url
169
+
170
+ @property
171
+ def image(self) -> Image.Image:
172
+ """
173
+ Load the image as a PIL Image
174
+ """
175
+ return Image.open(io.BytesIO(self.bytes))
176
+
177
+ def save(self, path: str):
178
+ """
179
+ Save the generated image to a local path
180
+ """
181
+ suffix = Path(self.url).suffix
182
+ if not path.endswith(suffix):
183
+ path = path + suffix
184
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
185
+ with open(path, "wb") as file:
186
+ file.write(self.bytes)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ from fire import Fire
191
+
192
+ Fire(ImageRequest)