amaye15 commited on
Commit
18f8eec
1 Parent(s): 59f3026

Old handler with autocast

Browse files
Files changed (1) hide show
  1. handler.py +308 -308
handler.py CHANGED
@@ -1,207 +1,9 @@
1
- # import torch
2
- # from typing import Dict, Any, List
3
- # from PIL import Image
4
- # import base64
5
- # from io import BytesIO
6
- # import logging
7
-
8
-
9
- # class EndpointHandler:
10
- # """
11
- # A handler class for processing image and text data, generating embeddings using a specified model and processor.
12
-
13
- # Attributes:
14
- # model: The pre-trained model used for generating embeddings.
15
- # processor: The pre-trained processor used to process images and text before model inference.
16
- # device: The device (CPU or CUDA) used to run model inference.
17
- # default_batch_size: The default batch size for processing images and text in batches.
18
- # """
19
-
20
- # def __init__(self, path: str = "", default_batch_size: int = 4):
21
- # """
22
- # Initializes the EndpointHandler with a specified model path and default batch size.
23
-
24
- # Args:
25
- # path (str): Path to the pre-trained model and processor.
26
- # default_batch_size (int): Default batch size for processing images and text data.
27
- # """
28
- # # Initialize logging
29
- # logging.basicConfig(level=logging.INFO)
30
- # self.logger = logging.getLogger(__name__)
31
-
32
- # from colpali_engine.models import ColQwen2, ColQwen2Processor
33
-
34
- # self.logger.info("Initializing model and processor.")
35
- # try:
36
- # self.model = ColQwen2.from_pretrained(
37
- # path,
38
- # torch_dtype=torch.bfloat16,
39
- # device_map="auto",
40
- # ).eval()
41
- # self.processor = ColQwen2Processor.from_pretrained(path)
42
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- # self.model.to(self.device)
44
- # self.default_batch_size = default_batch_size
45
- # self.logger.info("Initialization complete.")
46
- # except Exception as e:
47
- # self.logger.error(f"Failed to initialize model or processor: {e}")
48
- # raise
49
-
50
- # def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
51
- # """
52
- # Processes a batch of images and generates embeddings.
53
-
54
- # Args:
55
- # images (List[Image.Image]): List of images to process.
56
-
57
- # Returns:
58
- # List[List[float]]: List of embeddings for each image.
59
- # """
60
- # self.logger.debug(f"Processing batch of {len(images)} images.")
61
- # try:
62
- # batch_images = self.processor.process_images(images).to(self.device)
63
- # with torch.no_grad(), torch.amp.autocast():
64
- # image_embeddings = self.model(**batch_images)
65
- # self.logger.debug("Image batch processing complete.")
66
- # return image_embeddings.cpu().tolist()
67
- # except Exception as e:
68
- # self.logger.error(f"Error processing image batch: {e}")
69
- # raise
70
-
71
- # def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
72
- # """
73
- # Processes a batch of text queries and generates embeddings.
74
-
75
- # Args:
76
- # texts (List[str]): List of text queries to process.
77
-
78
- # Returns:
79
- # List[List[float]]: List of embeddings for each text query.
80
- # """
81
- # self.logger.debug(f"Processing batch of {len(texts)} text queries.")
82
- # try:
83
- # batch_queries = self.processor.process_queries(texts).to(self.device)
84
- # with torch.no_grad(), torch.amp.autocast():
85
- # query_embeddings = self.model(**batch_queries)
86
- # self.logger.debug("Text batch processing complete.")
87
- # return query_embeddings.cpu().tolist()
88
- # except Exception as e:
89
- # self.logger.error(f"Error processing text batch: {e}")
90
- # raise
91
-
92
- # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
93
- # """
94
- # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
95
-
96
- # Args:
97
- # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
98
-
99
- # Returns:
100
- # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
101
- # """
102
- # images_data = data.get("image", [])
103
- # text_data = data.get("text", [])
104
- # batch_size = data.get("batch_size", self.default_batch_size)
105
-
106
- # # Decode and process images
107
- # images = []
108
- # if images_data:
109
- # self.logger.info("Decoding images from base64.")
110
- # for img_data in images_data:
111
- # if isinstance(img_data, str):
112
- # try:
113
- # image_bytes = base64.b64decode(img_data)
114
- # image = Image.open(BytesIO(image_bytes)).convert("RGB")
115
- # images.append(image)
116
- # except Exception as e:
117
- # self.logger.error(f"Invalid image data: {e}")
118
- # return {"error": f"Invalid image data: {e}"}
119
- # else:
120
- # self.logger.error("Images should be base64-encoded strings.")
121
- # return {"error": "Images should be base64-encoded strings."}
122
-
123
- # image_embeddings = []
124
- # if images:
125
- # self.logger.info("Processing image embeddings.")
126
- # try:
127
- # for i in range(0, len(images), batch_size):
128
- # batch_images = images[i : i + batch_size]
129
- # batch_embeddings = self._process_image_batch(batch_images)
130
- # image_embeddings.extend(batch_embeddings)
131
- # except Exception as e:
132
- # self.logger.error(f"Error generating image embeddings: {e}")
133
- # return {"error": f"Error generating image embeddings: {e}"}
134
-
135
- # # Process text data
136
- # text_embeddings = []
137
- # if text_data:
138
- # self.logger.info("Processing text embeddings.")
139
- # try:
140
- # for i in range(0, len(text_data), batch_size):
141
- # batch_texts = text_data[i : i + batch_size]
142
- # batch_text_embeddings = self._process_text_batch(batch_texts)
143
- # text_embeddings.extend(batch_text_embeddings)
144
- # except Exception as e:
145
- # self.logger.error(f"Error generating text embeddings: {e}")
146
- # return {"error": f"Error generating text embeddings: {e}"}
147
-
148
- # # Compute similarity scores if both image and text embeddings are available
149
- # scores = []
150
- # if image_embeddings and text_embeddings:
151
- # self.logger.info("Computing similarity scores.")
152
- # try:
153
- # image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
154
- # text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
155
- # with torch.no_grad(), torch.amp.autocast():
156
- # scores = (
157
- # self.processor.score_multi_vector(
158
- # text_embeddings_tensor, image_embeddings_tensor
159
- # )
160
- # .cpu()
161
- # .tolist()
162
- # )
163
- # self.logger.info("Similarity scoring complete.")
164
- # except Exception as e:
165
- # self.logger.error(f"Error computing similarity scores: {e}")
166
- # return {"error": f"Error computing similarity scores: {e}"}
167
-
168
- # return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
169
-
170
-
171
  import torch
