Spaces:
Sleeping
Sleeping
Commit
·
869e26b
1
Parent(s):
69f10d4
refactor: cto for better optimization
Browse files- src/api/nto_api.py +30 -16
src/api/nto_api.py
CHANGED
@@ -148,7 +148,6 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
148 |
mask.read()
|
149 |
)
|
150 |
|
151 |
-
# Convert bytes to PIL Images
|
152 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
153 |
mask = Image.open(BytesIO(mask_data)).convert("RGB")
|
154 |
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
@@ -157,17 +156,14 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
157 |
return JSONResponse(status_code=500, content={"error": "Error reading image or mask", "code": 500})
|
158 |
|
159 |
try:
|
160 |
-
# Process images
|
161 |
actual_image = image.copy()
|
162 |
jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
|
163 |
arr_orig = np.array(grayscale(mask))
|
164 |
|
165 |
-
# Process image with inpainting
|
166 |
image = Image.fromarray(
|
167 |
cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
|
168 |
).resize((512, 512))
|
169 |
|
170 |
-
# Process mask
|
171 |
arr = arr_orig.copy()
|
172 |
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
173 |
arr[mask_y:, :] = 255
|
@@ -179,7 +175,6 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
179 |
return JSONResponse(status_code=500, content={"error": "Error processing image or mask", "code": 500})
|
180 |
|
181 |
try:
|
182 |
-
# Convert images to base64 more efficiently
|
183 |
mask_data_uri = image_to_base64(mask)
|
184 |
image_data_uri = image_to_base64(image)
|
185 |
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
@@ -187,7 +182,6 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
187 |
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
188 |
return JSONResponse(status_code=500, content={"error": "Error encoding images", "code": 500})
|
189 |
|
190 |
-
# Prepare replicate input
|
191 |
input_data = {
|
192 |
"mask": mask_data_uri,
|
193 |
"image": image_data_uri,
|
@@ -204,18 +198,30 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
204 |
return JSONResponse(content={"error": "Error running clothing try on", "code": 500}, status_code=500)
|
205 |
|
206 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
async with aiohttp.ClientSession() as session:
|
208 |
-
async with session.get(
|
|
|
|
|
|
|
209 |
output_bytes = await response.read()
|
210 |
|
|
|
211 |
output_image = Image.open(BytesIO(output_bytes)).resize(actual_image.size)
|
212 |
|
213 |
-
#
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
|
220 |
# Convert result to base64
|
221 |
result_base64 = image_to_base64(result)
|
@@ -228,13 +234,21 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
228 |
"code": 200,
|
229 |
"inference_time": total_inference_time
|
230 |
}
|
|
|
|
|
|
|
|
|
|
|
231 |
except Exception as e:
|
232 |
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
|
233 |
-
return JSONResponse(status_code=500,
|
|
|
|
|
234 |
finally:
|
235 |
# Clean up resources
|
236 |
-
|
237 |
-
|
|
|
238 |
gc.collect()
|
239 |
|
240 |
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
|
|
148 |
mask.read()
|
149 |
)
|
150 |
|
|
|
151 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
152 |
mask = Image.open(BytesIO(mask_data)).convert("RGB")
|
153 |
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
|
|
156 |
return JSONResponse(status_code=500, content={"error": "Error reading image or mask", "code": 500})
|
157 |
|
158 |
try:
|
|
|
159 |
actual_image = image.copy()
|
160 |
jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
|
161 |
arr_orig = np.array(grayscale(mask))
|
162 |
|
|
|
163 |
image = Image.fromarray(
|
164 |
cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
|
165 |
).resize((512, 512))
|
166 |
|
|
|
167 |
arr = arr_orig.copy()
|
168 |
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
169 |
arr[mask_y:, :] = 255
|
|
|
175 |
return JSONResponse(status_code=500, content={"error": "Error processing image or mask", "code": 500})
|
176 |
|
177 |
try:
|
|
|
178 |
mask_data_uri = image_to_base64(mask)
|
179 |
image_data_uri = image_to_base64(image)
|
180 |
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
|
|
182 |
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
183 |
return JSONResponse(status_code=500, content={"error": "Error encoding images", "code": 500})
|
184 |
|
|
|
185 |
input_data = {
|
186 |
"mask": mask_data_uri,
|
187 |
"image": image_data_uri,
|
|
|
198 |
return JSONResponse(content={"error": "Error running clothing try on", "code": 500}, status_code=500)
|
199 |
|
200 |
try:
|
201 |
+
# Fetch the output image URL
|
202 |
+
output_url = output[0]
|
203 |
+
if not isinstance(output_url, str):
|
204 |
+
raise ValueError("Invalid output URL format from replicate")
|
205 |
+
|
206 |
+
# Download the output image
|
207 |
async with aiohttp.ClientSession() as session:
|
208 |
+
async with session.get(output_url) as response:
|
209 |
+
if response.status != 200:
|
210 |
+
raise HTTPException(status_code=response.status,
|
211 |
+
detail="Failed to fetch output image")
|
212 |
output_bytes = await response.read()
|
213 |
|
214 |
+
# Process the output image
|
215 |
output_image = Image.open(BytesIO(output_bytes)).resize(actual_image.size)
|
216 |
|
217 |
+
# Convert arrays and process final image
|
218 |
+
mask_array = np.array(Image.fromarray(arr_orig).convert("RGB"))
|
219 |
+
output_array = np.array(output_image)
|
220 |
+
|
221 |
+
# Perform bitwise operations
|
222 |
+
mask_inverse = np.bitwise_not(mask_array)
|
223 |
+
intermediate = np.bitwise_and(output_array, mask_inverse)
|
224 |
+
result = Image.fromarray(np.bitwise_or(intermediate, np.array(jewellery_mask)))
|
225 |
|
226 |
# Convert result to base64
|
227 |
result_base64 = image_to_base64(result)
|
|
|
234 |
"code": 200,
|
235 |
"inference_time": total_inference_time
|
236 |
}
|
237 |
+
except ValueError as ve:
|
238 |
+
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(ve)} <<<")
|
239 |
+
return JSONResponse(status_code=500,
|
240 |
+
content={"error": "Invalid response from image generation service",
|
241 |
+
"code": 500})
|
242 |
except Exception as e:
|
243 |
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
|
244 |
+
return JSONResponse(status_code=500,
|
245 |
+
content={"error": "Error processing output image",
|
246 |
+
"code": 500})
|
247 |
finally:
|
248 |
# Clean up resources
|
249 |
+
for var in ['output_image', 'output_array', 'intermediate', 'mask_array', 'mask_inverse']:
|
250 |
+
if var in locals():
|
251 |
+
del locals()[var]
|
252 |
gc.collect()
|
253 |
|
254 |
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|