from base64 import b64decode from io import BytesIO import open_clip import requests import torch import numpy as np from PIL import Image from typing import Dict, Any class EndpointHandler: def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"): self.model, self.preprocess_train, self.preprocess_val = ( open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP") ) if torch.cuda.is_available(): self.model = self.model.cuda() self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP") def classify_image(self, candidate_labels, image): def get_top_prediction(text_probs, labels): max_index = text_probs[0].argmax().item() return { "label": labels[max_index], "score": text_probs[0][max_index].item(), } top_prediction = None for i in range(0, len(candidate_labels), 10): batch_labels = candidate_labels[i : i + 10] # Preprocess the image image_tensor = self.preprocess_val(image).unsqueeze(0) text = self.tokenizer(batch_labels) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = self.model.encode_image(image_tensor) text_features = self.model.encode_text(text) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) current_top = get_top_prediction(text_probs, batch_labels) if top_prediction is None or current_top["score"] > top_prediction["score"]: top_prediction = current_top return {"label": top_prediction["label"]} def combine_embeddings( self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5 ): """Combine text and image embeddings with specified weights.""" # Average text embeddings if text_embeddings is not None: avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0) else: avg_text_embedding = np.zeros_like(image_embeddings[0]) if image_embeddings is not None: avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0) else: avg_image_embeddings = np.zeros_like(text_embeddings[0]) # Combine text and image embeddings with specified weights combined_embedding = np.average( np.vstack((avg_text_embedding, avg_image_embeddings)), axis=0, weights=[text_weight, image_weight], ) return combined_embedding def average_text(self, doc): text_chunks = [ " ".join(doc.split(" ")[i : i + 40]) for i in range(0, len(doc.split(" ")), 40) ] text_embeddings = [] for chunk in text_chunks: inputs = self.tokenizer(chunk) text_features = self.model.encode_text(inputs) text_features /= text_features.norm(dim=-1, keepdim=True) text_embeddings.append(text_features.detach().squeeze().numpy()) combined = self.combine_embeddings( text_embeddings, None, text_weight=1, image_weight=0 ) return combined def embedd_image(self, doc) -> list: if not isinstance(doc, str): image = doc.get("image") if "https://" in image: image = image.split("|") # response = requests.get(image) image = [ Image.open(BytesIO(response.content)) for response in [requests.get(image) for image in image] ][0] # Simulate generating embeddings image = self.preprocess_val(image).unsqueeze(0) image_features = self.model.encode_image(image) image_features /= image_features.norm(dim=-1, keepdim=True) image_embedding = image_features.detach().squeeze().numpy() if doc.get("description", "") == "": print("empty description. Going with image alone") return image_embedding.tolist() else: average_texts = self.average_text(doc.get("description")) combined = self.combine_embeddings( [average_texts], [image_embedding], text_weight=0.5, image_weight=0.5, ) return combined.tolist() elif isinstance(doc, str): return self.average_text(doc).tolist() def process_batch(self, batch) -> object: try: batch = batch.get("batch") # Validate the batch input if not isinstance(batch, list): return "Invalid input: batch must be an array of strings.", 400 embeddings = [self.embedd_image(item) for item in batch] # Send the response with the embeddings array return embeddings except Exception as e: print("Error processing request", e) return "An error occurred while processing the request.", 500 def base64_image_to_pil(self, base64_str) -> Image: image_data = b64decode(base64_str) image_buffer = BytesIO(image_data) image = Image.open(image_buffer) return image def __call__(self, data: Any) -> Dict[str, Any]: """ Process the input data for either classification or embedding generation. Args: data (:obj:`dict`): A dictionary containing the input data and parameters for inference. For classification: { "type": "classify", "inputs": { "candidates": :obj:`list[str]`, "image": :obj:`str` # URL or base64 encoded image } } For embedding: { "type": "embedd", "batch": :obj:`list[str | dict[str, str]]` # Text or image+description } Returns: :obj:`dict`: The result of the operation. For classification: { "label": :obj:`str` # The predicted label } For embedding: { "embeddings": :obj:`list[list[float]]` # List of embeddings } Raises: :obj:`Exception`: If an error occurs during processing. """ inputs = data.pop("inputs", data) type = data.pop("type", "embedd") # Or classify if type == "classify": candidate_labels = inputs["candidates"] image = ( Image.open(BytesIO(requests.get(inputs["image"]).content)) if "https://" in inputs["image"] else self.base64_image_to_pil(inputs["image"]) ) response = self.classify_image(candidate_labels, image) return response elif type == "embedd": try: embeddings = self.process_batch(inputs) return {"embeddings": embeddings} except Exception as e: print(e) return e