File size: 3,378 Bytes
b54bf3c aea7238 b54bf3c 8d41aec 4e7495e 8d41aec 4e7495e 8d41aec f4f3a3e 8d41aec 4e7495e 8d41aec b54bf3c aea7238 8d41aec 4e7495e 8d41aec 4e7495e 8d41aec aea7238 b54bf3c 8d41aec 4e7495e 8d41aec 4e7495e 8d41aec b54bf3c aea7238 b54bf3c aea7238 b54bf3c aea7238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
"""
A handler class for processing image data, generating embeddings using a specified model and processor.
Attributes:
model: The pre-trained model used for generating embeddings.
processor: The pre-trained processor used to process images before model inference.
device: The device (CPU or CUDA) used to run model inference.
default_batch_size: The default batch size for processing images in batches.
"""
def __init__(self, path: str = "", default_batch_size: int = 4):
"""
Initializes the EndpointHandler with a specified model path and default batch size.
Args:
path (str): Path to the pre-trained model and processor.
default_batch_size (int): Default batch size for image processing.
"""
from colpali_engine.models import ColQwen2, ColQwen2Processor
self.model = ColQwen2.from_pretrained(
path,
torch_dtype=torch.bfloat16,
).eval()
self.processor = ColQwen2Processor.from_pretrained(path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.default_batch_size = default_batch_size
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
"""
Processes a batch of images and generates embeddings.
Args:
images (List[Image.Image]): List of images to process.
Returns:
List[List[float]]: List of embeddings for each image.
"""
batch_images = self.processor.process_images(images)
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
with torch.no_grad():
image_embeddings = self.model(**batch_images)
return image_embeddings.cpu().tolist()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Processes input data containing base64-encoded images, decodes them, and generates embeddings.
Args:
data (Dict[str, Any]): Dictionary containing input images and optional batch size.
Returns:
Dict[str, Any]: Dictionary containing generated embeddings or error messages.
"""
images_data = data.get("inputs", [])
batch_size = data.get("batch_size", self.default_batch_size)
if not images_data:
return {"error": "No images provided in 'inputs'."}
images = []
for img_data in images_data:
if isinstance(img_data, str):
try:
image_bytes = base64.b64decode(img_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Invalid image data: {e}"}
else:
return {"error": "Images should be base64-encoded strings."}
embeddings = []
for i in range(0, len(images), batch_size):
batch_images = images[i : i + batch_size]
batch_embeddings = self._process_batch(batch_images)
embeddings.extend(batch_embeddings)
return {"embeddings": embeddings}
|