llava-next-inference / handler.py
eBoreal's picture
remove req
5f53f7b
from typing import Dict, List, Any
from tempfile import TemporaryDirectory
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch
import requests
class EndpointHandler:
def __init__(self):
pass
# self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
# device = 'gpu' if torch.cuda.is_available() else 'cpu'
# model = LlavaNextForConditionalGeneration.from_pretrained(
# "llava-hf/llava-v1.6-mistral-7b-hf",
# torch_dtype=torch.float32 if device == 'cpu' else torch.float16,
# low_cpu_mem_usage=True
# )
# model.to(device)
# self.model = model
# self.device = device
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", "")
if not inputs:
return [{"error": "No inputs provided"}]
return inputs
# def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# """
# data args:
# inputs (:obj: `dict`)
# Return:
# A :obj:`list` | `dict`: will be serialized and returned
# """
# # get inputs
# inputs = data.get("inputs")
# if not inputs:
# return f"Inputs not in payload got {data}"
# # get additional date field0
# prompt = inputs.get("prompt")
# image_url = inputs.get("image")
# if image_url is None:
# return "You need to upload an image URL for LLaVA to work."
# if prompt is None:
# prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected."
# if not self.model:
# return "Model was not initialized"
# if not self.processor:
# return "Processor was not initialized"
# # Create a temporary directory
# with TemporaryDirectory() as tmpdirname:
# # Download the image
# response = requests.get(image_url)
# if response.status_code != 200:
# return "Failed to download the image."
# # Define the path for the downloaded image
# image_path = f"{tmpdirname}/image.jpg"
# with open(image_path, "wb") as f:
# f.write(response.content)
# # Open the downloaded image
# with Image.open(image_path).convert("RGB") as image:
# prompt = f"[INST] <image>\n{prompt} [/INST]"
# inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
# output = self.model.generate(**inputs, max_new_tokens=100)
# if not output:
# return 'Model failed to generate'
# clean = self.processor.decode(output[0], skip_special_tokens=True)
# return clean