File size: 12,573 Bytes
16f4f18 42ed865 16f4f18 42ed865 16f4f18 15c7fb8 16f4f18 15c7fb8 16f4f18 ea71f34 16f4f18 42ed865 16f4f18 15c7fb8 16f4f18 da91a0c 16f4f18 af56925 682ae47 af56925 682ae47 e6c2ed8 16f4f18 fbe7fc9 682ae47 cd3d8c2 af56925 16f4f18 15c7fb8 16f4f18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
# handler.py
import io
from typing import Any, Dict, List
import numpy as np
import requests
import torch
from PIL import Image
from transformers import (
CLIPModel,
CLIPProcessor,
CLIPTokenizerFast,
pipeline,
AutoProcessor,
AutoModelForCausalLM,
)
from huggingface_hub import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import timeit
import easyocr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# multi-model list
multi_model_list = [
{"model_id": "openai/clip-vit-base-patch32", "task": "Custom"},
{"model_id": "microsoft/git-large-coco", "task": "Custom"},
]
class EndpointHandler:
def __init__(self, path=""):
clip_model_id = "openai/clip-vit-base-patch32"
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor_clip = CLIPProcessor.from_pretrained(clip_model_id)
self.model_clip = CLIPModel.from_pretrained(clip_model_id).to(self.device)
self.tokenizer_clip = CLIPTokenizerFast.from_pretrained(clip_model_id)
self.processor_git = AutoProcessor.from_pretrained("microsoft/git-large-coco")
self.model_git = AutoModelForCausalLM.from_pretrained(
"microsoft/git-large-coco"
)
self.model_git.to(device)
self.model_clip.to(device)
logging.set_verbosity_debug()
self.logger = logging.get_logger(__name__)
self.reader = easyocr.Reader(["de", "en"])
def download_image(self, url: str) -> bytes:
"""
Download an image from a given URL.
Parameters:
- url: str
The URL from where the image needs to be downloaded.
Returns:
- bytes
The downloaded image data in bytes.
Raises:
- Exception: If the image download request fails.
"""
response = requests.get(url)
if response.status_code == 200:
return response.content
else:
self.logger.error(f"Error downloading image from :{str(url)}")
raise Exception(
f"Failed to download image from {url}. Status code: {response.status_code}"
)
def download_images_in_parallel(
self, urls: List[str], images_metadata_list: List[dict]
) -> List[bytes]:
"""
Download multiple images in parallel and collect their metadata.
Parameters:
- urls: List[str]
A list of URLs from where the images need to be downloaded.
- images_metadata_list: List[dict]
A list of metadata corresponding to each image URL.
Returns:
- Tuple[List[bytes], List[dict]]
A tuple containing a list of downloaded image data in bytes and
a list of metadata for the successfully downloaded images.
"""
with ThreadPoolExecutor() as executor:
# Start the load operations and mark each future with its URL and metadata
future_to_metadata = {
executor.submit(self.download_image, url): (url, metadata)
for url, metadata in zip(urls, images_metadata_list)
}
results = []
successful_metadata = []
for future in as_completed(future_to_metadata):
url, metadata = future_to_metadata[future]
try:
data = future.result()
results.append(data)
metadata["url"] = url
successful_metadata.append(metadata)
except Exception as exc:
self.logger.error("%r generated an exception: %s" % (url, exc))
return results, successful_metadata
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data based on its type and return the embeddings.
This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'.
If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key.
It downloads and processes these images, and returns their embeddings.
If 'process_type' is 'text', the method expects a string query under the 'query' key.
It processes this text and returns its embedding.
Parameters:
- data: Dict[str, Any]
A dictionary containing the data to be processed.
It must include a 'process_type' key with value either 'images' or 'text'.
If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs.
If 'process_type' is 'text', data should also include 'query' key with a string query.
Returns:
- List[Dict[str, Any]]
A list of dictionaries, each containing the embeddings of the processed data.
If an error occurs during processing, the dictionary will include an 'error' key with the error message.
Raises:
- ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type.
"""
if data["process_type"] == "images":
try:
# Check if 'inputs' key is in data and it has the right type
if "images_urls" not in data or not isinstance(
data["images_urls"], list
):
raise ValueError(
"Data must contain 'images_urls' key with a list of images urls."
)
batch_size = 50
if "batch_size" in data:
batch_size = int(data["batch_size"])
# Download and process the images (just downloading in this example)
images_batches = []
processed_metadata = []
for i in range(0, len(data["images_urls"]), batch_size):
# select batch of images
batches = data["images_urls"][i : i + batch_size]
batches_metadata = data["images_metadata"][i : i + batch_size]
download_start_time = timeit.default_timer()
# Download images in parallel along with their metadata
(
downloaded_images,
images_metadata,
) = self.download_images_in_parallel(batches, batches_metadata)
download_end_time = timeit.default_timer()
self.logger.info(
f"Image downloading took {download_end_time - download_start_time} seconds"
)
processing_start_time = timeit.default_timer()
for image_content, image_metadata in zip(
downloaded_images, images_metadata
):
try:
image = Image.open(io.BytesIO(image_content)).convert("RGB")
image_array = np.array(image)
images_batches.append(image_array)
complete_image_metadata = {
# "text": image_metadata["caption"],
# "source": image_metadata["url"],
"source_type": "images",
**image_metadata,
}
# Extract text from image using easyocr
extracted_text = self.reader.readtext(
np.array(image), detail=0
)
complete_image_metadata["extracted_text"] = extracted_text
processed_metadata.append(complete_image_metadata)
except Exception as e:
self.logger.error(f"Error image processing: {str(e)}")
print(e)
# This should be a list of images as np.arrays
processing_end_time = timeit.default_timer()
self.logger.info(
f"Image processing took {processing_end_time - processing_start_time} seconds"
)
embedding_start_time = timeit.default_timer()
with torch.no_grad(): # This line ensures that the code inside the block doesn't track gradients
batch = self.processor_clip(
text=None,
images=images_batches,
return_tensors="pt",
padding=True,
)["pixel_values"].to(self.model_clip.device)
batch_git = self.processor_git(
images=images_batches,
return_tensors="pt",
)
git_pixel_values = batch_git.pixel_values.to(self.model_git.device)
# get image captions
generated_ids = self.model_git.generate(
pixel_values=git_pixel_values, max_length=35
)
generated_captions = self.processor_git.batch_decode(
generated_ids, skip_special_tokens=True
)
# get image embeddings
batch_emb = self.model_clip.get_image_features(pixel_values=batch)
# detach text emb from graph, move to CPU, and convert to numpy array
self.logger.info(
f"Shape of batch_emb after get_image_features: {batch_emb.shape}"
)
# Check the shape of the tensor before squeezing
if batch_emb.shape[0] > 1:
batch_emb = batch_emb.squeeze(0)
self.logger.info(
f"Shape of batch_emb after squeeze: {batch_emb.shape}"
)
batch_emb = batch_emb.cpu().detach().numpy()
# NORMALIZE
if batch_emb.ndim > 1:
batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
self.logger.info(
f"Shape of batch_emb after normalization (2D case): {batch_emb.shape}"
)
# transpose back to (21, 512)
batch_emb = batch_emb.T.tolist()
embedding_end_time = timeit.default_timer()
self.logger.info(
f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds"
)
# Return the embeddings
return {
"embeddings": batch_emb,
"metadata": processed_metadata,
"captions": generated_captions,
}
except Exception as e:
print(f"Error during Images processing: {str(e)}")
self.logger.error(f"Error during Images processing: {str(e)}")
return {"embeddings": [], "error": str(e)}
elif data["process_type"] == "text":
if "query" not in data or not isinstance(data["query"], str):
raise ValueError("Data must contain 'query' key which is a str.")
query = data["query"]
inputs = self.tokenizer_clip(query, return_tensors="pt").to(self.device)
text_emb = self.model_clip.get_text_features(**inputs)
# detach text emb from graph, move to CPU, and convert to numpy array
text_emb = text_emb.detach().cpu().numpy()
# calculate value to normalize each vector by and normalize them
norm_factor = np.linalg.norm(text_emb, axis=1)
text_emb = text_emb.T / norm_factor
# transpose back to (21, 512)
text_emb = text_emb.T
# Converting tensor to list for JSON response
text_emb_list = text_emb.tolist()
return {"embeddings": text_emb_list}
else:
print(
f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'"
)
return {"embeddings": [], "error": str(e)}
|