Gabriel's picture
Update handler.py
6e7afc8 verified
raw
history blame
1.67 kB
import base64
import io
from typing import Any, Dict, List
import requests
import torch
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler:
def __init__(self, path=""):
self.processor = TrOCRProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.model.to(device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
image_input = inputs.get("image")
if not image_input:
return {"error": "No image provided."}
try:
if image_input.startswith("http"):
response = requests.get(image_input, stream=True)
if response.status_code == 200:
image = Image.open(response.raw).convert("RGB")
else:
return {
"error": f"Failed to fetch image. Status code: {response.status_code}"
}
else:
image_data = base64.b64decode(image_input)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process the image. Details: {str(e)}"}
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
generated_ids = self.model.generate(pixel_values.to(device))
prediction = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
return {"text": prediction[0]}