eBoreal commited on
Commit
b493bea
1 Parent(s): 0ce869a

base commit

Browse files
Files changed (3) hide show
  1. README.md +6 -0
  2. handler.py +67 -0
  3. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - llava-next
6
+ license: apache-2.0
handler.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from tempfile import TemporaryDirectory
3
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
4
+ from PIL import Image
5
+ import torch
6
+ import requests
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
12
+
13
+
14
+ device = 'cpu' if torch.cuda.is_available() else 'cpu'
15
+
16
+ model = LlavaNextForConditionalGeneration.from_pretrained(
17
+ "llava-hf/llava-v1.6-mistral-7b-hf",
18
+ torch_dtype=torch.float32 if device == 'cpu' else torch.float16,
19
+ low_cpu_mem_usage=True
20
+ )
21
+ model.to(device)
22
+
23
+ self.model = model
24
+ self.device = device
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
+ """
28
+ data args:
29
+ text (:obj: `str`)
30
+ files (:obj: `list`) - List of URLs to images
31
+ Return:
32
+ A :obj:`list` | `dict`: will be serialized and returned
33
+ """
34
+ # get inputs
35
+ prompt = data.pop("prompt", data)
36
+ # get additional date field0
37
+ image_url = data.pop("files", None)[-1]['path']
38
+
39
+ print(image_url)
40
+ print(prompt)
41
+
42
+ if image_url is None:
43
+ return "You need to upload an image URL for LLaVA to work."
44
+
45
+ # Create a temporary directory
46
+ with TemporaryDirectory() as tmpdirname:
47
+ # Download the image
48
+ response = requests.get(image_url)
49
+ if response.status_code != 200:
50
+ return "Failed to download the image."
51
+
52
+ # Define the path for the downloaded image
53
+ image_path = f"{tmpdirname}/image.jpg"
54
+ with open(image_path, "wb") as f:
55
+ f.write(response.content)
56
+
57
+ # Open the downloaded image
58
+ with Image.open(image_path).convert("RGB") as image:
59
+ prompt = f"[INST] <image>\n{prompt} [/INST]"
60
+
61
+ inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
62
+
63
+ output = self.model.generate(**inputs, max_new_tokens=100)
64
+
65
+ clean = self.processor.decode(output[0], skip_special_tokens=True)
66
+
67
+ return clean
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ spaces
4
+ pillow
5
+ accelerate
6
+ requests