|
import requests |
|
from typing import Dict, Any |
|
from PIL import Image |
|
import torch |
|
import base64 |
|
from io import BytesIO |
|
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
load_dotenv() |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") |
|
self.model = AutoModelForVisualQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device) |
|
self.model.eval() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
if os.getenv("SECRET_TOKEN") and os.getenv("SECRET_TOKEN") != data.get("secret_token"): |
|
return {"captions": [], "error": "Invalid secret token"} |
|
|
|
input_data = data.get("inputs", {}) |
|
input_image = input_data.get("image") |
|
|
|
if not input_image: |
|
return {"answers": [], "error": "No image provided"} |
|
|
|
questions = input_data.get("questions") |
|
|
|
if not questions: |
|
return {"answers": [], "error": "No questions provided"} |
|
|
|
try: |
|
if input_image.get("base64"): |
|
raw_image = Image.open(BytesIO(base64.b64decode(input_image.get("base64")))).convert("RGB") |
|
elif input_image.get("url"): |
|
raw_image = Image.open(BytesIO(requests.get(input_image.get("url")).content)).convert("RGB") |
|
else: |
|
return {"answers": [], "error": "Invalid image input"} |
|
|
|
answers = [] |
|
|
|
for question in questions: |
|
|
|
processed_input = self.processor(raw_image, question, return_tensors="pt").to(device) |
|
|
|
|
|
out = self.model.generate(**processed_input) |
|
|
|
|
|
answer = self.processor.batch_decode(out, skip_special_tokens=True)[0] |
|
|
|
|
|
answers.append(answer) |
|
|
|
return {"answers": answers} |
|
|
|
except Exception as e: |
|
print(f"Error during processing: {str(e)}") |
|
return {"captions": [], "error": str(e)} |
|
|
|
|