rroset commited on
Commit
70be7c1
1 Parent(s): 616b2e5

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +82 -0
handler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ import base64
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # Configuració de la quantització
12
+ quantization_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.float16,
16
+ )
17
+
18
+ # Carrega el processador i model de forma global
19
+ self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
20
+ self.model = LlavaNextForConditionalGeneration.from_pretrained(
21
+ "llava-hf/llava-v1.6-mistral-7b-hf",
22
+ quantization_config=quantization_config,
23
+ device_map="auto"
24
+ )
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
+ logs = []
28
+ logs.append("Iniciant processament de la petició.")
29
+
30
+ inputs = data.get("inputs")
31
+ if not inputs:
32
+ logs.append("Format d'entrada invàlid. Manca la clau 'inputs'.")
33
+ return {"error": "Invalid input format. 'inputs' key is missing.", "logs": logs}
34
+
35
+ image_url = inputs.get("url")
36
+ image_data = inputs.get("image_data")
37
+ prompt = inputs.get("prompt")
38
+ max_tokens = inputs.get("max_tokens", 100)
39
+
40
+ if not prompt:
41
+ logs.append("S'ha de proporcionar 'prompt' en 'inputs'.")
42
+ return {"error": "The 'prompt' must be provided in 'inputs'.", "logs": logs}
43
+
44
+ if not image_url and not image_data:
45
+ logs.append("S'ha de proporcionar 'url' o 'image_data' en 'inputs'.")
46
+ return {"error": "Either 'url' or 'image_data' must be provided in 'inputs'.", "logs": logs}
47
+
48
+ logs.append(f"Processant entrada: url={image_url}, image_data={'present' if image_data else 'absent'}, prompt={prompt}")
49
+
50
+ try:
51
+ if image_url:
52
+ logs.append(f"Carregant imatge des de URL: {image_url}")
53
+ response = requests.get(image_url, stream=True)
54
+ image = Image.open(response.raw)
55
+ elif image_data:
56
+ logs.append("Carregant imatge des de dades d'imatge en brut.")
57
+ image = Image.open(BytesIO(base64.b64decode(image_data)))
58
+
59
+ if image.format == 'PNG':
60
+ logs.append("Convertint imatge PNG a JPG.")
61
+ image = image.convert('RGB')
62
+ buffer = BytesIO()
63
+ image.save(buffer, format="JPEG")
64
+ buffer.seek(0)
65
+ image = Image.open(buffer)
66
+
67
+ except Exception as e:
68
+ logs.append(f"Error carregant imatge: {str(e)}")
69
+ return {"error": str(e), "logs": logs}
70
+
71
+ try:
72
+ logs.append("Processant imatge amb el model.")
73
+ inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
74
+ output = self.model.generate(**inputs, max_new_tokens=max_tokens)
75
+ result = self.processor.decode(output[0], skip_special_tokens=True)
76
+ logs.append("Processament complet.")
77
+ return {"input_prompt": prompt, "model_output": result, "logs": logs}
78
+
79
+ except Exception as e:
80
+ logs.append(f"Error processant el model: {str(e)}")
81
+ return {"error": str(e), "logs": logs}
82
+