llava-v1.6-34b / handler.py
rroset's picture
Update handler.py
d1b5158 verified
raw
history blame contribute delete
No virus
3.44 kB
from typing import Dict, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path=""):
# Configuraci贸 de la quantitzaci贸
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Carrega el processador i model de forma global
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
self.model = LlavaNextForConditionalGeneration.from_pretrained(
"rroset/llava-v1.6-34b",
quantization_config=quantization_config,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
logs = []
logs.append("Iniciant processament de la petici贸.")
inputs = data.get("inputs")
if not inputs:
logs.append("Format d'entrada inv脿lid. Manca la clau 'inputs'.")
return {"error": "Invalid input format. 'inputs' key is missing.", "logs": logs}
image_url = inputs.get("url")
image_data = inputs.get("image_data")
prompt = inputs.get("prompt")
max_tokens = inputs.get("max_tokens", 100)
if not prompt:
logs.append("S'ha de proporcionar 'prompt' en 'inputs'.")
return {"error": "The 'prompt' must be provided in 'inputs'.", "logs": logs}
if not image_url and not image_data:
logs.append("S'ha de proporcionar 'url' o 'image_data' en 'inputs'.")
return {"error": "Either 'url' or 'image_data' must be provided in 'inputs'.", "logs": logs}
logs.append(f"Processant entrada: url={image_url}, image_data={'present' if image_data else 'absent'}, prompt={prompt}")
try:
if image_url:
logs.append(f"Carregant imatge des de URL: {image_url}")
response = requests.get(image_url, stream=True)
image = Image.open(response.raw)
elif image_data:
logs.append("Carregant imatge des de dades d'imatge en brut.")
image = Image.open(BytesIO(base64.b64decode(image_data)))
if image.format == 'PNG':
logs.append("Convertint imatge PNG a JPG.")
image = image.convert('RGB')
buffer = BytesIO()
image.save(buffer, format="JPEG")
buffer.seek(0)
image = Image.open(buffer)
except Exception as e:
logs.append(f"Error carregant imatge: {str(e)}")
return {"error": str(e), "logs": logs}
try:
logs.append("Processant imatge amb el model.")
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
output = self.model.generate(**inputs, max_new_tokens=max_tokens)
result = self.processor.decode(output[0], skip_special_tokens=True)
logs.append("Processament complet.")
return {"input_prompt": prompt, "model_output": result, "logs": logs}
except Exception as e:
logs.append(f"Error processant el model: {str(e)}")
return {"error": str(e), "logs": logs}