Commit
•
ef5ce50
1
Parent(s):
964ef5d
This commit refactors the `handler.py` file to improve the performance of the Visual Question Answering (VQA) model. The changes include:
Browse files- Loading the VQA pipeline for the model
- Modifying the `__call__` method to extract the image and question from the request
- Performing the VQA using the pipeline
These changes aim to enhance the efficiency and accuracy of the VQA process.
- handler.py +34 -20
handler.py
CHANGED
@@ -1,24 +1,38 @@
|
|
1 |
-
from typing import Dict, List
|
2 |
-
from transformers import AutoModel, AutoTokenizer
|
3 |
-
from PIL import Image
|
4 |
|
5 |
-
|
|
|
|
|
|
|
6 |
def __init__(self, path=""):
|
7 |
-
# Preload all the elements you are going to need at inference.
|
8 |
-
self.model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
|
9 |
-
self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
)
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
|
|
|
|
2 |
|
3 |
+
from transformers import AutoModel, AutoTokenizer, pipeline
|
4 |
+
|
5 |
+
|
6 |
+
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
|
|
|
|
|
|
8 |
|
9 |
+
# Load the pipeline for the model
|
10 |
+
model = AutoModel.from_pretrained(
|
11 |
+
"openbmb/MiniCPM-Llama3-V-2_5-int4",
|
12 |
+
trust_remote_code=True,
|
13 |
+
)
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
15 |
+
"openbmb/MiniCPM-Llama3-V-2_5-int4", trust_remote_code=True
|
16 |
+
)
|
17 |
+
self.pipeline = pipeline(
|
18 |
+
"visual-question-answering",
|
19 |
+
model=model,
|
20 |
+
tokenizer=tokenizer,
|
21 |
)
|
22 |
+
|
23 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
24 |
+
# Get the image and question from the request
|
25 |
+
image = data.get("image")
|
26 |
+
question = data.get("question")
|
27 |
+
|
28 |
+
# Perform the VQA
|
29 |
+
return self.pipeline(image, question)
|
30 |
+
|
31 |
+
|
32 |
+
# if __name__ == "__main__":
|
33 |
+
# handler = EndpointHandler()
|
34 |
+
# data = {
|
35 |
+
# "image": "https://pwm.im-cdn.it/image/1524723057/xxl.jpg",
|
36 |
+
# "question": "Describe the image:",
|
37 |
+
# }
|
38 |
+
# print(handler(data))
|