hongyiyang's picture
use LlavaForConditionalGeneration to load instead
02212ce
raw
history blame
1.38 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
class EndpointHandler():
def __init__(self, path="."):
self.model = LlavaForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
prompt = "USER: <image>\nWhat's in the image\nASSISTANT:"
default_url = "https://cdn.faire.com/fastly/3c335e5c06d3027964ee8351093784c94dfa264e5eb26430c803f4ab3c44da84.jpeg"
url = data.pop("image_url", default_url)
inputs = data.pop("inputs", None)
image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
# run normal prediction
output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(output)
return output