|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import re |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
|
|
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") |
|
self.model = LlavaNextForConditionalGeneration.from_pretrained( |
|
"llava-hf/llava-v1.6-mistral-7b-hf", |
|
quantization_config=quantization_config, |
|
device_map="auto" |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
image_url = data.get("url") |
|
prompt = data.get("prompt") |
|
|
|
try: |
|
response = requests.get(image_url, stream=True) |
|
image = Image.open(response.raw) |
|
|
|
if image.format == 'PNG': |
|
image = image.convert('RGB') |
|
buffer = BytesIO() |
|
image.save(buffer, format="JPEG") |
|
buffer.seek(0) |
|
image = Image.open(buffer) |
|
|
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda") |
|
output = self.model.generate(**inputs, max_new_tokens=100) |
|
result = self.processor.decode(output[0], skip_special_tokens=True) |
|
|
|
scores = self.extract_scores(result) |
|
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True) |
|
return sorted_scores |
|
|
|
def extract_scores(self, response): |
|
scores = {} |
|
result_part = response.split("[/INST]")[-1].strip() |
|
pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)') |
|
matches = pattern.findall(result_part) |
|
for match in matches: |
|
category_number = int(match[0]) |
|
category_name = match[1].strip() |
|
score = int(match[2]) |
|
scores[category_name] = score |
|
return scores |
|
|