File size: 3,378 Bytes
b54bf3c
aea7238
b54bf3c
 
 
 
 
 
8d41aec
4e7495e
8d41aec
 
4e7495e
 
 
 
8d41aec
 
f4f3a3e
8d41aec
 
 
 
4e7495e
 
8d41aec
b54bf3c
 
 
 
 
 
 
 
 
 
aea7238
 
 
8d41aec
 
 
 
4e7495e
8d41aec
4e7495e
 
8d41aec
aea7238
 
 
 
 
 
 
 
b54bf3c
8d41aec
 
 
 
4e7495e
8d41aec
4e7495e
 
8d41aec
b54bf3c
aea7238
b54bf3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aea7238
 
 
 
 
b54bf3c
aea7238
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO


class EndpointHandler:
    """
    A handler class for processing image data, generating embeddings using a specified model and processor.

    Attributes:
        model: The pre-trained model used for generating embeddings.
        processor: The pre-trained processor used to process images before model inference.
        device: The device (CPU or CUDA) used to run model inference.
        default_batch_size: The default batch size for processing images in batches.
    """

    def __init__(self, path: str = "", default_batch_size: int = 4):
        """
        Initializes the EndpointHandler with a specified model path and default batch size.

        Args:
            path (str): Path to the pre-trained model and processor.
            default_batch_size (int): Default batch size for image processing.
        """
        from colpali_engine.models import ColQwen2, ColQwen2Processor

        self.model = ColQwen2.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
        ).eval()
        self.processor = ColQwen2Processor.from_pretrained(path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.default_batch_size = default_batch_size

    def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
        """
        Processes a batch of images and generates embeddings.

        Args:
            images (List[Image.Image]): List of images to process.

        Returns:
            List[List[float]]: List of embeddings for each image.
        """
        batch_images = self.processor.process_images(images)
        batch_images = {k: v.to(self.device) for k, v in batch_images.items()}

        with torch.no_grad():
            image_embeddings = self.model(**batch_images)

        return image_embeddings.cpu().tolist()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes input data containing base64-encoded images, decodes them, and generates embeddings.

        Args:
            data (Dict[str, Any]): Dictionary containing input images and optional batch size.

        Returns:
            Dict[str, Any]: Dictionary containing generated embeddings or error messages.
        """
        images_data = data.get("inputs", [])
        batch_size = data.get("batch_size", self.default_batch_size)

        if not images_data:
            return {"error": "No images provided in 'inputs'."}

        images = []
        for img_data in images_data:
            if isinstance(img_data, str):
                try:
                    image_bytes = base64.b64decode(img_data)
                    image = Image.open(BytesIO(image_bytes)).convert("RGB")
                    images.append(image)
                except Exception as e:
                    return {"error": f"Invalid image data: {e}"}
            else:
                return {"error": "Images should be base64-encoded strings."}

        embeddings = []
        for i in range(0, len(images), batch_size):
            batch_images = images[i : i + batch_size]
            batch_embeddings = self._process_batch(batch_images)
            embeddings.extend(batch_embeddings)

        return {"embeddings": embeddings}