fashion-clip / handler.py
EMaghakyan's picture
add custom handler
98660fb
raw
history blame
1.23 kB
from typing import Dict, List, Any
from transformers import CLIPModel, AutoProcessor, AutoTokenizer
import torch
from PIL import Image
import requests
class EndpointHandler:
def __init__(self):
self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
self.processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
parameters = data.pop("parameters", {"mode": "image"})
inputs = data.pop("inputs", data)
with torch.no_grad():
if parameters["mode"] == "text":
inputs = self.tokenizer(inputs, padding=True, return_tensors="pt")
features = self.model.get_text_features(**inputs)
if parameters["mode"] == "image":
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor(images=image, return_tensors="pt")
features = self.model.get_image_features(**inputs)
return features[0].tolist()