import io from typing import Dict, List, Any from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3FeatureExtractor, LayoutLMv3Tokenizer, LayoutLMv3Processor import torch from subprocess import run from PIL import Image # install tesseract-ocr and pytesseract run("apt install -y tesseract-ocr", shell=True, check=True) run("python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html", shell=True, check=True) # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.FEATURE_EXTRACTOR = LayoutLMv3FeatureExtractor() self.TOKENIZER = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") self.PROCESSOR = LayoutLMv3Processor(self.FEATURE_EXTRACTOR, self.TOKENIZER) self.MODEL = LayoutLMv3ForSequenceClassification.from_pretrained("OtraBoi/document_classifier_testing").to(device) def __call__(self, data: Dict[str, bytes]): #image = Image.open(io.BytesIO(data["inputs"])).convert("RGB") image = data.pop("inputs", data) encoding = self.PROCESSOR(image, return_tensors="pt", padding="max_length", truncation=True) for k,v in encoding.items(): encoding[k] = v.to(self.MODEL.device) # run prediction with torch.inference_mode(): outputs = self.MODEL(**encoding) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return self.MODEL.config.id2label[predicted_class_idx]