adasdimchom's picture
Upload handler.py
016cd63
raw
history blame contribute delete
No virus
1.59 kB
from transformers import BlipProcessor, BlipForConditionalGeneration
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
class EndpointHandler():
def __init__(self, path=""):
"""
path:
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = BlipProcessor.from_pretrained(path)
self.model = BlipForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
result = {}
inputs = data.pop("inputs", data)
image_url = inputs['image_url']
if "prompt" in inputs:
prompt = inputs["prompt"]
else:
prompt = None
image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
if prompt:
processed_image = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
else:
processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
output = self.model.generate(**processed_image)
text_output = self.processor.decode(output[0], skip_special_tokens=True)
result["text_output"] = text_output
return result