alps / doctrfiles /doctr_recognizer.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
import os
from abc import ABC
from pathlib import Path
from typing import Any, List, Literal, Mapping, Optional, Tuple
from zipfile import ZipFile
import json
from typing import Any, List, Literal, Mapping, Optional,Dict
import uuid
from doctr.models.preprocessor import PreProcessor
from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611
from doctr.models.recognition.zoo import ARCHS, recognition
import torch
# Numpy image type
import numpy.typing as npt
from numpy import uint8
ImageType = npt.NDArray[uint8]
from utils import WordAnnotation,getlogger
class DoctrTextRecognizer():
def __init__(
self,
architecture: str,
path_weights: str,
path_config_json: str = None,
) -> None:
"""
:param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
"crnn_mobilenet_v3_small". The full list can be found here:
https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16.
:param path_weights: Path to the weights of the model
:param device: "cpu" or "cuda".
:param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
:param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have
a model trained on custom vocab.
"""
self.architecture = architecture
self.path_weights = path_weights
self.name = self.get_name(self.path_weights, self.architecture)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.path_config_json = path_config_json
self.built_model = self.build_model(self.architecture, self.path_config_json)
self.load_model(self.path_weights, self.built_model, self.device)
self.doctr_predictor = self.get_wrapped_model()
def predict(self, inputs: Dict[uuid.UUID, Tuple[ImageType,WordAnnotation]]) -> List[WordAnnotation]:
"""
Prediction on a batch of text lines
:param images: Dictionary where key is word's object id and the value is tupe of cropped image and word annotation
:return: A list of DetectionResult
"""
if inputs:
predictor =self.doctr_predictor
device = self.device
word_uuids = list(inputs.keys())
cropped_images = [value[0] for value in inputs.values()]
raw_output = predictor(list(cropped_images))
det_results =[]
for uuid, output in zip(word_uuids, raw_output):
ann = inputs[uuid][1]
ann.text = output[0]
det_results.append(ann)
return det_results
return []
def predict_for_tables(self, inputs: List[ImageType]) -> List[str]:
if inputs:
predictor =self.doctr_predictor
device = self.device
raw_output = predictor(list(inputs))
det_results =[]
for output in raw_output:
det_results.append(output[0])
return det_results
return []
@staticmethod
def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None:
"""Loading model weights
1. Load the State Dictionary:
state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device.
2. Modify Keys in the State Dictionary:
The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model.
3. Load State Dictionary into Model:
doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model.
4. Move Model to Device:
doctr_predictor.to(device) moves the model to the specified device.
"""
state_dict = torch.load(path_weights, map_location=device)
for key in list(state_dict.keys()):
state_dict["model." + key] = state_dict.pop(key)
doctr_predictor.load_state_dict(state_dict)
doctr_predictor.to(device)
@staticmethod
def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor":
"""Building the model
1. Specific keys (arch, url, task) are removed from custom_configs.
mean and std values are moved to recognition_configs.
2. Creating model
Check Architecture Type:
Case 1 :
If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS).
If valid, it creates an instance of the model using the specified architecture and custom configurations.
Handle Custom Architecture Instances:
Case 2 :
If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.).
If valid, it assigns the provided architecture to model.
Get Input Shape and Create RecognitionPredictor:
3. Retrieves the input_shape from the model's configuration.
4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model.
"""
# inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
custom_configs = {}
batch_size = 1024
recognition_configs = {}
if path_config_json:
with open(path_config_json, "r", encoding="utf-8") as f:
custom_configs = json.load(f)
custom_configs.pop("arch", None)
custom_configs.pop("url", None)
custom_configs.pop("task", None)
recognition_configs["mean"] = custom_configs.pop("mean")
recognition_configs["std"] = custom_configs.pop("std")
#batch_size = custom_configs.pop("batch_size")
recognition_configs["batch_size"] = batch_size
if isinstance(architecture, str):
if architecture not in ARCHS:
raise ValueError(f"unknown architecture '{architecture}'")
model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
else:
if not isinstance(
architecture,
(recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
):
raise ValueError(f"unknown architecture: {type(architecture)}")
model = architecture
input_shape = model.cfg["input_shape"][-2:]
"""
(class) PreProcessor
Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
Args:
output_size: expected size of each page in format (H, W)
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
"""
return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)
def get_wrapped_model(self) -> Any:
"""
Get the inner (wrapped) model.
"""
doctr_predictor = self.build_model(self.architecture, self.path_config_json)
device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model(self.path_weights, doctr_predictor, device_str)
return doctr_predictor
@staticmethod
def get_name(path_weights: str, architecture: str) -> str:
"""Returns the name of the model"""
return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])