# import torch # from typing import Dict, Any, List # from PIL import Image # import base64 # from io import BytesIO # class EndpointHandler: # """ # A handler class for processing image data, generating embeddings using a specified model and processor. # Attributes: # model: The pre-trained model used for generating embeddings. # processor: The pre-trained processor used to process images before model inference. # device: The device (CPU or CUDA) used to run model inference. # default_batch_size: The default batch size for processing images in batches. # """ # def __init__(self, path: str = "", default_batch_size: int = 4): # """ # Initializes the EndpointHandler with a specified model path and default batch size. # Args: # path (str): Path to the pre-trained model and processor. # default_batch_size (int): Default batch size for image processing. # """ # from colpali_engine.models import ColQwen2, ColQwen2Processor # self.model = ColQwen2.from_pretrained( # path, # torch_dtype=torch.bfloat16, # ).eval() # self.processor = ColQwen2Processor.from_pretrained(path) # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.model.to(self.device) # self.default_batch_size = default_batch_size # def _process_batch(self, images: List[Image.Image]) -> List[List[float]]: # """ # Processes a batch of images and generates embeddings. # Args: # images (List[Image.Image]): List of images to process. # Returns: # List[List[float]]: List of embeddings for each image. # """ # batch_images = self.processor.process_images(images) # batch_images = {k: v.to(self.device) for k, v in batch_images.items()} # with torch.no_grad(): # image_embeddings = self.model(**batch_images) # return image_embeddings.cpu().tolist() # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # """ # Processes input data containing base64-encoded images, decodes them, and generates embeddings. # Args: # data (Dict[str, Any]): Dictionary containing input images and optional batch size. # Returns: # Dict[str, Any]: Dictionary containing generated embeddings or error messages. # """ # images_data = data.get("inputs", []) # batch_size = data.get("batch_size", self.default_batch_size) # if not images_data: # return {"error": "No images provided in 'inputs'."} # images = [] # for img_data in images_data: # if isinstance(img_data, str): # try: # image_bytes = base64.b64decode(img_data) # image = Image.open(BytesIO(image_bytes)).convert("RGB") # images.append(image) # except Exception as e: # return {"error": f"Invalid image data: {e}"} # else: # return {"error": "Images should be base64-encoded strings."} # embeddings = [] # for i in range(0, len(images), batch_size): # batch_images = images[i : i + batch_size] # batch_embeddings = self._process_batch(batch_images) # embeddings.extend(batch_embeddings) # return {"embeddings": embeddings} import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO class EndpointHandler: """ A handler class for processing image and text data, generating embeddings using a specified model and processor. Attributes: model: The pre-trained model used for generating embeddings. processor: The pre-trained processor used to process images and text before model inference. device: The device (CPU or CUDA) used to run model inference. default_batch_size: The default batch size for processing images and text in batches. """ def __init__(self, path: str = "", default_batch_size: int = 4): """ Initializes the EndpointHandler with a specified model path and default batch size. Args: path (str): Path to the pre-trained model and processor. default_batch_size (int): Default batch size for processing images and text data. """ from colpali_engine.models import ColQwen2, ColQwen2Processor self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, device_map=( "cuda:0" if torch.cuda.is_available() else "cpu" ), # Set device map based on availability ).eval() self.processor = ColQwen2Processor.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.default_batch_size = default_batch_size def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]: """ Processes a batch of images and generates embeddings. Args: images (List[Image.Image]): List of images to process. Returns: List[List[float]]: List of embeddings for each image. """ batch_images = self.processor.process_images(images).to(self.device) with torch.no_grad(): image_embeddings = self.model(**batch_images) return image_embeddings.cpu().tolist() def _process_text_batch(self, texts: List[str]) -> List[List[float]]: """ Processes a batch of text queries and generates embeddings. Args: texts (List[str]): List of text queries to process. Returns: List[List[float]]: List of embeddings for each text query. """ batch_queries = self.processor.process_queries(texts).to(self.device) with torch.no_grad(): query_embeddings = self.model(**batch_queries) return query_embeddings.cpu().tolist() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings. Args: data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size. Returns: Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages. """ images_data = data.get("image", []) text_data = data.get("text", []) batch_size = data.get("batch_size", self.default_batch_size) # Decode and process images images = [] if images_data: for img_data in images_data: if isinstance(img_data, str): try: image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") images.append(image) except Exception as e: return {"error": f"Invalid image data: {e}"} else: return {"error": "Images should be base64-encoded strings."} image_embeddings = [] for i in range(0, len(images), batch_size): batch_images = images[i : i + batch_size] batch_embeddings = self._process_image_batch(batch_images) image_embeddings.extend(batch_embeddings) # Process text data text_embeddings = [] if text_data: for i in range(0, len(text_data), batch_size): batch_texts = text_data[i : i + batch_size] batch_text_embeddings = self._process_text_batch(batch_texts) text_embeddings.extend(batch_text_embeddings) # Compute similarity scores if both image and text embeddings are available scores = [] if image_embeddings and text_embeddings: # Convert embeddings to tensors for scoring image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device) text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device) with torch.no_grad(): scores = ( self.processor.score_multi_vector( text_embeddings_tensor, image_embeddings_tensor ) .cpu() .tolist() ) return {"image": image_embeddings, "text": text_embeddings, "scores": scores}