amaye15
commited on
Commit
•
aea7238
1
Parent(s):
a4af2d9
Automatic Batching
Browse files- handler.py +86 -16
handler.py
CHANGED
@@ -1,12 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from typing import Dict, Any
|
3 |
from PIL import Image
|
4 |
import base64
|
5 |
from io import BytesIO
|
6 |
|
7 |
|
8 |
class EndpointHandler:
|
9 |
-
def __init__(self, path: str = ""):
|
10 |
# Import your model and processor inside the class
|
11 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
12 |
|
@@ -21,14 +81,30 @@ class EndpointHandler:
|
|
21 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
self.model.to(self.device)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
25 |
# Extract images from the input data
|
26 |
images_data = data.get("inputs", [])
|
|
|
27 |
|
28 |
if not images_data:
|
29 |
return {"error": "No images provided in 'inputs'."}
|
30 |
|
31 |
-
#
|
32 |
images = []
|
33 |
for img_data in images_data:
|
34 |
if isinstance(img_data, str):
|
@@ -42,17 +118,11 @@ class EndpointHandler:
|
|
42 |
else:
|
43 |
return {"error": "Images should be base64-encoded strings."}
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
# Generate embeddings
|
52 |
-
with torch.no_grad():
|
53 |
-
image_embeddings = self.model(**batch_images)
|
54 |
-
|
55 |
-
# Convert embeddings to a list
|
56 |
-
embeddings_list = image_embeddings.cpu().tolist()
|
57 |
|
58 |
-
return {"embeddings":
|
|
|
1 |
+
# import torch
|
2 |
+
# from typing import Dict, Any
|
3 |
+
# from PIL import Image
|
4 |
+
# import base64
|
5 |
+
# from io import BytesIO
|
6 |
+
|
7 |
+
|
8 |
+
# class EndpointHandler:
|
9 |
+
# def __init__(self, path: str = ""):
|
10 |
+
# # Import your model and processor inside the class
|
11 |
+
# from colpali_engine.models import ColQwen2, ColQwen2Processor
|
12 |
+
|
13 |
+
# # Load the model and processor
|
14 |
+
# self.model = ColQwen2.from_pretrained(
|
15 |
+
# path,
|
16 |
+
# torch_dtype=torch.bfloat16,
|
17 |
+
# ).eval()
|
18 |
+
# self.processor = ColQwen2Processor.from_pretrained(path)
|
19 |
+
|
20 |
+
# # Determine the device
|
21 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
# self.model.to(self.device)
|
23 |
+
|
24 |
+
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
25 |
+
# # Extract images from the input data
|
26 |
+
# images_data = data.get("inputs", [])
|
27 |
+
|
28 |
+
# if not images_data:
|
29 |
+
# return {"error": "No images provided in 'inputs'."}
|
30 |
+
|
31 |
+
# # Process images
|
32 |
+
# images = []
|
33 |
+
# for img_data in images_data:
|
34 |
+
# if isinstance(img_data, str):
|
35 |
+
# try:
|
36 |
+
# # Assume base64-encoded image
|
37 |
+
# image_bytes = base64.b64decode(img_data)
|
38 |
+
# image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
39 |
+
# images.append(image)
|
40 |
+
# except Exception as e:
|
41 |
+
# return {"error": f"Invalid image data: {e}"}
|
42 |
+
# else:
|
43 |
+
# return {"error": "Images should be base64-encoded strings."}
|
44 |
+
|
45 |
+
# # Prepare inputs
|
46 |
+
# batch_images = self.processor.process_images(images)
|
47 |
+
|
48 |
+
# # Move tensors to the device
|
49 |
+
# batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
50 |
+
|
51 |
+
# # Generate embeddings
|
52 |
+
# with torch.no_grad():
|
53 |
+
# image_embeddings = self.model(**batch_images)
|
54 |
+
|
55 |
+
# # Convert embeddings to a list
|
56 |
+
# embeddings_list = image_embeddings.cpu().tolist()
|
57 |
+
|
58 |
+
# return {"embeddings": embeddings_list}
|
59 |
+
|
60 |
+
|
61 |
import torch
|
62 |
+
from typing import Dict, Any, List
|
63 |
from PIL import Image
|
64 |
import base64
|
65 |
from io import BytesIO
|
66 |
|
67 |
|
68 |
class EndpointHandler:
|
69 |
+
def __init__(self, path: str = "", default_batch_size: int = 4):
|
70 |
# Import your model and processor inside the class
|
71 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
72 |
|
|
|
81 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
82 |
self.model.to(self.device)
|
83 |
|
84 |
+
# Set default batch size
|
85 |
+
self.default_batch_size = default_batch_size
|
86 |
+
|
87 |
+
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
88 |
+
# Prepare inputs for a batch
|
89 |
+
batch_images = self.processor.process_images(images)
|
90 |
+
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
91 |
+
|
92 |
+
# Generate embeddings
|
93 |
+
with torch.no_grad():
|
94 |
+
image_embeddings = self.model(**batch_images)
|
95 |
+
|
96 |
+
# Convert embeddings to list format
|
97 |
+
return image_embeddings.cpu().tolist()
|
98 |
+
|
99 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
100 |
# Extract images from the input data
|
101 |
images_data = data.get("inputs", [])
|
102 |
+
batch_size = data.get("batch_size", self.default_batch_size)
|
103 |
|
104 |
if not images_data:
|
105 |
return {"error": "No images provided in 'inputs'."}
|
106 |
|
107 |
+
# Decode and validate images
|
108 |
images = []
|
109 |
for img_data in images_data:
|
110 |
if isinstance(img_data, str):
|
|
|
118 |
else:
|
119 |
return {"error": "Images should be base64-encoded strings."}
|
120 |
|
121 |
+
# Process in batches with the specified or default batch size
|
122 |
+
embeddings = []
|
123 |
+
for i in range(0, len(images), batch_size):
|
124 |
+
batch_images = images[i : i + batch_size]
|
125 |
+
batch_embeddings = self._process_batch(batch_images)
|
126 |
+
embeddings.extend(batch_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
return {"embeddings": embeddings}
|