ishworrsubedii commited on
Commit
dfee28c
·
1 Parent(s): f33f113

add: nto mannequin, rt_nto

Browse files
app.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime, timedelta, timezone
9
  from fastapi import FastAPI, Depends, HTTPException
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from starlette.middleware.cors import CORSMiddleware
 
12
 
13
  from src.api.image_prep_api import preprocessing_router
14
  from src.api.nto_api import nto_cto_router
@@ -35,7 +36,7 @@ async def verify_login_token(credentials: HTTPAuthorizationCredentials = Depends
35
  time_difference = current_time - created_at
36
 
37
  if time_difference <= timedelta(minutes=30):
38
- return {"status": "Authorized", "message": "Token is valid"}
39
 
40
  raise HTTPException(status_code=401, detail="Unauthorized: Token expired")
41
 
 
9
  from fastapi import FastAPI, Depends, HTTPException
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from starlette.middleware.cors import CORSMiddleware
12
+ from starlette.responses import JSONResponse
13
 
14
  from src.api.image_prep_api import preprocessing_router
15
  from src.api.nto_api import nto_cto_router
 
36
  time_difference = current_time - created_at
37
 
38
  if time_difference <= timedelta(minutes=30):
39
+ return JSONResponse({"status": "Authorized", "message": "Token is valid"}, status_code=200)
40
 
41
  raise HTTPException(status_code=401, detail="Unauthorized: Token expired")
42
 
examples/rt_nto_example.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ project @ CTO_TCP_ZERO_GPU
3
+ created @ 2024-11-18
4
+ author @ github.com/ishworrsubedii
5
+ """
6
+ import asyncio
7
+ import aiohttp
8
+ import json
9
+ from pathlib import Path
10
+
11
+
12
+ async def test_rt_nto(
13
+ image_path: str,
14
+ necklace_ids: list,
15
+ categories: list,
16
+ storename: str,
17
+ api_url: str = "http://localhost:8000/rt_nto"
18
+ ):
19
+ image_content = None
20
+ with open(image_path, 'rb') as f:
21
+ image_content = f.read()
22
+
23
+ data = aiohttp.FormData()
24
+ data.add_field('necklace_id_list', ','.join(necklace_ids))
25
+ data.add_field('category_list', ','.join(categories))
26
+ data.add_field('storename', storename)
27
+ data.add_field('image',
28
+ image_content,
29
+ filename=Path(image_path).name,
30
+ content_type='image/jpeg')
31
+
32
+ async with aiohttp.ClientSession() as session:
33
+ try:
34
+ async with session.post(api_url, data=data) as response:
35
+ if response.status != 200:
36
+ print(f"Server returned status code: {response.status}")
37
+ try:
38
+ error_content = await response.text()
39
+ print(f"Error content: {error_content}")
40
+ except Exception as e:
41
+ print(f"Could not read error content: {e}")
42
+ return
43
+
44
+ async for line in response.content:
45
+ if line:
46
+ try:
47
+ result = json.loads(line.decode('utf-8'))
48
+ if result.get('code') == 200:
49
+ print(f"Success - Progress: {result['progress']}")
50
+ print(f"Necklace ID: {result['necklace_id']}")
51
+ print(f"Category: {result['category']}")
52
+ print(f"Output URL: {result['output']}")
53
+ print(f"Mask URL: {result['mask']}")
54
+ print(f"Inference Time: {result['inference_time']}s")
55
+ else:
56
+ print(f"Error - Progress: {result.get('progress')}")
57
+ print(f"Error Details: {result.get('error')}")
58
+ print("-" * 50)
59
+ except json.JSONDecodeError as e:
60
+ print(f"Error decoding response: {e}")
61
+ print(f"Raw line: {line}")
62
+ except aiohttp.ClientError as e:
63
+ print(f"Connection error: {e}")
64
+ except Exception as e:
65
+ print(f"Unexpected error: {e}")
66
+
67
+
68
+ async def main():
69
+ image_path = "/home/ishwor/Downloads/download (7).png"
70
+ necklace_ids = ["ST01SHO001", "ST01SHO002", "ST01SHO003"]
71
+ categories = ["Short Necklaces", "Short Necklaces", "Short Necklaces"]
72
+ storename = "Store1"
73
+
74
+ try:
75
+ await test_rt_nto(
76
+ image_path=image_path,
77
+ necklace_ids=necklace_ids,
78
+ categories=categories,
79
+ storename=storename
80
+ )
81
+ except Exception as e:
82
+ print(f"Error in main: {e}")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ asyncio.run(main())
src/api/batch_api.py CHANGED
@@ -4,6 +4,7 @@ created @ 2024-11-14
4
  author @ github.com/ishworrsubedii
5
  """
6
  import base64
 
7
  import time
8
  from io import BytesIO
9
  import json
@@ -15,8 +16,12 @@ from fastapi.routing import APIRouter
15
  from fastapi.responses import StreamingResponse
16
  from pydantic import BaseModel
17
  from typing import List
 
 
 
 
18
  from src.utils.logger import logger
19
- from src.api.nto_api import pipeline, replicate_run_cto
20
 
21
  batch_router = APIRouter()
22
 
@@ -36,7 +41,7 @@ async def rt_cto(
36
  try:
37
  clothing_list = [item.strip() for item in c_list.split(",")]
38
  logger.info(f">>> CLOTHING LIST: {clothing_list} <<<")
39
-
40
  image_bytes = await image.read()
41
  pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
42
  logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
@@ -77,7 +82,7 @@ async def rt_cto(
77
  for idx, clothing_type in enumerate(clothing_list):
78
  if not clothing_type:
79
  continue
80
-
81
  input = {
82
  "mask": mask_data_uri,
83
  "image": image_data_uri,
@@ -89,10 +94,10 @@ async def rt_cto(
89
  try:
90
  output = replicate_run_cto(input)
91
  logger.info(f">>> REPLICATE PROCESSING COMPLETED FOR {clothing_type} <<<")
92
-
93
  output_url = str(output[0]) if output and output[0] else None
94
  total_inference_time = round((time.time() - start_time), 2)
95
-
96
  result = {
97
  "code": 200,
98
  "output": output_url,
@@ -102,7 +107,7 @@ async def rt_cto(
102
  }
103
  yield json.dumps(result) + "\n"
104
  await asyncio.sleep(0.1)
105
-
106
  except Exception as e:
107
  logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
108
  error_result = {
@@ -125,3 +130,114 @@ async def rt_cto(
125
  "Transfer-Encoding": "chunked"
126
  }
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  author @ github.com/ishworrsubedii
5
  """
6
  import base64
7
+ import gc
8
  import time
9
  from io import BytesIO
10
  import json
 
16
  from fastapi.responses import StreamingResponse
17
  from pydantic import BaseModel
18
  from typing import List
19
+
20
+ from fastapi.responses import JSONResponse
21
+
22
+ from src.utils import returnBytesData
23
  from src.utils.logger import logger
24
+ from src.api.nto_api import pipeline, replicate_run_cto, supabase_upload_and_return_url
25
 
26
  batch_router = APIRouter()
27
 
 
41
  try:
42
  clothing_list = [item.strip() for item in c_list.split(",")]
43
  logger.info(f">>> CLOTHING LIST: {clothing_list} <<<")
44
+
45
  image_bytes = await image.read()
46
  pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
47
  logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
 
82
  for idx, clothing_type in enumerate(clothing_list):
83
  if not clothing_type:
84
  continue
85
+
86
  input = {
87
  "mask": mask_data_uri,
88
  "image": image_data_uri,
 
94
  try:
95
  output = replicate_run_cto(input)
96
  logger.info(f">>> REPLICATE PROCESSING COMPLETED FOR {clothing_type} <<<")
97
+
98
  output_url = str(output[0]) if output and output[0] else None
99
  total_inference_time = round((time.time() - start_time), 2)
100
+
101
  result = {
102
  "code": 200,
103
  "output": output_url,
 
107
  }
108
  yield json.dumps(result) + "\n"
109
  await asyncio.sleep(0.1)
110
+
111
  except Exception as e:
112
  logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
113
  error_result = {
 
130
  "Transfer-Encoding": "chunked"
131
  }
132
  )
133
+
134
+
135
+ @batch_router.post("/rt_nto")
136
+ async def rt_nto(
137
+ image: UploadFile = File(...),
138
+ necklace_id_list: str = Form(...),
139
+ category_list: str = Form(...),
140
+ storename: str = Form(...)
141
+ ):
142
+ logger.info("-" * 50)
143
+ logger.info(">>> REAL-TIME NECKLACE TRY ON STARTED <<<")
144
+ logger.info(f"Parameters: storename={storename}, categories={category_list}, necklace_ids={necklace_id_list}")
145
+
146
+ try:
147
+ # Parse the lists
148
+ necklace_ids = [id.strip() for id in necklace_id_list.split(",")]
149
+ categories = [cat.strip() for cat in category_list.split(",")]
150
+ if len(necklace_ids) != len(categories):
151
+ return JSONResponse(
152
+ content={"error": "Number of necklace IDs must match number of categories", "code": 400},
153
+ status_code=400
154
+ )
155
+
156
+ # Load the source image
157
+ image_bytes = await image.read()
158
+ source_image = Image.open(BytesIO(image_bytes))
159
+ logger.info(">>> SOURCE IMAGE LOADED SUCCESSFULLY <<<")
160
+ except Exception as e:
161
+ logger.error(f">>> INITIAL SETUP ERROR: {str(e)} <<<")
162
+ return JSONResponse(
163
+ content={"error": "Error in initial setup", "details": str(e), "code": 500},
164
+ status_code=500
165
+ )
166
+
167
+ async def generate():
168
+ start_time = time.time()
169
+
170
+ for idx, (necklace_id, category) in enumerate(zip(necklace_ids, categories)):
171
+ try:
172
+ jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{category}/image/{necklace_id}.png"
173
+ jewellery = Image.open(returnBytesData(url=jewellery_url))
174
+ logger.info(f">>> JEWELLERY IMAGE {necklace_id} LOADED SUCCESSFULLY <<<")
175
+
176
+ # Process the necklace try-on
177
+ result, headetText, mask = await pipeline.necklaceTryOn_(
178
+ image=source_image,
179
+ jewellery=jewellery,
180
+ storename=storename
181
+ )
182
+
183
+ if result is None:
184
+ error_result = {
185
+ "error": "No face detected in the image",
186
+ "code": 400,
187
+ "necklace_id": necklace_id,
188
+ "category": category,
189
+ "progress": f"{idx + 1}/{len(necklace_ids)}"
190
+ }
191
+ yield json.dumps(error_result) + "\n"
192
+ continue
193
+
194
+ # Upload results concurrently
195
+ logger.info(">>> UPLOADING RESULTS <<<")
196
+ upload_tasks = [
197
+ supabase_upload_and_return_url(prefix="necklace_try_on", image=result),
198
+ supabase_upload_and_return_url(prefix="necklace_try_on_mask", image=mask)
199
+ ]
200
+ result_url, mask_url = await asyncio.gather(*upload_tasks)
201
+
202
+ total_inference_time = round((time.time() - start_time), 2)
203
+ logger.info(f">>> UPLOADING COMPLETED FOR {necklace_id} <<<")
204
+
205
+ result = {
206
+ "code": 200,
207
+ "output": result_url,
208
+ "mask": mask_url,
209
+ "inference_time": total_inference_time,
210
+ "necklace_id": necklace_id,
211
+ "category": category,
212
+ "progress": f"{idx + 1}/{len(necklace_ids)}"
213
+ }
214
+ yield json.dumps(result) + "\n"
215
+ await asyncio.sleep(0.1)
216
+
217
+ del result
218
+ del mask
219
+ gc.collect()
220
+
221
+ except Exception as e:
222
+ logger.error(f">>> PROCESSING ERROR FOR {necklace_id}: {str(e)} <<<")
223
+ error_result = {
224
+ "error": f"Error processing necklace {necklace_id}",
225
+ "details": str(e),
226
+ "code": 500,
227
+ "necklace_id": necklace_id,
228
+ "category": category,
229
+ "progress": f"{idx + 1}/{len(necklace_ids)}"
230
+ }
231
+ yield json.dumps(error_result) + "\n"
232
+ await asyncio.sleep(0.1)
233
+
234
+ return StreamingResponse(
235
+ generate(),
236
+ media_type="application/x-ndjson",
237
+ headers={
238
+ "Cache-Control": "no-cache",
239
+ "Connection": "keep-alive",
240
+ "X-Accel-Buffering": "no",
241
+ "Transfer-Encoding": "chunked"
242
+ }
243
+ )
src/api/nto_api.py CHANGED
@@ -749,7 +749,7 @@ async def mannequin_nto(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(pars
749
  "code": 404}, status_code=404)
750
 
751
  try:
752
- result = await pipeline.necklaceTryOnMannequin_(image=image, jewellery=jewellery)
753
 
754
  if result is None:
755
  logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
 
749
  "code": 404}, status_code=404)
750
 
751
  try:
752
+ result,resized_img = await pipeline.necklaceTryOnMannequin_(image=image, jewellery=jewellery)
753
 
754
  if result is None:
755
  logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
src/components/necklaceTryOn.py CHANGED
@@ -107,6 +107,7 @@ class NecklaceTryOn:
107
  upper_half = image_np[:middle_point, :]
108
  lower_half = image_np[middle_point:, :]
109
 
 
110
  copy_upper = upper_half.copy()
111
 
112
  # Apply pose detection on upper half
@@ -145,10 +146,12 @@ class NecklaceTryOn:
145
 
146
  result_upper = cvzone.overlayPNG(copy_upper, jewellery, (avg_x1, y_coordinate))
147
 
148
- final_result = np.vstack((result_upper, lower_half))
 
 
149
 
150
  gc.collect()
151
- return Image.fromarray(final_result.astype(np.uint8))
152
 
153
  except Exception as e:
154
  print(f"Error: {e}")
 
107
  upper_half = image_np[:middle_point, :]
108
  lower_half = image_np[middle_point:, :]
109
 
110
+ upper_half = cv2.resize(upper_half, (upper_half.shape[1] * 3, upper_half.shape[0] * 3))
111
  copy_upper = upper_half.copy()
112
 
113
  # Apply pose detection on upper half
 
146
 
147
  result_upper = cvzone.overlayPNG(copy_upper, jewellery, (avg_x1, y_coordinate))
148
 
149
+ final_result = cv2.resize(result_upper, (image_np.shape[1], image_np.shape[0] - middle_point))
150
+
151
+ final_result = np.vstack((final_result, lower_half))
152
 
153
  gc.collect()
154
+ return Image.fromarray(final_result.astype(np.uint8)), Image.fromarray(result_upper.astype(np.uint8))
155
 
156
  except Exception as e:
157
  print(f"Error: {e}")
src/pipelines/completePipeline.py CHANGED
@@ -30,7 +30,7 @@ class Pipeline:
30
  storename=storename)
31
  return [result, headerText, mask]
32
 
33
- async def necklaceTryOnMannequin_(self, image: Image.Image, jewellery: Image.Image):
34
- result = self.necklaceTryOnObj.necklaceTryOnMannequin(image, jewellery)
35
 
36
- return result
 
30
  storename=storename)
31
  return [result, headerText, mask]
32
 
33
+ def necklaceTryOnMannequin_(self, image: Image.Image, jewellery: Image.Image):
34
+ result, resized_image = self.necklaceTryOnObj.necklaceTryOnMannequin(image, jewellery)
35
 
36
+ return result, resized_image