172
  from typing import Dict, Any, List
173
  from PIL import Image
174
  import base64
175
  from io import BytesIO
176
  import logging
177
- from torch.utils.data import DataLoader, Dataset
178
- import threading
179
-
180
-
181
- class ImageDataset(Dataset):
182
- def __init__(self, images: List[Image.Image], processor):
183
- self.images = images
184
- self.processor = processor
185
-
186
- def __len__(self):
187
- return len(self.images)
188
-
189
- def __getitem__(self, idx):
190
- image = self.processor.process_images([self.images[idx]])
191
- return image
192
-
193
-
194
- class TextDataset(Dataset):
195
- def __init__(self, texts: List[str], processor):
196
- self.texts = texts
197
- self.processor = processor
198
-
199
- def __len__(self):
200
- return len(self.texts)
201
-
202
- def __getitem__(self, idx):
203
- text = self.processor.process_queries([self.texts[idx]])
204
- return text
205
 
206
 
207
  class EndpointHandler:
@@ -218,6 +20,10 @@ class EndpointHandler:
218
  def __init__(self, path: str = "", default_batch_size: int = 4):
219
  """
220
  Initializes the EndpointHandler with a specified model path and default batch size.
 
 
 
 
221
  """
222
  # Initialize logging
223
  logging.basicConfig(level=logging.INFO)
