import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO import logging class EndpointHandler: """ A handler class for processing image and text data, generating embeddings using a specified model and processor. Attributes: model: The pre-trained model used for generating embeddings. processor: The pre-trained processor used to process images and text before model inference. device: The device (CPU or CUDA) used to run model inference. default_batch_size: The default batch size for processing images and text in batches. """ def __init__(self, path: str = "", default_batch_size: int = 4): """ Initializes the EndpointHandler with a specified model path and default batch size. Args: path (str): Path to the pre-trained model and processor. default_batch_size (int): Default batch size for processing images and text data. """ # Initialize logging logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) from colpali_engine.models import ColQwen2, ColQwen2Processor self.logger.info("Initializing model and processor.") try: self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="auto", ).eval() self.processor = ColQwen2Processor.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.default_batch_size = default_batch_size self.logger.info("Initialization complete.") except Exception as e: self.logger.error(f"Failed to initialize model or processor: {e}") raise def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]: """ Processes a batch of images and generates embeddings. Args: images (List[Image.Image]): List of images to process. Returns: List[List[float]]: List of embeddings for each image. """ self.logger.debug(f"Processing batch of {len(images)} images.") try: batch_images = self.processor.process_images(images).to(self.device) with torch.no_grad(), torch.amp.autocast(): image_embeddings = self.model(**batch_images) self.logger.debug("Image batch processing complete.") return image_embeddings.cpu().tolist() except Exception as e: self.logger.error(f"Error processing image batch: {e}") raise def _process_text_batch(self, texts: List[str]) -> List[List[float]]: """ Processes a batch of text queries and generates embeddings. Args: texts (List[str]): List of text queries to process. Returns: List[List[float]]: List of embeddings for each text query. """ self.logger.debug(f"Processing batch of {len(texts)} text queries.") try: batch_queries = self.processor.process_queries(texts).to(self.device) with torch.no_grad(), torch.amp.autocast(): query_embeddings = self.model(**batch_queries) self.logger.debug("Text batch processing complete.") return query_embeddings.cpu().tolist() except Exception as e: self.logger.error(f"Error processing text batch: {e}") raise def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings. Args: data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size. Returns: Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages. """ images_data = data.get("image", []) text_data = data.get("text", []) batch_size = data.get("batch_size", self.default_batch_size) # Decode and process images images = [] if images_data: self.logger.info("Decoding images from base64.") for img_data in images_data: if isinstance(img_data, str): try: image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") images.append(image) except Exception as e: self.logger.error(f"Invalid image data: {e}") return {"error": f"Invalid image data: {e}"} else: self.logger.error("Images should be base64-encoded strings.") return {"error": "Images should be base64-encoded strings."} image_embeddings = [] if images: self.logger.info("Processing image embeddings.") try: for i in range(0, len(images), batch_size): batch_images = images[i : i + batch_size] batch_embeddings = self._process_image_batch(batch_images) image_embeddings.extend(batch_embeddings) except Exception as e: self.logger.error(f"Error generating image embeddings: {e}") return {"error": f"Error generating image embeddings: {e}"} # Process text data text_embeddings = [] if text_data: self.logger.info("Processing text embeddings.") try: for i in range(0, len(text_data), batch_size): batch_texts = text_data[i : i + batch_size] batch_text_embeddings = self._process_text_batch(batch_texts) text_embeddings.extend(batch_text_embeddings) except Exception as e: self.logger.error(f"Error generating text embeddings: {e}") return {"error": f"Error generating text embeddings: {e}"} # Compute similarity scores if both image and text embeddings are available scores = [] if image_embeddings and text_embeddings: self.logger.info("Computing similarity scores.") try: image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device) text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device) with torch.no_grad(), torch.amp.autocast(): scores = ( self.processor.score_multi_vector( text_embeddings_tensor, image_embeddings_tensor ) .cpu() .tolist() ) self.logger.info("Similarity scoring complete.") except Exception as e: self.logger.error(f"Error computing similarity scores: {e}") return {"error": f"Error computing similarity scores: {e}"} return {"image": image_embeddings, "text": text_embeddings, "scores": scores} # import torch # from typing import Dict, Any, List # from PIL import Image # import base64 # from io import BytesIO # import logging # from torch.utils.data import DataLoader, Dataset # import threading # class ImageDataset(Dataset): # def __init__(self, images: List[Image.Image], processor): # self.images = images # self.processor = processor # def __len__(self): # return len(self.images) # def __getitem__(self, idx): # image = self.processor.process_images([self.images[idx]]) # return image # class TextDataset(Dataset): # def __init__(self, texts: List[str], processor): # self.texts = texts # self.processor = processor # def __len__(self): # return len(self.texts) # def __getitem__(self, idx): # text = self.processor.process_queries([self.texts[idx]]) # return text # class EndpointHandler: # """ # A handler class for processing image and text data, generating embeddings using a specified model and processor. # Attributes: # model: The pre-trained model used for generating embeddings. # processor: The pre-trained processor used to process images and text before model inference. # device: The device (CPU or CUDA) used to run model inference. # default_batch_size: The default batch size for processing images and text in batches. # """ # def __init__(self, path: str = "", default_batch_size: int = 4): # """ # Initializes the EndpointHandler with a specified model path and default batch size. # """ # # Initialize logging # logging.basicConfig(level=logging.INFO) # self.logger = logging.getLogger(__name__) # from colpali_engine.models import ColQwen2, ColQwen2Processor # self.logger.info("Initializing model and processor.") # try: # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.model = ( # ColQwen2.from_pretrained( # path, # torch_dtype=torch.bfloat16, # device_map="auto", # ) # .to(self.device) # .eval() # ) # self.processor = ColQwen2Processor.from_pretrained(path) # self.default_batch_size = default_batch_size # self.logger.info("Initialization complete.") # except Exception as e: # self.logger.error(f"Failed to initialize model or processor: {e}") # raise # def _process_image_embeddings( # self, images: List[Image.Image], batch_size: int # ) -> torch.Tensor: # """ # Processes images and generates embeddings. # Args: # images (List[Image.Image]): List of images to process. # batch_size (int): Batch size for processing images. # Returns: # torch.Tensor: Tensor containing embeddings for each image. # """ # self.logger.debug(f"Processing {len(images)} images.") # try: # image_dataset = ImageDataset(images, self.processor) # image_loader = DataLoader( # image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True # ) # all_embeddings = [] # with torch.no_grad(): # for batch in image_loader: # batch_images = batch[0].to(self.device, non_blocking=True) # with torch.cuda.amp.autocast(): # embeddings = self.model(**batch_images) # all_embeddings.append(embeddings) # image_embeddings = torch.cat(all_embeddings, dim=0) # self.logger.debug("Image processing complete.") # return image_embeddings # except Exception as e: # self.logger.error(f"Error processing images: {e}") # raise # def _process_text_embeddings( # self, texts: List[str], batch_size: int # ) -> torch.Tensor: # """ # Processes text queries and generates embeddings. # Args: # texts (List[str]): List of text queries to process. # batch_size (int): Batch size for processing texts. # Returns: # torch.Tensor: Tensor containing embeddings for each text query. # """ # self.logger.debug(f"Processing {len(texts)} text queries.") # try: # text_dataset = TextDataset(texts, self.processor) # text_loader = DataLoader( # text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True # ) # all_embeddings = [] # with torch.no_grad(): # for batch in text_loader: # batch_texts = batch[0].to(self.device, non_blocking=True) # with torch.amp.autocast(): # embeddings = self.model(**batch_texts) # all_embeddings.append(embeddings) # text_embeddings = torch.cat(all_embeddings, dim=0) # self.logger.debug("Text processing complete.") # return text_embeddings # except Exception as e: # self.logger.error(f"Error processing texts: {e}") # raise # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # """ # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings. # Args: # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size. # Returns: # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages. # """ # images_data = data.get("image", []) # text_data = data.get("text", []) # batch_size = data.get("batch_size", self.default_batch_size) # images = [] # if images_data: # self.logger.info("Decoding images from base64.") # for img_data in images_data: # if isinstance(img_data, str): # try: # image_bytes = base64.b64decode(img_data) # image = Image.open(BytesIO(image_bytes)).convert("RGB") # images.append(image) # except Exception as e: # self.logger.error(f"Invalid image data: {e}") # return {"error": f"Invalid image data: {e}"} # else: # self.logger.error("Images should be base64-encoded strings.") # return {"error": "Images should be base64-encoded strings."} # image_embeddings = None # text_embeddings = None # scores = None # def process_images(): # nonlocal image_embeddings # if images: # self.logger.info("Processing image embeddings.") # try: # image_embeddings = self._process_image_embeddings( # images, batch_size # ) # except Exception as e: # self.logger.error(f"Error generating image embeddings: {e}") # def process_texts(): # nonlocal text_embeddings # if text_data: # self.logger.info("Processing text embeddings.") # try: # text_embeddings = self._process_text_embeddings( # text_data, batch_size # ) # except Exception as e: # self.logger.error(f"Error generating text embeddings: {e}") # # Process images and texts in parallel if both are present # threads = [] # if images_data and text_data: # image_thread = threading.Thread(target=process_images) # text_thread = threading.Thread(target=process_texts) # threads.extend([image_thread, text_thread]) # image_thread.start() # text_thread.start() # for thread in threads: # thread.join() # else: # process_images() # process_texts() # # Compute similarity scores if both embeddings are available # if image_embeddings is not None and text_embeddings is not None: # self.logger.info("Computing similarity scores.") # try: # with torch.no_grad(), torch.amp.autocast(): # scores = self.processor.score_multi_vector( # text_embeddings, image_embeddings # ) # self.logger.info("Similarity scoring complete.") # except Exception as e: # self.logger.error(f"Error computing similarity scores: {e}") # return {"error": f"Error computing similarity scores: {e}"} # result = {} # if image_embeddings is not None: # result["image"] = image_embeddings.cpu().tolist() # if text_embeddings is not None: # result["text"] = text_embeddings.cpu().tolist() # if scores is not None: # result["scores"] = scores.cpu().tolist() # return result