Added requirements, handler and server.
Browse files- handler.py +77 -0
- requirements.txt +6 -0
- server.py +18 -0
handler.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from typing import Dict, Any
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import os
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
|
16 |
+
class EndpointHandler():
|
17 |
+
def __init__(self, path=""):
|
18 |
+
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
19 |
+
self.model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
|
20 |
+
self.model.eval()
|
21 |
+
|
22 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
23 |
+
if os.getenv("SECRET_TOKEN") and os.getenv("SECRET_TOKEN") != data.get("secret_token"):
|
24 |
+
return {"captions": [], "error": "Invalid secret token"}
|
25 |
+
|
26 |
+
input_data = data.get("inputs", {})
|
27 |
+
input_images = input_data.get("images")
|
28 |
+
|
29 |
+
if not input_images:
|
30 |
+
return {"captions": [], "error": "No images provided"}
|
31 |
+
|
32 |
+
# Get list of text arrays (one array per image) containing multiple questions
|
33 |
+
texts_per_image = input_data.get("texts", [[] for _ in input_images])
|
34 |
+
|
35 |
+
try:
|
36 |
+
raw_images = []
|
37 |
+
|
38 |
+
for img in input_images:
|
39 |
+
for key in img:
|
40 |
+
if key == "base64":
|
41 |
+
raw_images.append(Image.open(BytesIO(base64.b64decode(img[key]))).convert("RGB"))
|
42 |
+
elif key == "url":
|
43 |
+
raw_images.append(Image.open(BytesIO(requests.get(img[key]).content)).convert("RGB"))
|
44 |
+
else:
|
45 |
+
return {"captions": [], "error": f"Invalid image input: {key}"}
|
46 |
+
|
47 |
+
# List to store final captions (answers)
|
48 |
+
results = []
|
49 |
+
|
50 |
+
# Iterate over each image and its corresponding list of questions
|
51 |
+
for image, questions in zip(raw_images, texts_per_image):
|
52 |
+
image_captions = [] # Store answers for each image
|
53 |
+
|
54 |
+
for question in questions:
|
55 |
+
print(f"Question: {question}")
|
56 |
+
|
57 |
+
# Process the image and question
|
58 |
+
processed_input = self.processor(image, question, return_tensors="pt").to(device)
|
59 |
+
|
60 |
+
# Generate the answer
|
61 |
+
out = self.model.generate(**processed_input)
|
62 |
+
|
63 |
+
# Decode the answer
|
64 |
+
caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]
|
65 |
+
|
66 |
+
# Add the answer to the list for the current image
|
67 |
+
image_captions.append({"answer": caption})
|
68 |
+
|
69 |
+
# Store results for the current image
|
70 |
+
results.append({"image_results": image_captions})
|
71 |
+
|
72 |
+
return {"captions": results}
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Error during processing: {str(e)}")
|
76 |
+
return {"captions": [], "error": str(e)}
|
77 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests
|
2 |
+
Pillow
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
flask
|
6 |
+
python-dotenv
|
server.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
from handler import EndpointHandler # Import the class from handler.py
|
3 |
+
|
4 |
+
app = Flask(__name__)
|
5 |
+
|
6 |
+
handler = EndpointHandler()
|
7 |
+
|
8 |
+
@app.route("/", methods=["POST"])
|
9 |
+
def generate_captions():
|
10 |
+
try:
|
11 |
+
data = request.json
|
12 |
+
output = handler(data)
|
13 |
+
return jsonify(output)
|
14 |
+
except Exception as e:
|
15 |
+
return jsonify({"error": str(e)}), 400
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
app.run(host="0.0.0.0", port=5000)
|