Spaces:
Build error
Build error
import os | |
from abc import ABC | |
from pathlib import Path | |
from typing import Any, List, Literal, Mapping, Optional, Tuple | |
from zipfile import ZipFile | |
import json | |
from typing import Any, List, Literal, Mapping, Optional,Dict | |
import uuid | |
from doctr.models.preprocessor import PreProcessor | |
from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611 | |
from doctr.models.recognition.zoo import ARCHS, recognition | |
import torch | |
# Numpy image type | |
import numpy.typing as npt | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
from utils import WordAnnotation,getlogger | |
class DoctrTextRecognizer(): | |
def __init__( | |
self, | |
architecture: str, | |
path_weights: str, | |
path_config_json: str = None, | |
) -> None: | |
""" | |
:param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn", | |
"crnn_mobilenet_v3_small". The full list can be found here: | |
https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16. | |
:param path_weights: Path to the weights of the model | |
:param device: "cpu" or "cuda". | |
:param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used. | |
:param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have | |
a model trained on custom vocab. | |
""" | |
self.architecture = architecture | |
self.path_weights = path_weights | |
self.name = self.get_name(self.path_weights, self.architecture) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.path_config_json = path_config_json | |
self.built_model = self.build_model(self.architecture, self.path_config_json) | |
self.load_model(self.path_weights, self.built_model, self.device) | |
self.doctr_predictor = self.get_wrapped_model() | |
def predict(self, inputs: Dict[uuid.UUID, Tuple[ImageType,WordAnnotation]]) -> List[WordAnnotation]: | |
""" | |
Prediction on a batch of text lines | |
:param images: Dictionary where key is word's object id and the value is tupe of cropped image and word annotation | |
:return: A list of DetectionResult | |
""" | |
if inputs: | |
predictor =self.doctr_predictor | |
device = self.device | |
word_uuids = list(inputs.keys()) | |
cropped_images = [value[0] for value in inputs.values()] | |
raw_output = predictor(list(cropped_images)) | |
det_results =[] | |
for uuid, output in zip(word_uuids, raw_output): | |
ann = inputs[uuid][1] | |
ann.text = output[0] | |
det_results.append(ann) | |
return det_results | |
return [] | |
def predict_for_tables(self, inputs: List[ImageType]) -> List[str]: | |
if inputs: | |
predictor =self.doctr_predictor | |
device = self.device | |
raw_output = predictor(list(inputs)) | |
det_results =[] | |
for output in raw_output: | |
det_results.append(output[0]) | |
return det_results | |
return [] | |
def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None: | |
"""Loading model weights | |
1. Load the State Dictionary: | |
state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device. | |
2. Modify Keys in the State Dictionary: | |
The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model. | |
3. Load State Dictionary into Model: | |
doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model. | |
4. Move Model to Device: | |
doctr_predictor.to(device) moves the model to the specified device. | |
""" | |
state_dict = torch.load(path_weights, map_location=device) | |
for key in list(state_dict.keys()): | |
state_dict["model." + key] = state_dict.pop(key) | |
doctr_predictor.load_state_dict(state_dict) | |
doctr_predictor.to(device) | |
def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor": | |
"""Building the model | |
1. Specific keys (arch, url, task) are removed from custom_configs. | |
mean and std values are moved to recognition_configs. | |
2. Creating model | |
Check Architecture Type: | |
Case 1 : | |
If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS). | |
If valid, it creates an instance of the model using the specified architecture and custom configurations. | |
Handle Custom Architecture Instances: | |
Case 2 : | |
If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.). | |
If valid, it assigns the provided architecture to model. | |
Get Input Shape and Create RecognitionPredictor: | |
3. Retrieves the input_shape from the model's configuration. | |
4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model. | |
""" | |
# inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py | |
custom_configs = {} | |
batch_size = 1024 | |
recognition_configs = {} | |
if path_config_json: | |
with open(path_config_json, "r", encoding="utf-8") as f: | |
custom_configs = json.load(f) | |
custom_configs.pop("arch", None) | |
custom_configs.pop("url", None) | |
custom_configs.pop("task", None) | |
recognition_configs["mean"] = custom_configs.pop("mean") | |
recognition_configs["std"] = custom_configs.pop("std") | |
#batch_size = custom_configs.pop("batch_size") | |
recognition_configs["batch_size"] = batch_size | |
if isinstance(architecture, str): | |
if architecture not in ARCHS: | |
raise ValueError(f"unknown architecture '{architecture}'") | |
model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs) | |
else: | |
if not isinstance( | |
architecture, | |
(recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq), | |
): | |
raise ValueError(f"unknown architecture: {type(architecture)}") | |
model = architecture | |
input_shape = model.cfg["input_shape"][-2:] | |
""" | |
(class) PreProcessor | |
Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. | |
Args: | |
output_size: expected size of each page in format (H, W) | |
batch_size: the size of page batches | |
mean: mean value of the training distribution by channel | |
std: standard deviation of the training distribution by channel | |
""" | |
return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model) | |
def get_wrapped_model(self) -> Any: | |
""" | |
Get the inner (wrapped) model. | |
""" | |
doctr_predictor = self.build_model(self.architecture, self.path_config_json) | |
device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.load_model(self.path_weights, doctr_predictor, device_str) | |
return doctr_predictor | |
def get_name(path_weights: str, architecture: str) -> str: | |
"""Returns the name of the model""" | |
return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:]) | |