marqo-fashionSigLIP / handler.py
Amitai Getzler
:heavy_plus_sign: Ad
dc7652d
raw
history blame
7.59 kB
from base64 import b64decode
from io import BytesIO
import open_clip
import requests
import torch
import numpy as np
from PIL import Image
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
self.model, self.preprocess_train, self.preprocess_val = (
open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP", device="cuda" if torch.cuda.is_available() else "cpu")
)
self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP")
def classify_image(self, candidate_labels, image):
def get_top_prediction(text_probs, labels):
max_index = text_probs[0].argmax().item()
return {
"label": labels[max_index],
"score": text_probs[0][max_index].item(),
}
top_prediction = None
for i in range(0, len(candidate_labels), 10):
batch_labels = candidate_labels[i : i + 10]
# Preprocess the image
image_tensor = self.preprocess_val(image).unsqueeze(0)
text = self.tokenizer(batch_labels)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.model.encode_image(image_tensor)
text_features = self.model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
current_top = get_top_prediction(text_probs, batch_labels)
if top_prediction is None or current_top["score"] > top_prediction["score"]:
top_prediction = current_top
return {"label": top_prediction["label"]}
def combine_embeddings(
self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5
):
"""Combine text and image embeddings with specified weights."""
# Average text embeddings
if text_embeddings is not None:
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
else:
avg_text_embedding = np.zeros_like(image_embeddings[0])
if image_embeddings is not None:
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
else:
avg_image_embeddings = np.zeros_like(text_embeddings[0])
# Combine text and image embeddings with specified weights
combined_embedding = np.average(
np.vstack((avg_text_embedding, avg_image_embeddings)),
axis=0,
weights=[text_weight, image_weight],
)
return combined_embedding
def average_text(self, doc):
text_chunks = [
" ".join(doc.split(" ")[i : i + 40])
for i in range(0, len(doc.split(" ")), 40)
]
text_embeddings = []
for chunk in text_chunks:
inputs = self.tokenizer(chunk)
text_features = self.model.encode_text(inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings.append(text_features.detach().squeeze().numpy())
combined = self.combine_embeddings(
text_embeddings, None, text_weight=1, image_weight=0
)
return combined
def embedd_image(self, doc) -> list:
if not isinstance(doc, str):
image = doc.get("image")
if "https://" in image:
image = image.split("|")
# response = requests.get(image)
image = [
Image.open(BytesIO(response.content))
for response in [requests.get(image) for image in image]
][0]
# Simulate generating embeddings
image = self.preprocess_val(image).unsqueeze(0)
image_features = self.model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_embedding = image_features.detach().squeeze().numpy()
if doc.get("description", "") == "":
print("empty description. Going with image alone")
return image_embedding.tolist()
else:
average_texts = self.average_text(doc.get("description"))
combined = self.combine_embeddings(
[average_texts],
[image_embedding],
text_weight=0.5,
image_weight=0.5,
)
return combined.tolist()
elif isinstance(doc, str):
return self.average_text(doc).tolist()
def process_batch(self, batch) -> object:
try:
batch = batch.get("batch")
# Validate the batch input
if not isinstance(batch, list):
return "Invalid input: batch must be an array of strings.", 400
embeddings = [self.embedd_image(item) for item in batch]
# Send the response with the embeddings array
return embeddings
except Exception as e:
print("Error processing request", e)
return "An error occurred while processing the request.", 500
def base64_image_to_pil(self, base64_str) -> Image:
image_data = b64decode(base64_str)
image_buffer = BytesIO(image_data)
image = Image.open(image_buffer)
return image
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Process the input data for either classification or embedding generation.
Args:
data (:obj:`dict`): A dictionary containing the input data and parameters for inference.
For classification:
{
"type": "classify",
"inputs": {
"candidates": :obj:`list[str]`,
"image": :obj:`str` # URL or base64 encoded image
}
}
For embedding:
{
"type": "embedd",
"batch": :obj:`list[str | dict[str, str]]` # Text or image+description
}
Returns:
:obj:`dict`: The result of the operation.
For classification:
{
"label": :obj:`str` # The predicted label
}
For embedding:
{
"embeddings": :obj:`list[list[float]]` # List of embeddings
}
Raises:
:obj:`Exception`: If an error occurs during processing.
"""
inputs = data.pop("inputs", data)
type = data.pop("type", "embedd") # Or classify
if type == "classify":
candidate_labels = inputs["candidates"]
image = (
Image.open(BytesIO(requests.get(inputs["image"]).content))
if "https://" in inputs["image"]
else self.base64_image_to_pil(inputs["image"])
)
response = self.classify_image(candidate_labels, image)
return response
elif type == "embedd":
try:
embeddings = self.process_batch(inputs)
return {"embeddings": embeddings}
except Exception as e:
print(e)
return e