@@ -227,91 +33,60 @@ class EndpointHandler:
227
 
228
  self.logger.info("Initializing model and processor.")
229
  try:
230
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
-
232
- self.model = (
233
- ColQwen2.from_pretrained(
234
- path,
235
- torch_dtype=torch.bfloat16,
236
- device_map="auto",
237
- )
238
- .to(self.device)
239
- .eval()
240
- )
241
-
242
  self.processor = ColQwen2Processor.from_pretrained(path)
 
 
243
  self.default_batch_size = default_batch_size
244
  self.logger.info("Initialization complete.")
245
  except Exception as e:
246
  self.logger.error(f"Failed to initialize model or processor: {e}")
247
  raise
248
 
249
- def _process_image_embeddings(
250
- self, images: List[Image.Image], batch_size: int
251
- ) -> torch.Tensor:
252
  """
253
- Processes images and generates embeddings.
254
 
255
  Args:
256
  images (List[Image.Image]): List of images to process.
257
- batch_size (int): Batch size for processing images.
258
 
259
  Returns:
260
- torch.Tensor: Tensor containing embeddings for each image.
261
  """
262
- self.logger.debug(f"Processing {len(images)} images.")
263
  try:
264
- image_dataset = ImageDataset(images, self.processor)
265
- image_loader = DataLoader(
266
- image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
267
- )
268
-
269
- all_embeddings = []
270
- with torch.no_grad():
271
- for batch in image_loader:
272
- batch_images = batch[0].to(self.device, non_blocking=True)
273
- with torch.cuda.amp.autocast():
274
- embeddings = self.model(**batch_images)
275
- all_embeddings.append(embeddings)
276
- image_embeddings = torch.cat(all_embeddings, dim=0)
277
- self.logger.debug("Image processing complete.")
278
- return image_embeddings
279
  except Exception as e:
280
- self.logger.error(f"Error processing images: {e}")
281
  raise
282
 
283
- def _process_text_embeddings(
284
- self, texts: List[str], batch_size: int
285
- ) -> torch.Tensor:
286
  """
287
- Processes text queries and generates embeddings.
288
 
289
  Args:
290
  texts (List[str]): List of text queries to process.
291
- batch_size (int): Batch size for processing texts.
292
 
293
  Returns:
294
- torch.Tensor: Tensor containing embeddings for each text query.
295
  """
296
- self.logger.debug(f"Processing {len(texts)} text queries.")
297
  try:
298
- text_dataset = TextDataset(texts, self.processor)
299
- text_loader = DataLoader(
300
- text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
301
- )
302
-
303
- all_embeddings = []
304
- with torch.no_grad():
305
- for batch in text_loader:
306
- batch_texts = batch[0].to(self.device, non_blocking=True)
307
- with torch.amp.autocast():
308
- embeddings = self.model(**batch_texts)
309
- all_embeddings.append(embeddings)
310
- text_embeddings = torch.cat(all_embeddings, dim=0)
311
- self.logger.debug("Text processing complete.")
312
- return text_embeddings
313
  except Exception as e:
314
- self.logger.error(f"Error processing texts: {e}")
315
  raise
316
 
317
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -328,6 +103,7 @@ class EndpointHandler:
328
  text_data = data.get("text", [])
329
  batch_size = data.get("batch_size", self.default_batch_size)
330
 
 
331
  images = []
332
  if images_data:
333
  self.logger.info("Decoding images from base64.")
@@ -344,65 +120,289 @@ class EndpointHandler:
344
  self.logger.error("Images should be base64-encoded strings.")
345
  return {"error": "Images should be base64-encoded strings."}
346
 
