amaye15 commited on
Commit
e165930
1 Parent(s): 64262c3

handler clean up & readme updated

Browse files
Files changed (2) hide show
  1. README.md +114 -0
  2. handler.py +0 -136
README.md CHANGED
@@ -1,3 +1,117 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+
6
+ # EndpointHandler
7
+
8
+ `EndpointHandler` is a Python class that processes image and text data to generate embeddings and similarity scores using the ColQwen2 model—a visual retriever based on Qwen2-VL-2B-Instruct with the ColBERT strategy. This handler is optimized for retrieving documents and visual information based on their visual and textual features.
9
+
10
+ ## Overview
11
+
12
+ - **Efficient Document Retrieval**: Uses the ColQwen2 model to produce embeddings for images and text for accurate document retrieval.
13
+ - **Multi-vector Representation**: Generates ColBERT-style multi-vector embeddings for improved similarity search.
14
+ - **Flexible Image Resolution**: Supports dynamic image resolution without altering the aspect ratio, capped at 768 patches for memory efficiency.
15
+ - **Device Compatibility**: Automatically utilizes available CUDA devices or defaults to CPU.
16
+
17
+ ## Model Details
18
+
19
+ The **ColQwen2** model extends Qwen2-VL-2B with a focus on vision-language tasks, making it suitable for content indexing and retrieval. Key features include:
20
+ - **Training**: Pre-trained with a batch size of 256 over 5 epochs, with a modified pad token.
21
+ - **Input Flexibility**: Handles various image resolutions without resizing, ensuring accurate multi-vector representation.
22
+ - **Similarity Scoring**: Utilizes a ColBERT-style scoring approach for efficient retrieval across image and text modalities.
23
+
24
+ This base version is untrained, providing deterministic initialization of the projection layer for further customization.
25
+
26
+ ## How to Use
27
+
28
+ The following example demonstrates how to use `EndpointHandler` for processing PDF documents and text. PDF pages are converted to base64 images, which are then passed as input alongside text data to the handler.
29
+
30
+ ### Example Script
31
+
32
+ ```python
33
+ import torch
34
+ from pdf2image import convert_from_path
35
+ import base64
36
+ from io import BytesIO
37
+ import requests
38
+
39
+ # Function to convert PIL Image to base64 string
40
+ def pil_image_to_base64(image):
41
+ """Converts a PIL Image to a base64 encoded string."""
42
+ buffer = BytesIO()
43
+ image.save(buffer, format="PNG")
44
+ return base64.b64encode(buffer.getvalue()).decode()
45
+
46
+ # Function to convert PDF pages to base64 images
47
+ def convert_pdf_to_base64_images(pdf_path):
48
+ """Converts PDF pages to base64 encoded images."""
49
+ pages = convert_from_path(pdf_path)
50
+ return [pil_image_to_base64(page) for page in pages]
51
+
52
+ # Function to send payload to API and retrieve response
53
+ def query_api(payload, api_url, headers):
54
+ """Sends a POST request to the API and returns the response."""
55
+ response = requests.post(api_url, headers=headers, json=payload)
56
+ return response.json()
57
+
58
+ # Main execution
59
+ if __name__ == "__main__":
60
+ # Convert PDF pages to base64 encoded images
61
+ encoded_images = convert_pdf_to_base64_images('document.pdf')
62
+
63
+ # Prepare payload
64
+ payload = {
65
+ "inputs": [],
66
+ "image": encoded_images,
67
+ "text": ["example query text"]
68
+ }
69
+
70
+ # API configuration
71
+ API_URL = "https://your-api-url"
72
+ headers = {
73
+ "Accept": "application/json",
74
+ "Authorization": "Bearer your_access_token",
75
+ "Content-Type": "application/json"
76
+ }
77
+
78
+ # Query the API and get output
79
+ output = query_api(payload=payload, api_url=API_URL, headers=headers)
80
+ print(output)
81
+ ```
82
+
83
+ ## Inputs and Outputs
84
+
85
+ ### Input Format
86
+ The `EndpointHandler` expects a dictionary containing:
87
+ - **image**: A list of base64-encoded strings for images (e.g., PDF pages converted to images).
88
+ - **text**: A list of text strings representing queries or document contents.
89
+ - **batch_size** (optional): The batch size for processing images and text. Defaults to `4`.
90
+
91
+ Example payload:
92
+ ```json
93
+ {
94
+ "image": ["base64_image_string_1", "base64_image_string_2"],
95
+ "text": ["sample text 1", "sample text 2"],
96
+ "batch_size": 4
97
+ }
98
+ ```
99
+
100
+ ### Output Format
101
+ The handler returns a dictionary with the following keys:
102
+ - **image**: List of embeddings for each image.
103
+ - **text**: List of embeddings for each text entry.
104
+ - **scores**: List of similarity scores between the image and text embeddings.
105
+
106
+ Example output:
107
+ ```json
108
+ {
109
+ "image": [[0.12, 0.34, ...], [0.56, 0.78, ...]],
110
+ "text": [[0.11, 0.22, ...], [0.33, 0.44, ...]],
111
+ "scores": [[0.87, 0.45], [0.23, 0.67]]
112
+ }
113
+ ```
114
+
115
+ ### Error Handling
116
+ If any issues occur during processing (e.g., decoding images or model inference), the handler logs the error and returns an error message in the output dictionary.
117
+
handler.py CHANGED
@@ -1,139 +1,3 @@
1
- # import torch
2
- # from typing import Dict, Any, List
3
- # from PIL import Image
4
- # import base64
5
- # from io import BytesIO
6
-
7
-
8
- # class EndpointHandler:
9
- # """
10
- # A handler class for processing image and text data, generating embeddings using a specified model and processor.
11
-
12
- # Attributes:
13
- # model: The pre-trained model used for generating embeddings.
14
- # processor: The pre-trained processor used to process images and text before model inference.
15
- # device: The device (CPU or CUDA) used to run model inference.
16
- # default_batch_size: The default batch size for processing images and text in batches.
17
- # """
18
-
19
- # def __init__(self, path: str = "", default_batch_size: int = 4):
20
- # """
21
- # Initializes the EndpointHandler with a specified model path and default batch size.
22
-
23
- # Args:
24
- # path (str): Path to the pre-trained model and processor.
25
- # default_batch_size (int): Default batch size for processing images and text data.
26
- # """
27
- # from colpali_engine.models import ColQwen2, ColQwen2Processor
28
-
29
- # self.model = ColQwen2.from_pretrained(
30
- # path,
31
- # torch_dtype=torch.bfloat16,
32
- # device_map=(
33
- # "cuda:0" if torch.cuda.is_available() else "cpu"
34
- # ), # Set device map based on availability
35
- # ).eval()
36
- # self.processor = ColQwen2Processor.from_pretrained(path)
37
-
38
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
- # self.model.to(self.device)
40
- # self.default_batch_size = default_batch_size
41
-
42
- # def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
43
- # """
44
- # Processes a batch of images and generates embeddings.
45
-
46
- # Args:
47
- # images (List[Image.Image]): List of images to process.
48
-
49
- # Returns:
50
- # List[List[float]]: List of embeddings for each image.
51
- # """
52
- # batch_images = self.processor.process_images(images).to(self.device)
53
-
54
- # with torch.no_grad():
55
- # image_embeddings = self.model(**batch_images)
56
-
57
- # return image_embeddings.cpu().tolist()
58
-
59
- # def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
60
- # """
61
- # Processes a batch of text queries and generates embeddings.
62
-
63
- # Args:
64
- # texts (List[str]): List of text queries to process.
65
-
66
- # Returns:
67
- # List[List[float]]: List of embeddings for each text query.
68
- # """
69
- # batch_queries = self.processor.process_queries(texts).to(self.device)
70
-
71
- # with torch.no_grad():
72
- # query_embeddings = self.model(**batch_queries)
73
-
74
- # return query_embeddings.cpu().tolist()
75
-
76
- # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
77
- # """
78
- # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
79
-
80
- # Args:
81
- # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
82
-
83
- # Returns:
84
- # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
85
- # """
86
- # images_data = data.get("image", [])
87
- # text_data = data.get("text", [])
88
- # batch_size = data.get("batch_size", self.default_batch_size)
89
-
90
- # # Decode and process images
91
- # images = []
92
- # if images_data:
93
- # for img_data in images_data:
94
- # if isinstance(img_data, str):
95
- # try:
96
- # image_bytes = base64.b64decode(img_data)
97
- # image = Image.open(BytesIO(image_bytes)).convert("RGB")
98
- # images.append(image)
99
- # except Exception as e:
100
- # return {"error": f"Invalid image data: {e}"}
101
- # else:
102
- # return {"error": "Images should be base64-encoded strings."}
103
-
104
- # image_embeddings = []
105
- # for i in range(0, len(images), batch_size):
106
- # batch_images = images[i : i + batch_size]
107
- # batch_embeddings = self._process_image_batch(batch_images)
108
- # image_embeddings.extend(batch_embeddings)
109
-
110
- # # Process text data
111
- # text_embeddings = []
112
- # if text_data:
113
- # for i in range(0, len(text_data), batch_size):
114
- # batch_texts = text_data[i : i + batch_size]
115
- # batch_text_embeddings = self._process_text_batch(batch_texts)
116
- # text_embeddings.extend(batch_text_embeddings)
117
-
118
- # # Compute similarity scores if both image and text embeddings are available
119
- # scores = []
120
- # if image_embeddings and text_embeddings:
121
- # # Convert embeddings to tensors for scoring
122
- # image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
123
- # text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
124
-
125
- # with torch.no_grad():
126
- # scores = (
127
- # self.processor.score_multi_vector(
128
- # text_embeddings_tensor, image_embeddings_tensor
129
- # )
130
- # .cpu()
131
- # .tolist()
132
- # )
133
-
134
- # return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
135
-
136
-
137
  import torch
138
  from typing import Dict, Any, List
139
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Any, List
3
  from PIL import Image