rroset's picture
Update handler.py
44df4d6 verified
raw
history blame
No virus
2.3 kB
from typing import Dict, List, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import re
class EndpointHandler():
def __init__(self, path=""):
# Configuraci贸 de la quantitzaci贸
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Carrega el processador i model de forma global
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
self.model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
quantization_config=quantization_config,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image_url = data.get("url")
prompt = data.get("prompt")
try:
response = requests.get(image_url, stream=True)
image = Image.open(response.raw)
if image.format == 'PNG':
image = image.convert('RGB')
buffer = BytesIO()
image.save(buffer, format="JPEG")
buffer.seek(0)
image = Image.open(buffer)
except Exception as e:
return {"error": str(e)}
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
output = self.model.generate(**inputs, max_new_tokens=100)
result = self.processor.decode(output[0], skip_special_tokens=True)
scores = self.extract_scores(result)
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
return sorted_scores
def extract_scores(self, response):
scores = {}
result_part = response.split("[/INST]")[-1].strip()
pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)')
matches = pattern.findall(result_part)
for match in matches:
category_number = int(match[0])
category_name = match[1].strip()
score = int(match[2])
scores[category_name] = score
return scores