347
- image_embeddings = None
348
- text_embeddings = None
349
- scores = None
350
-
351
- def process_images():
352
- nonlocal image_embeddings
353
- if images:
354
- self.logger.info("Processing image embeddings.")
355
- try:
356
- image_embeddings = self._process_image_embeddings(
357
- images, batch_size
358
- )
359
- except Exception as e:
360
- self.logger.error(f"Error generating image embeddings: {e}")
361
-
362
- def process_texts():
363
- nonlocal text_embeddings
364
- if text_data:
365
- self.logger.info("Processing text embeddings.")
366
- try:
367
- text_embeddings = self._process_text_embeddings(
368
- text_data, batch_size
369
- )
370
- except Exception as e:
371
- self.logger.error(f"Error generating text embeddings: {e}")
372
-
373
- # Process images and texts in parallel if both are present
374
- threads = []
375
- if images_data and text_data:
376
- image_thread = threading.Thread(target=process_images)
377
- text_thread = threading.Thread(target=process_texts)
378
- threads.extend([image_thread, text_thread])
379
- image_thread.start()
380
- text_thread.start()
381
- for thread in threads:
382
- thread.join()
383
- else:
384
- process_images()
385
- process_texts()
386
-
387
- # Compute similarity scores if both embeddings are available
388
- if image_embeddings is not None and text_embeddings is not None:
389
  self.logger.info("Computing similarity scores.")
390
  try:
 
 
391
  with torch.no_grad(), torch.amp.autocast():
392
- scores = self.processor.score_multi_vector(
393
- text_embeddings, image_embeddings
 
 
 
 
394
  )
395
  self.logger.info("Similarity scoring complete.")
396
  except Exception as e:
397
  self.logger.error(f"Error computing similarity scores: {e}")
398
  return {"error": f"Error computing similarity scores: {e}"}
399
 
400
- result = {}
401
- if image_embeddings is not None:
402
- result["image"] = image_embeddings.cpu().tolist()
403
- if text_embeddings is not None:
404
- result["text"] = text_embeddings.cpu().tolist()
405
- if scores is not None:
406
- result["scores"] = scores.cpu().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Any, List
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  class EndpointHandler:
 
20
  def __init__(self, path: str = "", default_batch_size: int = 4):
21
  """
22
  Initializes the EndpointHandler with a specified model path and default batch size.
23
+
24
+ Args:
25
+ path (str): Path to the pre-trained model and processor.
26
+ default_batch_size (int): Default batch size for processing images and text data.
27
  """
28
  # Initialize logging
29
  logging.basicConfig(level=logging.INFO)
 
33
 
34
  self.logger.info("Initializing model and processor.")
35
  try:
36
+ self.model = ColQwen2.from_pretrained(
37
+ path,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ ).eval()
 
 
 
 
 
 
 
41
  self.processor = ColQwen2Processor.from_pretrained(path)
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ self.model.to(self.device)
44
  self.default_batch_size = default_batch_size
45
  self.logger.info("Initialization complete.")
46
  except Exception as e:
47
  self.logger.error(f"Failed to initialize model or processor: {e}")
48
  raise
49
 
50
+ def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
 
 
51
  """
52
+ Processes a batch of images and generates embeddings.
53
 
54
  Args:
55
  images (List[Image.Image]): List of images to process.
 
56
 
57
  Returns:
58
+ List[List[float]]: List of embeddings for each image.
59
  """
60
+ self.logger.debug(f"Processing batch of {len(images)} images.")
61
  try:
62
+ batch_images = self.processor.process_images(images).to(self.device)
63
+ with torch.no_grad(), torch.amp.autocast():
64
+ image_embeddings = self.model(**batch_images)
65
+ self.logger.debug("Image batch processing complete.")
66
+ return image_embeddings.cpu().tolist()
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
+ self.logger.error(f"Error processing image batch: {e}")
69
  raise
70
 
71
+ def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
 
 
72
  """
73
+ Processes a batch of text queries and generates embeddings.
74
 
75
  Args:
76
  texts (List[str]): List of text queries to process.
 
77
 
78
  Returns:
79
+ List[List[float]]: List of embeddings for each text query.
80
  """
