from typing import Dict, List, Any from transformers import pipeline from PIL import Image import requests import torch from transformers import AutoProcessor, LlavaForConditionalGeneration class EndpointHandler(): def __init__(self, path="."): self.model = LlavaForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(0) self.processor = AutoProcessor.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs prompt = "USER: \nWhat's in the image\nASSISTANT:" default_url = "https://cdn.faire.com/fastly/3c335e5c06d3027964ee8351093784c94dfa264e5eb26430c803f4ab3c44da84.jpeg" url = data.pop("image_url", default_url) inputs = data.pop("inputs", None) image = Image.open(requests.get(url, stream=True).raw) inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16) # run normal prediction output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False) print(output) return output