File size: 8,746 Bytes
25a8604 b54bf3c aea7238 b54bf3c 8d41aec 25a8604 8d41aec 4e7495e 25a8604 4e7495e 25a8604 8d41aec f4f3a3e 8d41aec 4e7495e 25a8604 8d41aec b54bf3c 25a8604 b54bf3c aea7238 25a8604 8d41aec 4e7495e 8d41aec 4e7495e 8d41aec 25a8604 aea7238 25a8604 b54bf3c 8d41aec 25a8604 8d41aec 25a8604 8d41aec 4e7495e 25a8604 8d41aec 25a8604 aea7238 b54bf3c 25a8604 b54bf3c 25a8604 aea7238 25a8604 dbabaf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# import torch
# from typing import Dict, Any, List
# from PIL import Image
# import base64
# from io import BytesIO
# class EndpointHandler:
# """
# A handler class for processing image data, generating embeddings using a specified model and processor.
# Attributes:
# model: The pre-trained model used for generating embeddings.
# processor: The pre-trained processor used to process images before model inference.
# device: The device (CPU or CUDA) used to run model inference.
# default_batch_size: The default batch size for processing images in batches.
# """
# def __init__(self, path: str = "", default_batch_size: int = 4):
# """
# Initializes the EndpointHandler with a specified model path and default batch size.
# Args:
# path (str): Path to the pre-trained model and processor.
# default_batch_size (int): Default batch size for image processing.
# """
# from colpali_engine.models import ColQwen2, ColQwen2Processor
# self.model = ColQwen2.from_pretrained(
# path,
# torch_dtype=torch.bfloat16,
# ).eval()
# self.processor = ColQwen2Processor.from_pretrained(path)
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.to(self.device)
# self.default_batch_size = default_batch_size
# def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
# """
# Processes a batch of images and generates embeddings.
# Args:
# images (List[Image.Image]): List of images to process.
# Returns:
# List[List[float]]: List of embeddings for each image.
# """
# batch_images = self.processor.process_images(images)
# batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
# with torch.no_grad():
# image_embeddings = self.model(**batch_images)
# return image_embeddings.cpu().tolist()
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# """
# Processes input data containing base64-encoded images, decodes them, and generates embeddings.
# Args:
# data (Dict[str, Any]): Dictionary containing input images and optional batch size.
# Returns:
# Dict[str, Any]: Dictionary containing generated embeddings or error messages.
# """
# images_data = data.get("inputs", [])
# batch_size = data.get("batch_size", self.default_batch_size)
# if not images_data:
# return {"error": "No images provided in 'inputs'."}
# images = []
# for img_data in images_data:
# if isinstance(img_data, str):
# try:
# image_bytes = base64.b64decode(img_data)
# image = Image.open(BytesIO(image_bytes)).convert("RGB")
# images.append(image)
# except Exception as e:
# return {"error": f"Invalid image data: {e}"}
# else:
# return {"error": "Images should be base64-encoded strings."}
# embeddings = []
# for i in range(0, len(images), batch_size):
# batch_images = images[i : i + batch_size]
# batch_embeddings = self._process_batch(batch_images)
# embeddings.extend(batch_embeddings)
# return {"embeddings": embeddings}
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
"""
A handler class for processing image and text data, generating embeddings using a specified model and processor.
Attributes:
model: The pre-trained model used for generating embeddings.
processor: The pre-trained processor used to process images and text before model inference.
device: The device (CPU or CUDA) used to run model inference.
default_batch_size: The default batch size for processing images and text in batches.
"""
def __init__(self, path: str = "", default_batch_size: int = 4):
"""
Initializes the EndpointHandler with a specified model path and default batch size.
Args:
path (str): Path to the pre-trained model and processor.
default_batch_size (int): Default batch size for processing images and text data.
"""
from colpali_engine.models import ColQwen2, ColQwen2Processor
self.model = ColQwen2.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda:0" if torch.cuda.is_available() else "cpu"
), # Set device map based on availability
).eval()
self.processor = ColQwen2Processor.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.default_batch_size = default_batch_size
def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
"""
Processes a batch of images and generates embeddings.
Args:
images (List[Image.Image]): List of images to process.
Returns:
List[List[float]]: List of embeddings for each image.
"""
batch_images = self.processor.process_images(images).to(self.device)
with torch.no_grad():
image_embeddings = self.model(**batch_images)
return image_embeddings.cpu().tolist()
def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
"""
Processes a batch of text queries and generates embeddings.
Args:
texts (List[str]): List of text queries to process.
Returns:
List[List[float]]: List of embeddings for each text query.
"""
batch_queries = self.processor.process_queries(texts).to(self.device)
with torch.no_grad():
query_embeddings = self.model(**batch_queries)
return query_embeddings.cpu().tolist()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
Args:
data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
Returns:
Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
"""
images_data = data.get("image", [])
text_data = data.get("text", [])
batch_size = data.get("batch_size", self.default_batch_size)
# Decode and process images
images = []
if images_data:
for img_data in images_data:
if isinstance(img_data, str):
try:
image_bytes = base64.b64decode(img_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Invalid image data: {e}"}
else:
return {"error": "Images should be base64-encoded strings."}
image_embeddings = []
for i in range(0, len(images), batch_size):
batch_images = images[i : i + batch_size]
batch_embeddings = self._process_image_batch(batch_images)
image_embeddings.extend(batch_embeddings)
# Process text data
text_embeddings = []
if text_data:
for i in range(0, len(text_data), batch_size):
batch_texts = text_data[i : i + batch_size]
batch_text_embeddings = self._process_text_batch(batch_texts)
text_embeddings.extend(batch_text_embeddings)
# Compute similarity scores if both image and text embeddings are available
scores = []
if image_embeddings and text_embeddings:
# Convert embeddings to tensors for scoring
image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
with torch.no_grad():
scores = (
self.processor.score_multi_vector(
text_embeddings_tensor, image_embeddings_tensor
)
.cpu()
.tolist()
)
return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|