File size: 8,746 Bytes
25a8604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b54bf3c
aea7238
b54bf3c
 
 
 
 
 
8d41aec
25a8604
8d41aec
 
4e7495e
25a8604
4e7495e
25a8604
8d41aec
 
f4f3a3e
8d41aec
 
 
 
4e7495e
25a8604
8d41aec
b54bf3c
 
 
 
 
25a8604
 
 
b54bf3c
 
 
 
 
aea7238
 
25a8604
8d41aec
 
 
 
4e7495e
8d41aec
4e7495e
 
8d41aec
25a8604
aea7238
 
 
 
 
 
25a8604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b54bf3c
8d41aec
25a8604
8d41aec
 
25a8604
8d41aec
4e7495e
25a8604
8d41aec
25a8604
 
aea7238
b54bf3c
25a8604
b54bf3c
25a8604
 
 
 
 
 
 
 
 
 
 
 
 
aea7238
 
25a8604
 
 
 
 
 
 
 
 
 
 
dbabaf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# import torch
# from typing import Dict, Any, List
# from PIL import Image
# import base64
# from io import BytesIO


# class EndpointHandler:
#     """
#     A handler class for processing image data, generating embeddings using a specified model and processor.

#     Attributes:
#         model: The pre-trained model used for generating embeddings.
#         processor: The pre-trained processor used to process images before model inference.
#         device: The device (CPU or CUDA) used to run model inference.
#         default_batch_size: The default batch size for processing images in batches.
#     """

#     def __init__(self, path: str = "", default_batch_size: int = 4):
#         """
#         Initializes the EndpointHandler with a specified model path and default batch size.

#         Args:
#             path (str): Path to the pre-trained model and processor.
#             default_batch_size (int): Default batch size for image processing.
#         """
#         from colpali_engine.models import ColQwen2, ColQwen2Processor

#         self.model = ColQwen2.from_pretrained(
#             path,
#             torch_dtype=torch.bfloat16,
#         ).eval()
#         self.processor = ColQwen2Processor.from_pretrained(path)

#         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         self.model.to(self.device)
#         self.default_batch_size = default_batch_size

#     def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
#         """
#         Processes a batch of images and generates embeddings.

#         Args:
#             images (List[Image.Image]): List of images to process.

#         Returns:
#             List[List[float]]: List of embeddings for each image.
#         """
#         batch_images = self.processor.process_images(images)
#         batch_images = {k: v.to(self.device) for k, v in batch_images.items()}

#         with torch.no_grad():
#             image_embeddings = self.model(**batch_images)

#         return image_embeddings.cpu().tolist()

#     def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
#         """
#         Processes input data containing base64-encoded images, decodes them, and generates embeddings.

#         Args:
#             data (Dict[str, Any]): Dictionary containing input images and optional batch size.

#         Returns:
#             Dict[str, Any]: Dictionary containing generated embeddings or error messages.
#         """
#         images_data = data.get("inputs", [])
#         batch_size = data.get("batch_size", self.default_batch_size)

#         if not images_data:
#             return {"error": "No images provided in 'inputs'."}

#         images = []
#         for img_data in images_data:
#             if isinstance(img_data, str):
#                 try:
#                     image_bytes = base64.b64decode(img_data)
#                     image = Image.open(BytesIO(image_bytes)).convert("RGB")
#                     images.append(image)
#                 except Exception as e:
#                     return {"error": f"Invalid image data: {e}"}
#             else:
#                 return {"error": "Images should be base64-encoded strings."}

#         embeddings = []
#         for i in range(0, len(images), batch_size):
#             batch_images = images[i : i + batch_size]
#             batch_embeddings = self._process_batch(batch_images)
#             embeddings.extend(batch_embeddings)

#         return {"embeddings": embeddings}

import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO


class EndpointHandler:
    """
    A handler class for processing image and text data, generating embeddings using a specified model and processor.

    Attributes:
        model: The pre-trained model used for generating embeddings.
        processor: The pre-trained processor used to process images and text before model inference.
        device: The device (CPU or CUDA) used to run model inference.
        default_batch_size: The default batch size for processing images and text in batches.
    """

    def __init__(self, path: str = "", default_batch_size: int = 4):
        """
        Initializes the EndpointHandler with a specified model path and default batch size.

        Args:
            path (str): Path to the pre-trained model and processor.
            default_batch_size (int): Default batch size for processing images and text data.
        """
        from colpali_engine.models import ColQwen2, ColQwen2Processor

        self.model = ColQwen2.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            device_map=(
                "cuda:0" if torch.cuda.is_available() else "cpu"
            ),  # Set device map based on availability
        ).eval()
        self.processor = ColQwen2Processor.from_pretrained(path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.default_batch_size = default_batch_size

    def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
        """
        Processes a batch of images and generates embeddings.

        Args:
            images (List[Image.Image]): List of images to process.

        Returns:
            List[List[float]]: List of embeddings for each image.
        """
        batch_images = self.processor.process_images(images).to(self.device)

        with torch.no_grad():
            image_embeddings = self.model(**batch_images)

        return image_embeddings.cpu().tolist()

    def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
        """
        Processes a batch of text queries and generates embeddings.

        Args:
            texts (List[str]): List of text queries to process.

        Returns:
            List[List[float]]: List of embeddings for each text query.
        """
        batch_queries = self.processor.process_queries(texts).to(self.device)

        with torch.no_grad():
            query_embeddings = self.model(**batch_queries)

        return query_embeddings.cpu().tolist()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.

        Args:
            data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.

        Returns:
            Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
        """
        images_data = data.get("image", [])
        text_data = data.get("text", [])
        batch_size = data.get("batch_size", self.default_batch_size)

        # Decode and process images
        images = []
        if images_data:
            for img_data in images_data:
                if isinstance(img_data, str):
                    try:
                        image_bytes = base64.b64decode(img_data)
                        image = Image.open(BytesIO(image_bytes)).convert("RGB")
                        images.append(image)
                    except Exception as e:
                        return {"error": f"Invalid image data: {e}"}
                else:
                    return {"error": "Images should be base64-encoded strings."}

        image_embeddings = []
        for i in range(0, len(images), batch_size):
            batch_images = images[i : i + batch_size]
            batch_embeddings = self._process_image_batch(batch_images)
            image_embeddings.extend(batch_embeddings)

        # Process text data
        text_embeddings = []
        if text_data:
            for i in range(0, len(text_data), batch_size):
                batch_texts = text_data[i : i + batch_size]
                batch_text_embeddings = self._process_text_batch(batch_texts)
                text_embeddings.extend(batch_text_embeddings)

        # Compute similarity scores if both image and text embeddings are available
        scores = []
        if image_embeddings and text_embeddings:
            # Convert embeddings to tensors for scoring
            image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
            text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)

            with torch.no_grad():
                scores = (
                    self.processor.score_multi_vector(
                        text_embeddings_tensor, image_embeddings_tensor
                    )
                    .cpu()
                    .tolist()
                )

        return {"image": image_embeddings, "text": text_embeddings, "scores": scores}