ishworrsubedii commited on
Commit
869e26b
·
1 Parent(s): 69f10d4

refactor: cto for better optimization

Browse files
Files changed (1) hide show
  1. src/api/nto_api.py +30 -16
src/api/nto_api.py CHANGED
@@ -148,7 +148,6 @@ async def clothing_try_on(image: UploadFile = File(...),
148
  mask.read()
149
  )
150
 
151
- # Convert bytes to PIL Images
152
  image = Image.open(BytesIO(image_data)).convert("RGB")
153
  mask = Image.open(BytesIO(mask_data)).convert("RGB")
154
  logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
@@ -157,17 +156,14 @@ async def clothing_try_on(image: UploadFile = File(...),
157
  return JSONResponse(status_code=500, content={"error": "Error reading image or mask", "code": 500})
158
 
159
  try:
160
- # Process images
161
  actual_image = image.copy()
162
  jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
163
  arr_orig = np.array(grayscale(mask))
164
 
165
- # Process image with inpainting
166
  image = Image.fromarray(
167
  cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
168
  ).resize((512, 512))
169
 
170
- # Process mask
171
  arr = arr_orig.copy()
172
  mask_y = np.where(arr == arr[arr != 0][0])[0][0]
173
  arr[mask_y:, :] = 255
@@ -179,7 +175,6 @@ async def clothing_try_on(image: UploadFile = File(...),
179
  return JSONResponse(status_code=500, content={"error": "Error processing image or mask", "code": 500})
180
 
181
  try:
182
- # Convert images to base64 more efficiently
183
  mask_data_uri = image_to_base64(mask)
184
  image_data_uri = image_to_base64(image)
185
  logger.info(">>> IMAGE ENCODING COMPLETED <<<")
@@ -187,7 +182,6 @@ async def clothing_try_on(image: UploadFile = File(...),
187
  logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
188
  return JSONResponse(status_code=500, content={"error": "Error encoding images", "code": 500})
189
 
190
- # Prepare replicate input
191
  input_data = {
192
  "mask": mask_data_uri,
193
  "image": image_data_uri,
@@ -204,18 +198,30 @@ async def clothing_try_on(image: UploadFile = File(...),
204
  return JSONResponse(content={"error": "Error running clothing try on", "code": 500}, status_code=500)
205
 
206
  try:
 
 
 
 
 
 
207
  async with aiohttp.ClientSession() as session:
208
- async with session.get(output[0]) as response:
 
 
 
209
  output_bytes = await response.read()
210
 
 
211
  output_image = Image.open(BytesIO(output_bytes)).resize(actual_image.size)
212
 
213
- # Process final image
214
- output_array = np.bitwise_and(
215
- np.array(output_image),
216
- np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB")))
217
- )
218
- result = Image.fromarray(np.bitwise_or(output_array, np.array(jewellery_mask)))
 
 
219
 
220
  # Convert result to base64
221
  result_base64 = image_to_base64(result)
@@ -228,13 +234,21 @@ async def clothing_try_on(image: UploadFile = File(...),
228
  "code": 200,
229
  "inference_time": total_inference_time
230
  }
 
 
 
 
 
231
  except Exception as e:
232
  logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
233
- return JSONResponse(status_code=500, content={"error": "Error processing output image", "code": 500})
 
 
234
  finally:
235
  # Clean up resources
236
- if 'output_image' in locals(): del output_image
237
- if 'output_array' in locals(): del output_array
 
238
  gc.collect()
239
 
240
  logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
 
148
  mask.read()
149
  )
150
 
 
151
  image = Image.open(BytesIO(image_data)).convert("RGB")
152
  mask = Image.open(BytesIO(mask_data)).convert("RGB")
153
  logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
 
156
  return JSONResponse(status_code=500, content={"error": "Error reading image or mask", "code": 500})
157
 
158
  try:
 
159
  actual_image = image.copy()
160
  jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
161
  arr_orig = np.array(grayscale(mask))
162
 
 
163
  image = Image.fromarray(
164
  cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
165
  ).resize((512, 512))
166
 
 
167
  arr = arr_orig.copy()
168
  mask_y = np.where(arr == arr[arr != 0][0])[0][0]
169
  arr[mask_y:, :] = 255
 
175
  return JSONResponse(status_code=500, content={"error": "Error processing image or mask", "code": 500})
176
 
177
  try:
 
178
  mask_data_uri = image_to_base64(mask)
179
  image_data_uri = image_to_base64(image)
180
  logger.info(">>> IMAGE ENCODING COMPLETED <<<")
 
182
  logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
183
  return JSONResponse(status_code=500, content={"error": "Error encoding images", "code": 500})
184
 
 
185
  input_data = {
186
  "mask": mask_data_uri,
187
  "image": image_data_uri,
 
198
  return JSONResponse(content={"error": "Error running clothing try on", "code": 500}, status_code=500)
199
 
200
  try:
201
+ # Fetch the output image URL
202
+ output_url = output[0]
203
+ if not isinstance(output_url, str):
204
+ raise ValueError("Invalid output URL format from replicate")
205
+
206
+ # Download the output image
207
  async with aiohttp.ClientSession() as session:
208
+ async with session.get(output_url) as response:
209
+ if response.status != 200:
210
+ raise HTTPException(status_code=response.status,
211
+ detail="Failed to fetch output image")
212
  output_bytes = await response.read()
213
 
214
+ # Process the output image
215
  output_image = Image.open(BytesIO(output_bytes)).resize(actual_image.size)
216
 
217
+ # Convert arrays and process final image
218
+ mask_array = np.array(Image.fromarray(arr_orig).convert("RGB"))
219
+ output_array = np.array(output_image)
220
+
221
+ # Perform bitwise operations
222
+ mask_inverse = np.bitwise_not(mask_array)
223
+ intermediate = np.bitwise_and(output_array, mask_inverse)
224
+ result = Image.fromarray(np.bitwise_or(intermediate, np.array(jewellery_mask)))
225
 
226
  # Convert result to base64
227
  result_base64 = image_to_base64(result)
 
234
  "code": 200,
235
  "inference_time": total_inference_time
236
  }
237
+ except ValueError as ve:
238
+ logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(ve)} <<<")
239
+ return JSONResponse(status_code=500,
240
+ content={"error": "Invalid response from image generation service",
241
+ "code": 500})
242
  except Exception as e:
243
  logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
244
+ return JSONResponse(status_code=500,
245
+ content={"error": "Error processing output image",
246
+ "code": 500})
247
  finally:
248
  # Clean up resources
249
+ for var in ['output_image', 'output_array', 'intermediate', 'mask_array', 'mask_inverse']:
250
+ if var in locals():
251
+ del locals()[var]
252
  gc.collect()
253
 
254
  logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")