File size: 7,237 Bytes
b54bf3c aea7238 b54bf3c 64262c3 b54bf3c 8d41aec 25a8604 8d41aec 4e7495e 25a8604 4e7495e 25a8604 8d41aec f4f3a3e 8d41aec 18f8eec 8d41aec 64262c3 b54bf3c 64262c3 b54bf3c 64262c3 18f8eec 2c8e3a0 18f8eec 64262c3 aea7238 18f8eec 8d41aec 18f8eec 8d41aec 4e7495e 8d41aec 4e7495e 18f8eec 8d41aec 18f8eec 64262c3 18f8eec a737583 18f8eec 64262c3 18f8eec 64262c3 aea7238 18f8eec 25a8604 18f8eec 25a8604 18f8eec 25a8604 18f8eec 64262c3 18f8eec a737583 18f8eec 64262c3 18f8eec 64262c3 25a8604 b54bf3c 8d41aec 25a8604 8d41aec 25a8604 8d41aec 4e7495e 25a8604 8d41aec 25a8604 aea7238 b54bf3c 18f8eec b54bf3c 25a8604 64262c3 25a8604 64262c3 25a8604 64262c3 25a8604 18f8eec 64262c3 18f8eec b0a7877 18f8eec dbabaf1 64262c3 dbabaf1 18f8eec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO
import logging
class EndpointHandler:
"""
A handler class for processing image and text data, generating embeddings using a specified model and processor.
Attributes:
model: The pre-trained model used for generating embeddings.
processor: The pre-trained processor used to process images and text before model inference.
device: The device (CPU or CUDA) used to run model inference.
default_batch_size: The default batch size for processing images and text in batches.
"""
def __init__(self, path: str = "", default_batch_size: int = 4):
"""
Initializes the EndpointHandler with a specified model path and default batch size.
Args:
path (str): Path to the pre-trained model and processor.
default_batch_size (int): Default batch size for processing images and text data.
"""
# Initialize logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
from colpali_engine.models import ColQwen2, ColQwen2Processor
self.logger.info("Initializing model and processor.")
try:
self.model = ColQwen2.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
self.processor = ColQwen2Processor.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.default_batch_size = default_batch_size
self.logger.info("Initialization complete.")
except Exception as e:
self.logger.error(f"Failed to initialize model or processor: {e}")
raise
def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
"""
Processes a batch of images and generates embeddings.
Args:
images (List[Image.Image]): List of images to process.
Returns:
List[List[float]]: List of embeddings for each image.
"""
self.logger.debug(f"Processing batch of {len(images)} images.")
try:
batch_images = self.processor.process_images(images).to(self.device)
with torch.no_grad():
image_embeddings = self.model(**batch_images)
self.logger.debug("Image batch processing complete.")
return image_embeddings.cpu().tolist()
except Exception as e:
self.logger.error(f"Error processing image batch: {e}")
raise
def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
"""
Processes a batch of text queries and generates embeddings.
Args:
texts (List[str]): List of text queries to process.
Returns:
List[List[float]]: List of embeddings for each text query.
"""
self.logger.debug(f"Processing batch of {len(texts)} text queries.")
try:
batch_queries = self.processor.process_queries(texts).to(self.device)
with torch.no_grad():
query_embeddings = self.model(**batch_queries)
self.logger.debug("Text batch processing complete.")
return query_embeddings.cpu().tolist()
except Exception as e:
self.logger.error(f"Error processing text batch: {e}")
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
Args:
data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
Returns:
Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
"""
images_data = data.get("image", [])
text_data = data.get("text", [])
batch_size = data.get("batch_size", self.default_batch_size)
# Decode and process images
images = []
if images_data:
self.logger.info("Decoding images from base64.")
for img_data in images_data:
if isinstance(img_data, str):
try:
image_bytes = base64.b64decode(img_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
images.append(image)
except Exception as e:
self.logger.error(f"Invalid image data: {e}")
return {"error": f"Invalid image data: {e}"}
else:
self.logger.error("Images should be base64-encoded strings.")
return {"error": "Images should be base64-encoded strings."}
image_embeddings = []
if images:
self.logger.info("Processing image embeddings.")
try:
for i in range(0, len(images), batch_size):
batch_images = images[i : i + batch_size]
batch_embeddings = self._process_image_batch(batch_images)
image_embeddings.extend(batch_embeddings)
except Exception as e:
self.logger.error(f"Error generating image embeddings: {e}")
return {"error": f"Error generating image embeddings: {e}"}
# Process text data
text_embeddings = []
if text_data:
self.logger.info("Processing text embeddings.")
try:
for i in range(0, len(text_data), batch_size):
batch_texts = text_data[i : i + batch_size]
batch_text_embeddings = self._process_text_batch(batch_texts)
text_embeddings.extend(batch_text_embeddings)
except Exception as e:
self.logger.error(f"Error generating text embeddings: {e}")
return {"error": f"Error generating text embeddings: {e}"}
# Compute similarity scores if both image and text embeddings are available
scores = []
if image_embeddings and text_embeddings:
self.logger.info("Computing similarity scores.")
try:
image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
with torch.no_grad():
scores = (
self.processor.score_multi_vector(
text_embeddings_tensor, image_embeddings_tensor
)
.cpu()
.tolist()
)
self.logger.info("Similarity scoring complete.")
except Exception as e:
self.logger.error(f"Error computing similarity scores: {e}")
return {"error": f"Error computing similarity scores: {e}"}
return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|