amaye15
commited on
Commit
•
18f8eec
1
Parent(s):
59f3026
Old handler with autocast
Browse files- handler.py +308 -308
handler.py
CHANGED
@@ -1,207 +1,9 @@
|
|
1 |
-
# import torch
|
2 |
-
# from typing import Dict, Any, List
|
3 |
-
# from PIL import Image
|
4 |
-
# import base64
|
5 |
-
# from io import BytesIO
|
6 |
-
# import logging
|
7 |
-
|
8 |
-
|
9 |
-
# class EndpointHandler:
|
10 |
-
# """
|
11 |
-
# A handler class for processing image and text data, generating embeddings using a specified model and processor.
|
12 |
-
|
13 |
-
# Attributes:
|
14 |
-
# model: The pre-trained model used for generating embeddings.
|
15 |
-
# processor: The pre-trained processor used to process images and text before model inference.
|
16 |
-
# device: The device (CPU or CUDA) used to run model inference.
|
17 |
-
# default_batch_size: The default batch size for processing images and text in batches.
|
18 |
-
# """
|
19 |
-
|
20 |
-
# def __init__(self, path: str = "", default_batch_size: int = 4):
|
21 |
-
# """
|
22 |
-
# Initializes the EndpointHandler with a specified model path and default batch size.
|
23 |
-
|
24 |
-
# Args:
|
25 |
-
# path (str): Path to the pre-trained model and processor.
|
26 |
-
# default_batch_size (int): Default batch size for processing images and text data.
|
27 |
-
# """
|
28 |
-
# # Initialize logging
|
29 |
-
# logging.basicConfig(level=logging.INFO)
|
30 |
-
# self.logger = logging.getLogger(__name__)
|
31 |
-
|
32 |
-
# from colpali_engine.models import ColQwen2, ColQwen2Processor
|
33 |
-
|
34 |
-
# self.logger.info("Initializing model and processor.")
|
35 |
-
# try:
|
36 |
-
# self.model = ColQwen2.from_pretrained(
|
37 |
-
# path,
|
38 |
-
# torch_dtype=torch.bfloat16,
|
39 |
-
# device_map="auto",
|
40 |
-
# ).eval()
|
41 |
-
# self.processor = ColQwen2Processor.from_pretrained(path)
|
42 |
-
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
-
# self.model.to(self.device)
|
44 |
-
# self.default_batch_size = default_batch_size
|
45 |
-
# self.logger.info("Initialization complete.")
|
46 |
-
# except Exception as e:
|
47 |
-
# self.logger.error(f"Failed to initialize model or processor: {e}")
|
48 |
-
# raise
|
49 |
-
|
50 |
-
# def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
51 |
-
# """
|
52 |
-
# Processes a batch of images and generates embeddings.
|
53 |
-
|
54 |
-
# Args:
|
55 |
-
# images (List[Image.Image]): List of images to process.
|
56 |
-
|
57 |
-
# Returns:
|
58 |
-
# List[List[float]]: List of embeddings for each image.
|
59 |
-
# """
|
60 |
-
# self.logger.debug(f"Processing batch of {len(images)} images.")
|
61 |
-
# try:
|
62 |
-
# batch_images = self.processor.process_images(images).to(self.device)
|
63 |
-
# with torch.no_grad(), torch.amp.autocast():
|
64 |
-
# image_embeddings = self.model(**batch_images)
|
65 |
-
# self.logger.debug("Image batch processing complete.")
|
66 |
-
# return image_embeddings.cpu().tolist()
|
67 |
-
# except Exception as e:
|
68 |
-
# self.logger.error(f"Error processing image batch: {e}")
|
69 |
-
# raise
|
70 |
-
|
71 |
-
# def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
|
72 |
-
# """
|
73 |
-
# Processes a batch of text queries and generates embeddings.
|
74 |
-
|
75 |
-
# Args:
|
76 |
-
# texts (List[str]): List of text queries to process.
|
77 |
-
|
78 |
-
# Returns:
|
79 |
-
# List[List[float]]: List of embeddings for each text query.
|
80 |
-
# """
|
81 |
-
# self.logger.debug(f"Processing batch of {len(texts)} text queries.")
|
82 |
-
# try:
|
83 |
-
# batch_queries = self.processor.process_queries(texts).to(self.device)
|
84 |
-
# with torch.no_grad(), torch.amp.autocast():
|
85 |
-
# query_embeddings = self.model(**batch_queries)
|
86 |
-
# self.logger.debug("Text batch processing complete.")
|
87 |
-
# return query_embeddings.cpu().tolist()
|
88 |
-
# except Exception as e:
|
89 |
-
# self.logger.error(f"Error processing text batch: {e}")
|
90 |
-
# raise
|
91 |
-
|
92 |
-
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
93 |
-
# """
|
94 |
-
# Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
|
95 |
-
|
96 |
-
# Args:
|
97 |
-
# data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
|
98 |
-
|
99 |
-
# Returns:
|
100 |
-
# Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
|
101 |
-
# """
|
102 |
-
# images_data = data.get("image", [])
|
103 |
-
# text_data = data.get("text", [])
|
104 |
-
# batch_size = data.get("batch_size", self.default_batch_size)
|
105 |
-
|
106 |
-
# # Decode and process images
|
107 |
-
# images = []
|
108 |
-
# if images_data:
|
109 |
-
# self.logger.info("Decoding images from base64.")
|
110 |
-
# for img_data in images_data:
|
111 |
-
# if isinstance(img_data, str):
|
112 |
-
# try:
|
113 |
-
# image_bytes = base64.b64decode(img_data)
|
114 |
-
# image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
115 |
-
# images.append(image)
|
116 |
-
# except Exception as e:
|
117 |
-
# self.logger.error(f"Invalid image data: {e}")
|
118 |
-
# return {"error": f"Invalid image data: {e}"}
|
119 |
-
# else:
|
120 |
-
# self.logger.error("Images should be base64-encoded strings.")
|
121 |
-
# return {"error": "Images should be base64-encoded strings."}
|
122 |
-
|
123 |
-
# image_embeddings = []
|
124 |
-
# if images:
|
125 |
-
# self.logger.info("Processing image embeddings.")
|
126 |
-
# try:
|
127 |
-
# for i in range(0, len(images), batch_size):
|
128 |
-
# batch_images = images[i : i + batch_size]
|
129 |
-
# batch_embeddings = self._process_image_batch(batch_images)
|
130 |
-
# image_embeddings.extend(batch_embeddings)
|
131 |
-
# except Exception as e:
|
132 |
-
# self.logger.error(f"Error generating image embeddings: {e}")
|
133 |
-
# return {"error": f"Error generating image embeddings: {e}"}
|
134 |
-
|
135 |
-
# # Process text data
|
136 |
-
# text_embeddings = []
|
137 |
-
# if text_data:
|
138 |
-
# self.logger.info("Processing text embeddings.")
|
139 |
-
# try:
|
140 |
-
# for i in range(0, len(text_data), batch_size):
|
141 |
-
# batch_texts = text_data[i : i + batch_size]
|
142 |
-
# batch_text_embeddings = self._process_text_batch(batch_texts)
|
143 |
-
# text_embeddings.extend(batch_text_embeddings)
|
144 |
-
# except Exception as e:
|
145 |
-
# self.logger.error(f"Error generating text embeddings: {e}")
|
146 |
-
# return {"error": f"Error generating text embeddings: {e}"}
|
147 |
-
|
148 |
-
# # Compute similarity scores if both image and text embeddings are available
|
149 |
-
# scores = []
|
150 |
-
# if image_embeddings and text_embeddings:
|
151 |
-
# self.logger.info("Computing similarity scores.")
|
152 |
-
# try:
|
153 |
-
# image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
|
154 |
-
# text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
|
155 |
-
# with torch.no_grad(), torch.amp.autocast():
|
156 |
-
# scores = (
|
157 |
-
# self.processor.score_multi_vector(
|
158 |
-
# text_embeddings_tensor, image_embeddings_tensor
|
159 |
-
# )
|
160 |
-
# .cpu()
|
161 |
-
# .tolist()
|
162 |
-
# )
|
163 |
-
# self.logger.info("Similarity scoring complete.")
|
164 |
-
# except Exception as e:
|
165 |
-
# self.logger.error(f"Error computing similarity scores: {e}")
|
166 |
-
# return {"error": f"Error computing similarity scores: {e}"}
|
167 |
-
|
168 |
-
# return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|
169 |
-
|
170 |
-
|
171 |
import torch
|
172 |
from typing import Dict, Any, List
|
173 |
from PIL import Image
|
174 |
import base64
|
175 |
from io import BytesIO
|
176 |
import logging
|
177 |
-
from torch.utils.data import DataLoader, Dataset
|
178 |
-
import threading
|
179 |
-
|
180 |
-
|
181 |
-
class ImageDataset(Dataset):
|
182 |
-
def __init__(self, images: List[Image.Image], processor):
|
183 |
-
self.images = images
|
184 |
-
self.processor = processor
|
185 |
-
|
186 |
-
def __len__(self):
|
187 |
-
return len(self.images)
|
188 |
-
|
189 |
-
def __getitem__(self, idx):
|
190 |
-
image = self.processor.process_images([self.images[idx]])
|
191 |
-
return image
|
192 |
-
|
193 |
-
|
194 |
-
class TextDataset(Dataset):
|
195 |
-
def __init__(self, texts: List[str], processor):
|
196 |
-
self.texts = texts
|
197 |
-
self.processor = processor
|
198 |
-
|
199 |
-
def __len__(self):
|
200 |
-
return len(self.texts)
|
201 |
-
|
202 |
-
def __getitem__(self, idx):
|
203 |
-
text = self.processor.process_queries([self.texts[idx]])
|
204 |
-
return text
|
205 |
|
206 |
|
207 |
class EndpointHandler:
|
@@ -218,6 +20,10 @@ class EndpointHandler:
|
|
218 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
219 |
"""
|
220 |
Initializes the EndpointHandler with a specified model path and default batch size.
|
|
|
|
|
|
|
|
|
221 |
"""
|
222 |
# Initialize logging
|
223 |
logging.basicConfig(level=logging.INFO)
|
@@ -227,91 +33,60 @@ class EndpointHandler:
|
|
227 |
|
228 |
self.logger.info("Initializing model and processor.")
|
229 |
try:
|
230 |
-
self.
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
torch_dtype=torch.bfloat16,
|
236 |
-
device_map="auto",
|
237 |
-
)
|
238 |
-
.to(self.device)
|
239 |
-
.eval()
|
240 |
-
)
|
241 |
-
|
242 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
|
|
|
|
243 |
self.default_batch_size = default_batch_size
|
244 |
self.logger.info("Initialization complete.")
|
245 |
except Exception as e:
|
246 |
self.logger.error(f"Failed to initialize model or processor: {e}")
|
247 |
raise
|
248 |
|
249 |
-
def
|
250 |
-
self, images: List[Image.Image], batch_size: int
|
251 |
-
) -> torch.Tensor:
|
252 |
"""
|
253 |
-
Processes images and generates embeddings.
|
254 |
|
255 |
Args:
|
256 |
images (List[Image.Image]): List of images to process.
|
257 |
-
batch_size (int): Batch size for processing images.
|
258 |
|
259 |
Returns:
|
260 |
-
|
261 |
"""
|
262 |
-
self.logger.debug(f"Processing {len(images)} images.")
|
263 |
try:
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
)
|
268 |
-
|
269 |
-
all_embeddings = []
|
270 |
-
with torch.no_grad():
|
271 |
-
for batch in image_loader:
|
272 |
-
batch_images = batch[0].to(self.device, non_blocking=True)
|
273 |
-
with torch.cuda.amp.autocast():
|
274 |
-
embeddings = self.model(**batch_images)
|
275 |
-
all_embeddings.append(embeddings)
|
276 |
-
image_embeddings = torch.cat(all_embeddings, dim=0)
|
277 |
-
self.logger.debug("Image processing complete.")
|
278 |
-
return image_embeddings
|
279 |
except Exception as e:
|
280 |
-
self.logger.error(f"Error processing
|
281 |
raise
|
282 |
|
283 |
-
def
|
284 |
-
self, texts: List[str], batch_size: int
|
285 |
-
) -> torch.Tensor:
|
286 |
"""
|
287 |
-
Processes text queries and generates embeddings.
|
288 |
|
289 |
Args:
|
290 |
texts (List[str]): List of text queries to process.
|
291 |
-
batch_size (int): Batch size for processing texts.
|
292 |
|
293 |
Returns:
|
294 |
-
|
295 |
"""
|
296 |
-
self.logger.debug(f"Processing {len(texts)} text queries.")
|
297 |
try:
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
)
|
302 |
-
|
303 |
-
all_embeddings = []
|
304 |
-
with torch.no_grad():
|
305 |
-
for batch in text_loader:
|
306 |
-
batch_texts = batch[0].to(self.device, non_blocking=True)
|
307 |
-
with torch.amp.autocast():
|
308 |
-
embeddings = self.model(**batch_texts)
|
309 |
-
all_embeddings.append(embeddings)
|
310 |
-
text_embeddings = torch.cat(all_embeddings, dim=0)
|
311 |
-
self.logger.debug("Text processing complete.")
|
312 |
-
return text_embeddings
|
313 |
except Exception as e:
|
314 |
-
self.logger.error(f"Error processing
|
315 |
raise
|
316 |
|
317 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
@@ -328,6 +103,7 @@ class EndpointHandler:
|
|
328 |
text_data = data.get("text", [])
|
329 |
batch_size = data.get("batch_size", self.default_batch_size)
|
330 |
|
|
|
331 |
images = []
|
332 |
if images_data:
|
333 |
self.logger.info("Decoding images from base64.")
|
@@ -344,65 +120,289 @@ class EndpointHandler:
|
|
344 |
self.logger.error("Images should be base64-encoded strings.")
|
345 |
return {"error": "Images should be base64-encoded strings."}
|
346 |
|
347 |
-
image_embeddings =
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
text_embeddings
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
if images_data and text_data:
|
376 |
-
image_thread = threading.Thread(target=process_images)
|
377 |
-
text_thread = threading.Thread(target=process_texts)
|
378 |
-
threads.extend([image_thread, text_thread])
|
379 |
-
image_thread.start()
|
380 |
-
text_thread.start()
|
381 |
-
for thread in threads:
|
382 |
-
thread.join()
|
383 |
-
else:
|
384 |
-
process_images()
|
385 |
-
process_texts()
|
386 |
-
|
387 |
-
# Compute similarity scores if both embeddings are available
|
388 |
-
if image_embeddings is not None and text_embeddings is not None:
|
389 |
self.logger.info("Computing similarity scores.")
|
390 |
try:
|
|
|
|
|
391 |
with torch.no_grad(), torch.amp.autocast():
|
392 |
-
scores =
|
393 |
-
|
|
|
|
|
|
|
|
|
394 |
)
|
395 |
self.logger.info("Similarity scoring complete.")
|
396 |
except Exception as e:
|
397 |
self.logger.error(f"Error computing similarity scores: {e}")
|
398 |
return {"error": f"Error computing similarity scores: {e}"}
|
399 |
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from typing import Dict, Any, List
|
3 |
from PIL import Image
|
4 |
import base64
|
5 |
from io import BytesIO
|
6 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
class EndpointHandler:
|
|
|
20 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
21 |
"""
|
22 |
Initializes the EndpointHandler with a specified model path and default batch size.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
path (str): Path to the pre-trained model and processor.
|
26 |
+
default_batch_size (int): Default batch size for processing images and text data.
|
27 |
"""
|
28 |
# Initialize logging
|
29 |
logging.basicConfig(level=logging.INFO)
|
|
|
33 |
|
34 |
self.logger.info("Initializing model and processor.")
|
35 |
try:
|
36 |
+
self.model = ColQwen2.from_pretrained(
|
37 |
+
path,
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
device_map="auto",
|
40 |
+
).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
self.model.to(self.device)
|
44 |
self.default_batch_size = default_batch_size
|
45 |
self.logger.info("Initialization complete.")
|
46 |
except Exception as e:
|
47 |
self.logger.error(f"Failed to initialize model or processor: {e}")
|
48 |
raise
|
49 |
|
50 |
+
def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
|
|
|
|
51 |
"""
|
52 |
+
Processes a batch of images and generates embeddings.
|
53 |
|
54 |
Args:
|
55 |
images (List[Image.Image]): List of images to process.
|
|
|
56 |
|
57 |
Returns:
|
58 |
+
List[List[float]]: List of embeddings for each image.
|
59 |
"""
|
60 |
+
self.logger.debug(f"Processing batch of {len(images)} images.")
|
61 |
try:
|
62 |
+
batch_images = self.processor.process_images(images).to(self.device)
|
63 |
+
with torch.no_grad(), torch.amp.autocast():
|
64 |
+
image_embeddings = self.model(**batch_images)
|
65 |
+
self.logger.debug("Image batch processing complete.")
|
66 |
+
return image_embeddings.cpu().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
except Exception as e:
|
68 |
+
self.logger.error(f"Error processing image batch: {e}")
|
69 |
raise
|
70 |
|
71 |
+
def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
|
72 |
"""
|
73 |
+
Processes a batch of text queries and generates embeddings.
|
74 |
|
75 |
Args:
|
76 |
texts (List[str]): List of text queries to process.
|
|
|
77 |
|
78 |
Returns:
|
79 |
+
List[List[float]]: List of embeddings for each text query.
|
80 |
"""
|
81 |
+
self.logger.debug(f"Processing batch of {len(texts)} text queries.")
|
82 |
try:
|
83 |
+
batch_queries = self.processor.process_queries(texts).to(self.device)
|
84 |
+
with torch.no_grad(), torch.amp.autocast():
|
85 |
+
query_embeddings = self.model(**batch_queries)
|
86 |
+
self.logger.debug("Text batch processing complete.")
|
87 |
+
return query_embeddings.cpu().tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
except Exception as e:
|
89 |
+
self.logger.error(f"Error processing text batch: {e}")
|
90 |
raise
|
91 |
|
92 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
103 |
text_data = data.get("text", [])
|
104 |
batch_size = data.get("batch_size", self.default_batch_size)
|
105 |
|
106 |
+
# Decode and process images
|
107 |
images = []
|
108 |
if images_data:
|
109 |
self.logger.info("Decoding images from base64.")
|
|
|
120 |
self.logger.error("Images should be base64-encoded strings.")
|
121 |
return {"error": "Images should be base64-encoded strings."}
|
122 |
|
123 |
+
image_embeddings = []
|
124 |
+
if images:
|
125 |
+
self.logger.info("Processing image embeddings.")
|
126 |
+
try:
|
127 |
+
for i in range(0, len(images), batch_size):
|
128 |
+
batch_images = images[i : i + batch_size]
|
129 |
+
batch_embeddings = self._process_image_batch(batch_images)
|
130 |
+
image_embeddings.extend(batch_embeddings)
|
131 |
+
except Exception as e:
|
132 |
+
self.logger.error(f"Error generating image embeddings: {e}")
|
133 |
+
return {"error": f"Error generating image embeddings: {e}"}
|
134 |
+
|
135 |
+
# Process text data
|
136 |
+
text_embeddings = []
|
137 |
+
if text_data:
|
138 |
+
self.logger.info("Processing text embeddings.")
|
139 |
+
try:
|
140 |
+
for i in range(0, len(text_data), batch_size):
|
141 |
+
batch_texts = text_data[i : i + batch_size]
|
142 |
+
batch_text_embeddings = self._process_text_batch(batch_texts)
|
143 |
+
text_embeddings.extend(batch_text_embeddings)
|
144 |
+
except Exception as e:
|
145 |
+
self.logger.error(f"Error generating text embeddings: {e}")
|
146 |
+
return {"error": f"Error generating text embeddings: {e}"}
|
147 |
+
|
148 |
+
# Compute similarity scores if both image and text embeddings are available
|
149 |
+
scores = []
|
150 |
+
if image_embeddings and text_embeddings:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
self.logger.info("Computing similarity scores.")
|
152 |
try:
|
153 |
+
image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
|
154 |
+
text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
|
155 |
with torch.no_grad(), torch.amp.autocast():
|
156 |
+
scores = (
|
157 |
+
self.processor.score_multi_vector(
|
158 |
+
text_embeddings_tensor, image_embeddings_tensor
|
159 |
+
)
|
160 |
+
.cpu()
|
161 |
+
.tolist()
|
162 |
)
|
163 |
self.logger.info("Similarity scoring complete.")
|
164 |
except Exception as e:
|
165 |
self.logger.error(f"Error computing similarity scores: {e}")
|
166 |
return {"error": f"Error computing similarity scores: {e}"}
|
167 |
|
168 |
+
return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|
169 |
+
|
170 |
+
|
171 |
+
# import torch
|
172 |
+
# from typing import Dict, Any, List
|
173 |
+
# from PIL import Image
|
174 |
+
# import base64
|
175 |
+
# from io import BytesIO
|
176 |
+
# import logging
|
177 |
+
# from torch.utils.data import DataLoader, Dataset
|
178 |
+
# import threading
|
179 |
+
|
180 |
+
|
181 |
+
# class ImageDataset(Dataset):
|
182 |
+
# def __init__(self, images: List[Image.Image], processor):
|
183 |
+
# self.images = images
|
184 |
+
# self.processor = processor
|
185 |
+
|
186 |
+
# def __len__(self):
|
187 |
+
# return len(self.images)
|
188 |
+
|
189 |
+
# def __getitem__(self, idx):
|
190 |
+
# image = self.processor.process_images([self.images[idx]])
|
191 |
+
# return image
|
192 |
+
|
193 |
+
|
194 |
+
# class TextDataset(Dataset):
|
195 |
+
# def __init__(self, texts: List[str], processor):
|
196 |
+
# self.texts = texts
|
197 |
+
# self.processor = processor
|
198 |
+
|
199 |
+
# def __len__(self):
|
200 |
+
# return len(self.texts)
|
201 |
+
|
202 |
+
# def __getitem__(self, idx):
|
203 |
+
# text = self.processor.process_queries([self.texts[idx]])
|
204 |
+
# return text
|
205 |
+
|
206 |
+
|
207 |
+
# class EndpointHandler:
|
208 |
+
# """
|
209 |
+
# A handler class for processing image and text data, generating embeddings using a specified model and processor.
|
210 |
+
|
211 |
+
# Attributes:
|
212 |
+
# model: The pre-trained model used for generating embeddings.
|
213 |
+
# processor: The pre-trained processor used to process images and text before model inference.
|
214 |
+
# device: The device (CPU or CUDA) used to run model inference.
|
215 |
+
# default_batch_size: The default batch size for processing images and text in batches.
|
216 |
+
# """
|
217 |
+
|
218 |
+
# def __init__(self, path: str = "", default_batch_size: int = 4):
|
219 |
+
# """
|
220 |
+
# Initializes the EndpointHandler with a specified model path and default batch size.
|
221 |
+
# """
|
222 |
+
# # Initialize logging
|
223 |
+
# logging.basicConfig(level=logging.INFO)
|
224 |
+
# self.logger = logging.getLogger(__name__)
|
225 |
+
|
226 |
+
# from colpali_engine.models import ColQwen2, ColQwen2Processor
|
227 |
+
|
228 |
+
# self.logger.info("Initializing model and processor.")
|
229 |
+
# try:
|
230 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
231 |
+
|
232 |
+
# self.model = (
|
233 |
+
# ColQwen2.from_pretrained(
|
234 |
+
# path,
|
235 |
+
# torch_dtype=torch.bfloat16,
|
236 |
+
# device_map="auto",
|
237 |
+
# )
|
238 |
+
# .to(self.device)
|
239 |
+
# .eval()
|
240 |
+
# )
|
241 |
+
|
242 |
+
# self.processor = ColQwen2Processor.from_pretrained(path)
|
243 |
+
# self.default_batch_size = default_batch_size
|
244 |
+
# self.logger.info("Initialization complete.")
|
245 |
+
# except Exception as e:
|
246 |
+
# self.logger.error(f"Failed to initialize model or processor: {e}")
|
247 |
+
# raise
|
248 |
+
|
249 |
+
# def _process_image_embeddings(
|
250 |
+
# self, images: List[Image.Image], batch_size: int
|
251 |
+
# ) -> torch.Tensor:
|
252 |
+
# """
|
253 |
+
# Processes images and generates embeddings.
|
254 |
+
|
255 |
+
# Args:
|
256 |
+
# images (List[Image.Image]): List of images to process.
|
257 |
+
# batch_size (int): Batch size for processing images.
|
258 |
+
|
259 |
+
# Returns:
|
260 |
+
# torch.Tensor: Tensor containing embeddings for each image.
|
261 |
+
# """
|
262 |
+
# self.logger.debug(f"Processing {len(images)} images.")
|
263 |
+
# try:
|
264 |
+
# image_dataset = ImageDataset(images, self.processor)
|
265 |
+
# image_loader = DataLoader(
|
266 |
+
# image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
|
267 |
+
# )
|
268 |
+
|
269 |
+
# all_embeddings = []
|
270 |
+
# with torch.no_grad():
|
271 |
+
# for batch in image_loader:
|
272 |
+
# batch_images = batch[0].to(self.device, non_blocking=True)
|
273 |
+
# with torch.cuda.amp.autocast():
|
274 |
+
# embeddings = self.model(**batch_images)
|
275 |
+
# all_embeddings.append(embeddings)
|
276 |
+
# image_embeddings = torch.cat(all_embeddings, dim=0)
|
277 |
+
# self.logger.debug("Image processing complete.")
|
278 |
+
# return image_embeddings
|
279 |
+
# except Exception as e:
|
280 |
+
# self.logger.error(f"Error processing images: {e}")
|
281 |
+
# raise
|
282 |
+
|
283 |
+
# def _process_text_embeddings(
|
284 |
+
# self, texts: List[str], batch_size: int
|
285 |
+
# ) -> torch.Tensor:
|
286 |
+
# """
|
287 |
+
# Processes text queries and generates embeddings.
|
288 |
+
|
289 |
+
# Args:
|
290 |
+
# texts (List[str]): List of text queries to process.
|
291 |
+
# batch_size (int): Batch size for processing texts.
|
292 |
+
|
293 |
+
# Returns:
|
294 |
+
# torch.Tensor: Tensor containing embeddings for each text query.
|
295 |
+
# """
|
296 |
+
# self.logger.debug(f"Processing {len(texts)} text queries.")
|
297 |
+
# try:
|
298 |
+
# text_dataset = TextDataset(texts, self.processor)
|
299 |
+
# text_loader = DataLoader(
|
300 |
+
# text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
|
301 |
+
# )
|
302 |
+
|
303 |
+
# all_embeddings = []
|
304 |
+
# with torch.no_grad():
|
305 |
+
# for batch in text_loader:
|
306 |
+
# batch_texts = batch[0].to(self.device, non_blocking=True)
|
307 |
+
# with torch.amp.autocast():
|
308 |
+
# embeddings = self.model(**batch_texts)
|
309 |
+
# all_embeddings.append(embeddings)
|
310 |
+
# text_embeddings = torch.cat(all_embeddings, dim=0)
|
311 |
+
# self.logger.debug("Text processing complete.")
|
312 |
+
# return text_embeddings
|
313 |
+
# except Exception as e:
|
314 |
+
# self.logger.error(f"Error processing texts: {e}")
|
315 |
+
# raise
|
316 |
+
|
317 |
+
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
318 |
+
# """
|
319 |
+
# Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
|
320 |
+
|
321 |
+
# Args:
|
322 |
+
# data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
|
323 |
+
|
324 |
+
# Returns:
|
325 |
+
# Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
|
326 |
+
# """
|
327 |
+
# images_data = data.get("image", [])
|
328 |
+
# text_data = data.get("text", [])
|
329 |
+
# batch_size = data.get("batch_size", self.default_batch_size)
|
330 |
+
|
331 |
+
# images = []
|
332 |
+
# if images_data:
|
333 |
+
# self.logger.info("Decoding images from base64.")
|
334 |
+
# for img_data in images_data:
|
335 |
+
# if isinstance(img_data, str):
|
336 |
+
# try:
|
337 |
+
# image_bytes = base64.b64decode(img_data)
|
338 |
+
# image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
339 |
+
# images.append(image)
|
340 |
+
# except Exception as e:
|
341 |
+
# self.logger.error(f"Invalid image data: {e}")
|
342 |
+
# return {"error": f"Invalid image data: {e}"}
|
343 |
+
# else:
|
344 |
+
# self.logger.error("Images should be base64-encoded strings.")
|
345 |
+
# return {"error": "Images should be base64-encoded strings."}
|
346 |
+
|
347 |
+
# image_embeddings = None
|
348 |
+
# text_embeddings = None
|
349 |
+
# scores = None
|
350 |
+
|
351 |
+
# def process_images():
|
352 |
+
# nonlocal image_embeddings
|
353 |
+
# if images:
|
354 |
+
# self.logger.info("Processing image embeddings.")
|
355 |
+
# try:
|
356 |
+
# image_embeddings = self._process_image_embeddings(
|
357 |
+
# images, batch_size
|
358 |
+
# )
|
359 |
+
# except Exception as e:
|
360 |
+
# self.logger.error(f"Error generating image embeddings: {e}")
|
361 |
+
|
362 |
+
# def process_texts():
|
363 |
+
# nonlocal text_embeddings
|
364 |
+
# if text_data:
|
365 |
+
# self.logger.info("Processing text embeddings.")
|
366 |
+
# try:
|
367 |
+
# text_embeddings = self._process_text_embeddings(
|
368 |
+
# text_data, batch_size
|
369 |
+
# )
|
370 |
+
# except Exception as e:
|
371 |
+
# self.logger.error(f"Error generating text embeddings: {e}")
|
372 |
+
|
373 |
+
# # Process images and texts in parallel if both are present
|
374 |
+
# threads = []
|
375 |
+
# if images_data and text_data:
|
376 |
+
# image_thread = threading.Thread(target=process_images)
|
377 |
+
# text_thread = threading.Thread(target=process_texts)
|
378 |
+
# threads.extend([image_thread, text_thread])
|
379 |
+
# image_thread.start()
|
380 |
+
# text_thread.start()
|
381 |
+
# for thread in threads:
|
382 |
+
# thread.join()
|
383 |
+
# else:
|
384 |
+
# process_images()
|
385 |
+
# process_texts()
|
386 |
+
|
387 |
+
# # Compute similarity scores if both embeddings are available
|
388 |
+
# if image_embeddings is not None and text_embeddings is not None:
|
389 |
+
# self.logger.info("Computing similarity scores.")
|
390 |
+
# try:
|
391 |
+
# with torch.no_grad(), torch.amp.autocast():
|
392 |
+
# scores = self.processor.score_multi_vector(
|
393 |
+
# text_embeddings, image_embeddings
|
394 |
+
# )
|
395 |
+
# self.logger.info("Similarity scoring complete.")
|
396 |
+
# except Exception as e:
|
397 |
+
# self.logger.error(f"Error computing similarity scores: {e}")
|
398 |
+
# return {"error": f"Error computing similarity scores: {e}"}
|
399 |
+
|
400 |
+
# result = {}
|
401 |
+
# if image_embeddings is not None:
|
402 |
+
# result["image"] = image_embeddings.cpu().tolist()
|
403 |
+
# if text_embeddings is not None:
|
404 |
+
# result["text"] = text_embeddings.cpu().tolist()
|
405 |
+
# if scores is not None:
|
406 |
+
# result["scores"] = scores.cpu().tolist()
|
407 |
|
408 |
+
# return result
|