81
+ self.logger.debug(f"Processing batch of {len(texts)} text queries.")
82
  try:
83
+ batch_queries = self.processor.process_queries(texts).to(self.device)
84
+ with torch.no_grad(), torch.amp.autocast():
85
+ query_embeddings = self.model(**batch_queries)
86
+ self.logger.debug("Text batch processing complete.")
87
+ return query_embeddings.cpu().tolist()
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
+ self.logger.error(f"Error processing text batch: {e}")
90
  raise
91
 
92
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
103
  text_data = data.get("text", [])
104
  batch_size = data.get("batch_size", self.default_batch_size)
105
 
106
+ # Decode and process images
107
  images = []
108
  if images_data:
109
  self.logger.info("Decoding images from base64.")
 
120
  self.logger.error("Images should be base64-encoded strings.")
121
  return {"error": "Images should be base64-encoded strings."}
122
 
123
+ image_embeddings = []
124
+ if images:
125
+ self.logger.info("Processing image embeddings.")
126
+ try:
127
+ for i in range(0, len(images), batch_size):
128
+ batch_images = images[i : i + batch_size]
129
+ batch_embeddings = self._process_image_batch(batch_images)
130
+ image_embeddings.extend(batch_embeddings)
131
+ except Exception as e:
132
+ self.logger.error(f"Error generating image embeddings: {e}")
133
+ return {"error": f"Error generating image embeddings: {e}"}
134
+
135
+ # Process text data
136
+ text_embeddings = []
137
+ if text_data:
138
+ self.logger.info("Processing text embeddings.")
139
+ try:
140
+ for i in range(0, len(text_data), batch_size):
141
+ batch_texts = text_data[i : i + batch_size]
142
+ batch_text_embeddings = self._process_text_batch(batch_texts)
143
+ text_embeddings.extend(batch_text_embeddings)
144
+ except Exception as e:
145
+ self.logger.error(f"Error generating text embeddings: {e}")
146
+ return {"error": f"Error generating text embeddings: {e}"}
147
+
148
+ # Compute similarity scores if both image and text embeddings are available
149
+ scores = []
150
+ if image_embeddings and text_embeddings:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  self.logger.info("Computing similarity scores.")
152
  try:
153
+ image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
154
+ text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
155
  with torch.no_grad(), torch.amp.autocast():
156
+ scores = (
157
+ self.processor.score_multi_vector(
158
+ text_embeddings_tensor, image_embeddings_tensor
159
+ )
160
+ .cpu()
161
+ .tolist()
162
  )
163
  self.logger.info("Similarity scoring complete.")
164
  except Exception as e:
165
  self.logger.error(f"Error computing similarity scores: {e}")
166
  return {"error": f"Error computing similarity scores: {e}"}
167
 
