from typing import Dict, List, Any | |
from tempfile import TemporaryDirectory | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
from PIL import Image | |
import torch | |
import requests | |
class EndpointHandler: | |
def __init__(self): | |
pass | |
# self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
# device = 'gpu' if torch.cuda.is_available() else 'cpu' | |
# model = LlavaNextForConditionalGeneration.from_pretrained( | |
# "llava-hf/llava-v1.6-mistral-7b-hf", | |
# torch_dtype=torch.float32 if device == 'cpu' else torch.float16, | |
# low_cpu_mem_usage=True | |
# ) | |
# model.to(device) | |
# self.model = model | |
# self.device = device | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
inputs = data.get("inputs", "") | |
if not inputs: | |
return [{"error": "No inputs provided"}] | |
return inputs | |
# def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
# """ | |
# data args: | |
# inputs (:obj: `dict`) | |
# Return: | |
# A :obj:`list` | `dict`: will be serialized and returned | |
# """ | |
# # get inputs | |
# inputs = data.get("inputs") | |
# if not inputs: | |
# return f"Inputs not in payload got {data}" | |
# # get additional date field0 | |
# prompt = inputs.get("prompt") | |
# image_url = inputs.get("image") | |
# if image_url is None: | |
# return "You need to upload an image URL for LLaVA to work." | |
# if prompt is None: | |
# prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected." | |
# if not self.model: | |
# return "Model was not initialized" | |
# if not self.processor: | |
# return "Processor was not initialized" | |
# # Create a temporary directory | |
# with TemporaryDirectory() as tmpdirname: | |
# # Download the image | |
# response = requests.get(image_url) | |
# if response.status_code != 200: | |
# return "Failed to download the image." | |
# # Define the path for the downloaded image | |
# image_path = f"{tmpdirname}/image.jpg" | |
# with open(image_path, "wb") as f: | |
# f.write(response.content) | |
# # Open the downloaded image | |
# with Image.open(image_path).convert("RGB") as image: | |
# prompt = f"[INST] <image>\n{prompt} [/INST]" | |
# inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) | |
# output = self.model.generate(**inputs, max_new_tokens=100) | |
# if not output: | |
# return 'Model failed to generate' | |
# clean = self.processor.decode(output[0], skip_special_tokens=True) | |
# return clean |