amaye15 commited on
Commit
aea7238
1 Parent(s): a4af2d9

Automatic Batching

Browse files
Files changed (1) hide show
  1. handler.py +86 -16
handler.py CHANGED
@@ -1,12 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
- from typing import Dict, Any
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
 
7
 
8
  class EndpointHandler:
9
- def __init__(self, path: str = ""):
10
  # Import your model and processor inside the class
11
  from colpali_engine.models import ColQwen2, ColQwen2Processor
12
 
@@ -21,14 +81,30 @@ class EndpointHandler:
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  self.model.to(self.device)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
  # Extract images from the input data
26
  images_data = data.get("inputs", [])
 
27
 
28
  if not images_data:
29
  return {"error": "No images provided in 'inputs'."}
30
 
31
- # Process images
32
  images = []
33
  for img_data in images_data:
34
  if isinstance(img_data, str):
@@ -42,17 +118,11 @@ class EndpointHandler:
42
  else:
43
  return {"error": "Images should be base64-encoded strings."}
44
 
45
- # Prepare inputs
46
- batch_images = self.processor.process_images(images)
47
-
48
- # Move tensors to the device
49
- batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
50
-
51
- # Generate embeddings
52
- with torch.no_grad():
53
- image_embeddings = self.model(**batch_images)
54
-
55
- # Convert embeddings to a list
56
- embeddings_list = image_embeddings.cpu().tolist()
57
 
58
- return {"embeddings": embeddings_list}
 
1
+ # import torch
2
+ # from typing import Dict, Any
3
+ # from PIL import Image
4
+ # import base64
5
+ # from io import BytesIO
6
+
7
+
8
+ # class EndpointHandler:
9
+ # def __init__(self, path: str = ""):
10
+ # # Import your model and processor inside the class
11
+ # from colpali_engine.models import ColQwen2, ColQwen2Processor
12
+
13
+ # # Load the model and processor
14
+ # self.model = ColQwen2.from_pretrained(
15
+ # path,
16
+ # torch_dtype=torch.bfloat16,
17
+ # ).eval()
18
+ # self.processor = ColQwen2Processor.from_pretrained(path)
19
+
20
+ # # Determine the device
21
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ # self.model.to(self.device)
23
+
24
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
+ # # Extract images from the input data
26
+ # images_data = data.get("inputs", [])
27
+
28
+ # if not images_data:
29
+ # return {"error": "No images provided in 'inputs'."}
30
+
31
+ # # Process images
32
+ # images = []
33
+ # for img_data in images_data:
34
+ # if isinstance(img_data, str):
35
+ # try:
36
+ # # Assume base64-encoded image
37
+ # image_bytes = base64.b64decode(img_data)
38
+ # image = Image.open(BytesIO(image_bytes)).convert("RGB")
39
+ # images.append(image)
40
+ # except Exception as e:
41
+ # return {"error": f"Invalid image data: {e}"}
42
+ # else:
43
+ # return {"error": "Images should be base64-encoded strings."}
44
+
45
+ # # Prepare inputs
46
+ # batch_images = self.processor.process_images(images)
47
+
48
+ # # Move tensors to the device
49
+ # batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
50
+
51
+ # # Generate embeddings
52
+ # with torch.no_grad():
53
+ # image_embeddings = self.model(**batch_images)
54
+
55
+ # # Convert embeddings to a list
56
+ # embeddings_list = image_embeddings.cpu().tolist()
57
+
58
+ # return {"embeddings": embeddings_list}
59
+
60
+
61
  import torch
62
+ from typing import Dict, Any, List
63
  from PIL import Image
64
  import base64
65
  from io import BytesIO
66
 
67
 
68
  class EndpointHandler:
69
+ def __init__(self, path: str = "", default_batch_size: int = 4):
70
  # Import your model and processor inside the class
71
  from colpali_engine.models import ColQwen2, ColQwen2Processor
72
 
 
81
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  self.model.to(self.device)
83
 
84
+ # Set default batch size
85
+ self.default_batch_size = default_batch_size
86
+
87
+ def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
88
+ # Prepare inputs for a batch
89
+ batch_images = self.processor.process_images(images)
90
+ batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
91
+
92
+ # Generate embeddings
93
+ with torch.no_grad():
94
+ image_embeddings = self.model(**batch_images)
95
+
96
+ # Convert embeddings to list format
97
+ return image_embeddings.cpu().tolist()
98
+
99
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
100
  # Extract images from the input data
101
  images_data = data.get("inputs", [])
102
+ batch_size = data.get("batch_size", self.default_batch_size)
103
 
104
  if not images_data:
105
  return {"error": "No images provided in 'inputs'."}
106
 
107
+ # Decode and validate images
108
  images = []
109
  for img_data in images_data:
110
  if isinstance(img_data, str):
 
118
  else:
119
  return {"error": "Images should be base64-encoded strings."}
120
 
121
+ # Process in batches with the specified or default batch size
122
+ embeddings = []
123
+ for i in range(0, len(images), batch_size):
124
+ batch_images = images[i : i + batch_size]
125
+ batch_embeddings = self._process_batch(batch_images)
126
+ embeddings.extend(batch_embeddings)
 
 
 
 
 
 
127
 
128
+ return {"embeddings": embeddings}