168
+ return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
169
+
170
+
171
+ # import torch
172
+ # from typing import Dict, Any, List
173
+ # from PIL import Image
174
+ # import base64
175
+ # from io import BytesIO
176
+ # import logging
177
+ # from torch.utils.data import DataLoader, Dataset
178
+ # import threading
179
+
180
+
181
+ # class ImageDataset(Dataset):
182
+ # def __init__(self, images: List[Image.Image], processor):
183
+ # self.images = images
184
+ # self.processor = processor
185
+
186
+ # def __len__(self):
187
+ # return len(self.images)
188
+
189
+ # def __getitem__(self, idx):
190
+ # image = self.processor.process_images([self.images[idx]])
191
+ # return image
192
+
193
+
194
+ # class TextDataset(Dataset):
195
+ # def __init__(self, texts: List[str], processor):
196
+ # self.texts = texts
197
+ # self.processor = processor
198
+
199
+ # def __len__(self):
200
+ # return len(self.texts)
201
+
202
+ # def __getitem__(self, idx):
203
+ # text = self.processor.process_queries([self.texts[idx]])
204
+ # return text
205
+
206
+
207
+ # class EndpointHandler:
208
+ # """
209
+ # A handler class for processing image and text data, generating embeddings using a specified model and processor.
210
+
211
+ # Attributes:
212
+ # model: The pre-trained model used for generating embeddings.
213
+ # processor: The pre-trained processor used to process images and text before model inference.
214
+ # device: The device (CPU or CUDA) used to run model inference.
215
+ # default_batch_size: The default batch size for processing images and text in batches.
216
+ # """
217
+
218
+ # def __init__(self, path: str = "", default_batch_size: int = 4):
219
+ # """
220
+ # Initializes the EndpointHandler with a specified model path and default batch size.
221
+ # """
222
+ # # Initialize logging
223
+ # logging.basicConfig(level=logging.INFO)
224
+ # self.logger = logging.getLogger(__name__)
225
+
226
+ # from colpali_engine.models import ColQwen2, ColQwen2Processor
227
+
228
+ # self.logger.info("Initializing model and processor.")
229
+ # try:
230
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
+
232
+ # self.model = (
233
+ # ColQwen2.from_pretrained(
234
+ # path,
235
+ # torch_dtype=torch.bfloat16,
236
+ # device_map="auto",
237
+ # )
238
+ # .to(self.device)
239
+ # .eval()
240
+ # )
241
+
242
+ # self.processor = ColQwen2Processor.from_pretrained(path)
243
+ # self.default_batch_size = default_batch_size
244
+ # self.logger.info("Initialization complete.")
245
+ # except Exception as e:
246
+ # self.logger.error(f"Failed to initialize model or processor: {e}")
247
+ # raise
248
+
249
+ # def _process_image_embeddings(
250
+ # self, images: List[Image.Image], batch_size: int
251
+ # ) -> torch.Tensor:
252
+ # """
253
+ # Processes images and generates embeddings.
254
+
255
+ # Args:
256
+ # images (List[Image.Image]): List of images to process.
257
+ # batch_size (int): Batch size for processing images.
258
+
259
+ # Returns:
260
+ # torch.Tensor: Tensor containing embeddings for each image.
261
+ # """
262
+ # self.logger.debug(f"Processing {len(images)} images.")
263
+ # try:
264
+ # image_dataset = ImageDataset(images, self.processor)
265
+ # image_loader = DataLoader(
266
+ # image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
267
+ # )
268
+
269
+ # all_embeddings = []
270
+ # with torch.no_grad():
271
+ # for batch in image_loader:
272
+ # batch_images = batch[0].to(self.device, non_blocking=True)
273
+ # with torch.cuda.amp.autocast():
274
+ # embeddings = self.model(**batch_images)
275
+ # all_embeddings.append(embeddings)
276
+ # image_embeddings = torch.cat(all_embeddings, dim=0)
277
+ # self.logger.debug("Image processing complete.")
278
+ # return image_embeddings
279
+ # except Exception as e:
280
+ # self.logger.error(f"Error processing images: {e}")
281
+ # raise
282
+
283
+ # def _process_text_embeddings(
284
+ # self, texts: List[str], batch_size: int
285
+ # ) -> torch.Tensor:
286
+ # """
287
+ # Processes text queries and generates embeddings.
288
+
289
+ # Args:
290
+ # texts (List[str]): List of text queries to process.
291
+ # batch_size (int): Batch size for processing texts.
292
+
293
+ # Returns:
294
+ # torch.Tensor: Tensor containing embeddings for each text query.
295
+ # """
296
+ # self.logger.debug(f"Processing {len(texts)} text queries.")
297
+ # try:
298
+ # text_dataset = TextDataset(texts, self.processor)
299
+ # text_loader = DataLoader(
300
+ # text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
301
+ # )
302
+
303
+ # all_embeddings = []
304
+ # with torch.no_grad():
305
+ # for batch in text_loader:
306
+ # batch_texts = batch[0].to(self.device, non_blocking=True)
307
+ # with torch.amp.autocast():
308
+ # embeddings = self.model(**batch_texts)
309
+ # all_embeddings.append(embeddings)
310
+ # text_embeddings = torch.cat(all_embeddings, dim=0)
311
+ # self.logger.debug("Text processing complete.")
312
+ # return text_embeddings
313
+ # except Exception as e:
314
+ # self.logger.error(f"Error processing texts: {e}")
315
+ # raise
316
+
317
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
318
+ # """
319
+ # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
320
+
321
+ # Args:
322
+ # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
323
+
324
+ # Returns:
325
+ # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
326
+ # """
327
+ # images_data = data.get("image", [])
328
+ # text_data = data.get("text", [])
329
+ # batch_size = data.get("batch_size", self.default_batch_size)
330
+
331
+ # images = []
332
+ # if images_data:
333
+ # self.logger.info("Decoding images from base64.")
334
+ # for img_data in images_data:
335
+ # if isinstance(img_data, str):
336
+ # try:
337
+ # image_bytes = base64.b64decode(img_data)
338
+ # image = Image.open(BytesIO(image_bytes)).convert("RGB")
339
+ # images.append(image)
340
+ # except Exception as e:
341
+ # self.logger.error(f"Invalid image data: {e}")
342
+ # return {"error": f"Invalid image data: {e}"}
343
+ # else:
344
+ # self.logger.error("Images should be base64-encoded strings.")
345
+ # return {"error": "Images should be base64-encoded strings."}
346
+
347
+ # image_embeddings = None
348
+ # text_embeddings = None
349
+ # scores = None
350
+
351
+ # def process_images():
352
+ # nonlocal image_embeddings
353
+ # if images:
354
+ # self.logger.info("Processing image embeddings.")
355
+ # try:
356
+ # image_embeddings = self._process_image_embeddings(
357
+ # images, batch_size
358
+ # )
359
+ # except Exception as e:
360
+ # self.logger.error(f"Error generating image embeddings: {e}")
361
+
362
+ # def process_texts():
363
+ # nonlocal text_embeddings
364
+ # if text_data:
365
+ # self.logger.info("Processing text embeddings.")
366
+ # try:
367
+ # text_embeddings = self._process_text_embeddings(
368
+ # text_data, batch_size
369
+ # )
370
+ # except Exception as e:
371
+ # self.logger.error(f"Error generating text embeddings: {e}")
372
+
373
+ # # Process images and texts in parallel if both are present
374
+ # threads = []
375
+ # if images_data and text_data:
376
+ # image_thread = threading.Thread(target=process_images)
377
+ # text_thread = threading.Thread(target=process_texts)
378
+ # threads.extend([image_thread, text_thread])
379
+ # image_thread.start()
380
+ # text_thread.start()
381
+ # for thread in threads:
382
+ # thread.join()
383
+ # else:
384
+ # process_images()
385
+ # process_texts()
386
+
387
+ # # Compute similarity scores if both embeddings are available
388
+ # if image_embeddings is not None and text_embeddings is not None:
389
+ # self.logger.info("Computing similarity scores.")
390
+ # try:
391
+ # with torch.no_grad(), torch.amp.autocast():
392
+ # scores = self.processor.score_multi_vector(
393
+ # text_embeddings, image_embeddings
394
+ # )
395
+ # self.logger.info("Similarity scoring complete.")
396
+ # except Exception as e:
397
+ # self.logger.error(f"Error computing similarity scores: {e}")
398
+ # return {"error": f"Error computing similarity scores: {e}"}
399
+
400
+ # result = {}
401
+ # if image_embeddings is not None:
402
+ # result["image"] = image_embeddings.cpu().tolist()
403
+ # if text_embeddings is not None:
404
+ # result["text"] = text_embeddings.cpu().tolist()
405
+ # if scores is not None:
406
+ # result["scores"] = scores.cpu().tolist()
407
 
408
+ # return result