Amitai Getzler
:heavy_plus_sign: Add
9ef8061
raw
history blame contribute delete
No virus
6.67 kB
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from typing import List, Dict, Any, Union
from base64 import b64decode
from io import BytesIO
import open_clip
import requests
import torch
import numpy as np
from PIL import Image
import uvicorn
app = FastAPI()
class EndpointHandler:
def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
self.model, self.preprocess_train, self.preprocess_val = (
open_clip.create_model_and_transforms(path)
)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = open_clip.get_tokenizer(path)
def classify_image(self, candidate_labels, image):
def get_top_prediction(text_probs, labels):
max_index = text_probs[0].argmax().item()
return {
"label": labels[max_index],
"score": text_probs[0][max_index].item(),
}
top_prediction = None
for i in range(0, len(candidate_labels), 10):
batch_labels = candidate_labels[i : i + 10]
image_tensor = self.preprocess_val(image).unsqueeze(0)
text = self.tokenizer(batch_labels)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.model.encode_image(image_tensor)
text_features = self.model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
current_top = get_top_prediction(text_probs, batch_labels)
if top_prediction is None or current_top["score"] > top_prediction["score"]:
top_prediction = current_top
return {"label": top_prediction["label"]}
def combine_embeddings(self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5):
if text_embeddings is not None:
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
else:
avg_text_embedding = np.zeros_like(image_embeddings[0])
if image_embeddings is not None:
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
else:
avg_image_embeddings = np.zeros_like(text_embeddings[0])
combined_embedding = np.average(
np.vstack((avg_text_embedding, avg_image_embeddings)),
axis=0,
weights=[text_weight, image_weight],
)
return combined_embedding
def average_text(self, doc):
text_chunks = [
" ".join(doc.split(" ")[i : i + 40])
for i in range(0, len(doc.split(" ")), 40)
]
text_embeddings = []
for chunk in text_chunks:
inputs = self.tokenizer(chunk)
text_features = self.model.encode_text(inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings.append(text_features.detach().squeeze().numpy())
combined = self.combine_embeddings(
text_embeddings, None, text_weight=1, image_weight=0
)
return combined
def embedd_image(self, doc) -> list:
if not isinstance(doc, str):
image = doc.get("image")
if "https://" in image:
image = image.split("|")
image = [
Image.open(BytesIO(response.content))
for response in [requests.get(image) for image in image]
][0]
image = self.preprocess_val(image).unsqueeze(0)
image_features = self.model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
image_embedding = image_features.detach().squeeze().numpy()
if doc.get("description", "") == "":
return image_embedding.tolist()
else:
average_texts = self.average_text(doc.get("description"))
combined = self.combine_embeddings(
[average_texts],
[image_embedding],
text_weight=0.5,
image_weight=0.5,
)
return combined.tolist()
elif isinstance(doc, str):
return self.average_text(doc).tolist()
def process_batch(self, batch) -> object:
try:
batch = batch.get("batch")
if not isinstance(batch, list):
return "Invalid input: batch must be an array of strings.", 400
embeddings = [self.embedd_image(item) for item in batch]
return embeddings
except Exception as e:
return "An error occurred while processing the request.", 500
def base64_image_to_pil(self, base64_str) -> Image:
image_data = b64decode(base64_str)
image_buffer = BytesIO(image_data)
image = Image.open(image_buffer)
return image
handler = EndpointHandler()
class ClassifyRequest(BaseModel):
candidates: List[str]
image: str
class EmbeddRequest(BaseModel):
batch: List[Union[str, Dict[str, str]]]
@app.post("/classify")
def classify(request: ClassifyRequest):
try:
image = (
Image.open(BytesIO(requests.get(request.image).content))
if "https://" in request.image
else handler.base64_image_to_pil(request.image)
)
response = handler.classify_image(request.candidates, image)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embedd")
def embedd(request: EmbeddRequest):
try:
embeddings = handler.process_batch(request.dict())
return {"embeddings": embeddings}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/process")
async def process(request: Request):
try:
data = await request.json()
if "candidates" in data and "image" in data:
classify_request = ClassifyRequest(**data)
return classify(classify_request)
elif "batch" in data:
embedd_request = EmbeddRequest(**data)
return embedd(embedd_request)
else:
raise HTTPException(status_code=400, detail="Invalid request format.")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)