|
from base64 import b64decode |
|
from io import BytesIO |
|
import open_clip |
|
import requests |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from typing import Dict, Any |
|
|
|
class EndpointHandler: |
|
def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"): |
|
self.model, self.preprocess_train, self.preprocess_val = ( |
|
open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP", device="cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP") |
|
|
|
def classify_image(self, candidate_labels, image): |
|
def get_top_prediction(text_probs, labels): |
|
max_index = text_probs[0].argmax().item() |
|
return { |
|
"label": labels[max_index], |
|
"score": text_probs[0][max_index].item(), |
|
} |
|
|
|
top_prediction = None |
|
for i in range(0, len(candidate_labels), 10): |
|
batch_labels = candidate_labels[i : i + 10] |
|
|
|
image_tensor = self.preprocess_val(image).unsqueeze(0) |
|
text = self.tokenizer(batch_labels) |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
image_features = self.model.encode_image(image_tensor) |
|
text_features = self.model.encode_text(text) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
|
|
current_top = get_top_prediction(text_probs, batch_labels) |
|
if top_prediction is None or current_top["score"] > top_prediction["score"]: |
|
top_prediction = current_top |
|
|
|
return {"label": top_prediction["label"]} |
|
|
|
def combine_embeddings( |
|
self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5 |
|
): |
|
"""Combine text and image embeddings with specified weights.""" |
|
|
|
if text_embeddings is not None: |
|
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0) |
|
else: |
|
avg_text_embedding = np.zeros_like(image_embeddings[0]) |
|
|
|
if image_embeddings is not None: |
|
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0) |
|
else: |
|
avg_image_embeddings = np.zeros_like(text_embeddings[0]) |
|
|
|
|
|
combined_embedding = np.average( |
|
np.vstack((avg_text_embedding, avg_image_embeddings)), |
|
axis=0, |
|
weights=[text_weight, image_weight], |
|
) |
|
return combined_embedding |
|
|
|
def average_text(self, doc): |
|
text_chunks = [ |
|
" ".join(doc.split(" ")[i : i + 40]) |
|
for i in range(0, len(doc.split(" ")), 40) |
|
] |
|
text_embeddings = [] |
|
for chunk in text_chunks: |
|
inputs = self.tokenizer(chunk) |
|
text_features = self.model.encode_text(inputs) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
text_embeddings.append(text_features.detach().squeeze().numpy()) |
|
combined = self.combine_embeddings( |
|
text_embeddings, None, text_weight=1, image_weight=0 |
|
) |
|
return combined |
|
|
|
def embedd_image(self, doc) -> list: |
|
if not isinstance(doc, str): |
|
image = doc.get("image") |
|
if "https://" in image: |
|
image = image.split("|") |
|
|
|
image = [ |
|
Image.open(BytesIO(response.content)) |
|
for response in [requests.get(image) for image in image] |
|
][0] |
|
|
|
image = self.preprocess_val(image).unsqueeze(0) |
|
image_features = self.model.encode_image(image) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
image_embedding = image_features.detach().squeeze().numpy() |
|
if doc.get("description", "") == "": |
|
print("empty description. Going with image alone") |
|
return image_embedding.tolist() |
|
else: |
|
average_texts = self.average_text(doc.get("description")) |
|
combined = self.combine_embeddings( |
|
[average_texts], |
|
[image_embedding], |
|
text_weight=0.5, |
|
image_weight=0.5, |
|
) |
|
return combined.tolist() |
|
elif isinstance(doc, str): |
|
return self.average_text(doc).tolist() |
|
|
|
def process_batch(self, batch) -> object: |
|
try: |
|
batch = batch.get("batch") |
|
|
|
if not isinstance(batch, list): |
|
return "Invalid input: batch must be an array of strings.", 400 |
|
embeddings = [self.embedd_image(item) for item in batch] |
|
|
|
return embeddings |
|
except Exception as e: |
|
print("Error processing request", e) |
|
return "An error occurred while processing the request.", 500 |
|
|
|
def base64_image_to_pil(self, base64_str) -> Image: |
|
image_data = b64decode(base64_str) |
|
image_buffer = BytesIO(image_data) |
|
image = Image.open(image_buffer) |
|
return image |
|
|
|
def __call__(self, data: Any) -> Dict[str, Any]: |
|
""" |
|
Process the input data for either classification or embedding generation. |
|
|
|
Args: |
|
data (:obj:`dict`): A dictionary containing the input data and parameters for inference. |
|
For classification: |
|
{ |
|
"type": "classify", |
|
"inputs": { |
|
"candidates": :obj:`list[str]`, |
|
"image": :obj:`str` # URL or base64 encoded image |
|
} |
|
} |
|
For embedding: |
|
{ |
|
"type": "embedd", |
|
"batch": :obj:`list[str | dict[str, str]]` # Text or image+description |
|
} |
|
|
|
Returns: |
|
:obj:`dict`: The result of the operation. |
|
For classification: |
|
{ |
|
"label": :obj:`str` # The predicted label |
|
} |
|
For embedding: |
|
{ |
|
"embeddings": :obj:`list[list[float]]` # List of embeddings |
|
} |
|
|
|
Raises: |
|
:obj:`Exception`: If an error occurs during processing. |
|
""" |
|
inputs = data.pop("inputs", data) |
|
type = data.pop("type", "embedd") |
|
if type == "classify": |
|
candidate_labels = inputs["candidates"] |
|
image = ( |
|
Image.open(BytesIO(requests.get(inputs["image"]).content)) |
|
if "https://" in inputs["image"] |
|
else self.base64_image_to_pil(inputs["image"]) |
|
) |
|
response = self.classify_image(candidate_labels, image) |
|
return response |
|
elif type == "embedd": |
|
try: |
|
embeddings = self.process_batch(inputs) |
|
return {"embeddings": embeddings} |
|
except Exception as e: |
|
print(e) |
|
return e |