from typing import Dict, List, Any import torch from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig from PIL import Image import requests from io import BytesIO import re class EndpointHandler(): def __init__(self, path=""): # ConfiguraciĆ³ de la quantitzaciĆ³ quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) # Carrega el processador i model de forma global self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") self.model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", quantization_config=quantization_config, device_map="auto" ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.get("inputs") if not inputs: return {"error": "Invalid input format. 'inputs' key is missing."} image_url = inputs.get("url") prompt = inputs.get("prompt") if not image_url or not prompt: return {"error": "Both 'url' and 'prompt' must be provided in 'inputs'."} # DepuraciĆ³ debug_info = { "stage": "processing input", "image_url": image_url, "prompt": prompt } try: response = requests.get(image_url, stream=True) image = Image.open(response.raw) if image.format == 'PNG': image = image.convert('RGB') buffer = BytesIO() image.save(buffer, format="JPEG") buffer.seek(0) image = Image.open(buffer) except Exception as e: debug_info["stage"] = "loading image" debug_info["error"] = str(e) return debug_info try: inputs = self.processor(prompt, image, return_tensors="pt").to("cuda") output = self.model.generate(**inputs, max_new_tokens=100) result = self.processor.decode(output[0], skip_special_tokens=True) scores = self.extract_scores(result) sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) return sorted_scores except Exception as e: debug_info["stage"] = "processing model" debug_info["error"] = str(e) return debug_info def extract_scores(self, response): scores = {} result_part = response.split("[/INST]")[-1].strip() pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)') matches = pattern.findall(result_part) for match in matches: category_number = int(match[0]) category_name = match[1].strip() score = int(match[2]) scores[category_name] = score return scores