sergeymal commited on
Commit
db1babc
·
1 Parent(s): e53f952

Added requirements, handler and server.

Browse files
Files changed (3) hide show
  1. handler.py +77 -0
  2. requirements.txt +6 -0
  3. server.py +18 -0
handler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Dict, Any
3
+ from PIL import Image
4
+ import torch
5
+ import base64
6
+ from io import BytesIO
7
+ from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering
8
+ from dotenv import load_dotenv
9
+ import os
10
+
11
+ load_dotenv()
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, path=""):
18
+ self.processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
19
+ self.model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
20
+ self.model.eval()
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ if os.getenv("SECRET_TOKEN") and os.getenv("SECRET_TOKEN") != data.get("secret_token"):
24
+ return {"captions": [], "error": "Invalid secret token"}
25
+
26
+ input_data = data.get("inputs", {})
27
+ input_images = input_data.get("images")
28
+
29
+ if not input_images:
30
+ return {"captions": [], "error": "No images provided"}
31
+
32
+ # Get list of text arrays (one array per image) containing multiple questions
33
+ texts_per_image = input_data.get("texts", [[] for _ in input_images])
34
+
35
+ try:
36
+ raw_images = []
37
+
38
+ for img in input_images:
39
+ for key in img:
40
+ if key == "base64":
41
+ raw_images.append(Image.open(BytesIO(base64.b64decode(img[key]))).convert("RGB"))
42
+ elif key == "url":
43
+ raw_images.append(Image.open(BytesIO(requests.get(img[key]).content)).convert("RGB"))
44
+ else:
45
+ return {"captions": [], "error": f"Invalid image input: {key}"}
46
+
47
+ # List to store final captions (answers)
48
+ results = []
49
+
50
+ # Iterate over each image and its corresponding list of questions
51
+ for image, questions in zip(raw_images, texts_per_image):
52
+ image_captions = [] # Store answers for each image
53
+
54
+ for question in questions:
55
+ print(f"Question: {question}")
56
+
57
+ # Process the image and question
58
+ processed_input = self.processor(image, question, return_tensors="pt").to(device)
59
+
60
+ # Generate the answer
61
+ out = self.model.generate(**processed_input)
62
+
63
+ # Decode the answer
64
+ caption = self.processor.batch_decode(out, skip_special_tokens=True)[0]
65
+
66
+ # Add the answer to the list for the current image
67
+ image_captions.append({"answer": caption})
68
+
69
+ # Store results for the current image
70
+ results.append({"image_results": image_captions})
71
+
72
+ return {"captions": results}
73
+
74
+ except Exception as e:
75
+ print(f"Error during processing: {str(e)}")
76
+ return {"captions": [], "error": str(e)}
77
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ requests
2
+ Pillow
3
+ torch
4
+ transformers
5
+ flask
6
+ python-dotenv
server.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from handler import EndpointHandler # Import the class from handler.py
3
+
4
+ app = Flask(__name__)
5
+
6
+ handler = EndpointHandler()
7
+
8
+ @app.route("/", methods=["POST"])
9
+ def generate_captions():
10
+ try:
11
+ data = request.json
12
+ output = handler(data)
13
+ return jsonify(output)
14
+ except Exception as e:
15
+ return jsonify({"error": str(e)}), 400
16
+
17
+ if __name__ == "__main__":
18
+ app.run(host="0.0.0.0", port=5000)