File size: 7,629 Bytes
4e6f8d5 736c00c 4e6f8d5 93c7837 4e6f8d5 93c7837 6d8b6f3 4e6f8d5 772d70e 4e6f8d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from base64 import b64decode
from io import BytesIO
import open_clip
import requests
import torch
import numpy as np
from PIL import Image
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
self.model, self.preprocess_train, self.preprocess_val = (
open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP")
)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP")
def classify_image(self, candidate_labels, image):
def get_top_prediction(text_probs, labels):
max_index = text_probs[0].argmax().item()
return {
"label": labels[max_index],
"score": text_probs[0][max_index].item(),
}
top_prediction = None
for i in range(0, len(candidate_labels), 10):
batch_labels = candidate_labels[i : i + 10]
# Preprocess the image
image_tensor = self.preprocess_val(image).unsqueeze(0)
text = self.tokenizer(batch_labels)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.model.encode_image(image_tensor)
text_features = self.model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
current_top = get_top_prediction(text_probs, batch_labels)
if top_prediction is None or current_top["score"] > top_prediction["score"]:
top_prediction = current_top
return {"label": top_prediction["label"]}
def combine_embeddings(
self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5
):
"""Combine text and image embeddings with specified weights."""
# Average text embeddings
if text_embeddings is not None:
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
else:
avg_text_embedding = np.zeros_like(image_embeddings[0])
if image_embeddings is not None:
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
else:
avg_image_embeddings = np.zeros_like(text_embeddings[0])
# Combine text and image embeddings with specified weights
combined_embedding = np.average(
np.vstack((avg_text_embedding, avg_image_embeddings)),
axis=0,
weights=[text_weight, image_weight],
)
return combined_embedding
def average_text(self, doc):
text_chunks = [
" ".join(doc.split(" ")[i : i + 40])
for i in range(0, len(doc.split(" ")), 40)
]
text_embeddings = []
for chunk in text_chunks:
inputs = self.tokenizer(chunk)
text_features = self.model.encode_text(inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings.append(text_features.detach().squeeze().numpy())
combined = self.combine_embeddings(
text_embeddings, None, text_weight=1, image_weight=0
)
return combined
def embedd_image(self, doc) -> list:
if not isinstance(doc, str):
image = doc.get("image")
if "https://" in image:
image = image.split("|")
# response = requests.get(image)
image = [
Image.open(BytesIO(response.content))
for response in [requests.get(image) for image in image]
][0]
# Simulate generating embeddings
image = self.preprocess_val(image).unsqueeze(0)
image_features = self.model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_embedding = image_features.detach().squeeze().numpy()
if doc.get("description", "") == "":
print("empty description. Going with image alone")
return image_embedding.tolist()
else:
average_texts = self.average_text(doc.get("description"))
combined = self.combine_embeddings(
[average_texts],
[image_embedding],
text_weight=0.5,
image_weight=0.5,
)
return combined.tolist()
elif isinstance(doc, str):
return self.average_text(doc).tolist()
def process_batch(self, batch) -> object:
try:
batch = batch.get("batch")
# Validate the batch input
if not isinstance(batch, list):
return "Invalid input: batch must be an array of strings.", 400
embeddings = [self.embedd_image(item) for item in batch]
# Send the response with the embeddings array
return embeddings
except Exception as e:
print("Error processing request", e)
return "An error occurred while processing the request.", 500
def base64_image_to_pil(self, base64_str) -> Image:
image_data = b64decode(base64_str)
image_buffer = BytesIO(image_data)
image = Image.open(image_buffer)
return image
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Process the input data for either classification or embedding generation.
Args:
data (:obj:`dict`): A dictionary containing the input data and parameters for inference.
For classification:
{
"type": "classify",
"inputs": {
"candidates": :obj:`list[str]`,
"image": :obj:`str` # URL or base64 encoded image
}
}
For embedding:
{
"type": "embedd",
"batch": :obj:`list[str | dict[str, str]]` # Text or image+description
}
Returns:
:obj:`dict`: The result of the operation.
For classification:
{
"label": :obj:`str` # The predicted label
}
For embedding:
{
"embeddings": :obj:`list[list[float]]` # List of embeddings
}
Raises:
:obj:`Exception`: If an error occurs during processing.
"""
inputs = data.pop("inputs", data)
type = data.pop("type", "embedd") # Or classify
if type == "classify":
candidate_labels = inputs["candidates"]
image = (
Image.open(BytesIO(requests.get(inputs["image"]).content))
if "https://" in inputs["image"]
else self.base64_image_to_pil(inputs["image"])
)
response = self.classify_image(candidate_labels, image)
return response
elif type == "embedd":
try:
embeddings = self.process_batch(inputs)
return {"embeddings": embeddings}
except Exception as e:
print(e)
return e |