File size: 7,962 Bytes
daf0288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
from abc import ABC
from pathlib import Path
from typing import Any, List, Literal, Mapping, Optional, Tuple
from zipfile import ZipFile
import json
from typing import Any, List, Literal, Mapping, Optional,Dict
import uuid
from doctr.models.preprocessor import PreProcessor
from doctr.models.recognition.predictor import RecognitionPredictor  # pylint: disable=W0611
from doctr.models.recognition.zoo import ARCHS, recognition
import torch
# Numpy image type
import numpy.typing as npt
from numpy import uint8
ImageType = npt.NDArray[uint8]

from utils import WordAnnotation,getlogger

class DoctrTextRecognizer():

    def __init__(
        self,
        architecture: str,
        path_weights: str,
        path_config_json: str = None,
    ) -> None:
        """
        :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
        "crnn_mobilenet_v3_small". The full list can be found here:
        https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16.
        :param path_weights: Path to the weights of the model
        :param device: "cpu" or "cuda".
        :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
        :param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have
        a model trained on custom vocab.
        """

        self.architecture = architecture
        self.path_weights = path_weights

        self.name = self.get_name(self.path_weights, self.architecture)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.path_config_json = path_config_json

        self.built_model = self.build_model(self.architecture, self.path_config_json)
        self.load_model(self.path_weights, self.built_model, self.device)
        self.doctr_predictor = self.get_wrapped_model()

    def predict(self, inputs: Dict[uuid.UUID, Tuple[ImageType,WordAnnotation]]) -> List[WordAnnotation]:
        
        """
        Prediction on a batch of text lines

        :param images: Dictionary where key is word's object id and the value is tupe of cropped image and word annotation
        :return: A list of DetectionResult
        """
        if inputs:
            

            predictor =self.doctr_predictor
            device = self.device
            
            word_uuids = list(inputs.keys())
            cropped_images = [value[0] for value in inputs.values()]

            raw_output = predictor(list(cropped_images))
            det_results =[]
            for uuid, output in zip(word_uuids, raw_output):
                ann = inputs[uuid][1]
                ann.text = output[0]
                det_results.append(ann)
            return det_results
        return []
    
    def predict_for_tables(self, inputs: List[ImageType]) -> List[str]:

        if inputs:

            predictor =self.doctr_predictor
            device = self.device
            
            raw_output = predictor(list(inputs))
            det_results =[]
            for output in raw_output:
                det_results.append(output[0])
            return det_results
        return []

    @staticmethod
    def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None:
        """Loading model weights
        1. Load the State Dictionary:
        state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device.
        2. Modify Keys in the State Dictionary:
        The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model.
        3. Load State Dictionary into Model:
        doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model.
        4. Move Model to Device:
        doctr_predictor.to(device) moves the model to the specified device.
        """
        state_dict = torch.load(path_weights, map_location=device)
        for key in list(state_dict.keys()):
            state_dict["model." + key] = state_dict.pop(key)
        doctr_predictor.load_state_dict(state_dict)
        doctr_predictor.to(device)

    @staticmethod
    def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor":
        """Building the model
        1. Specific keys (arch, url, task) are removed from custom_configs.
            mean and std values are moved to recognition_configs.   
        2. Creating model
        Check Architecture Type:
        Case 1 : 
        If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS).
        If valid, it creates an instance of the model using the specified architecture and custom configurations.
        Handle Custom Architecture Instances:
        Case 2 : 
        If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.).
        If valid, it assigns the provided architecture to model.
        Get Input Shape and Create RecognitionPredictor:

        3. Retrieves the input_shape from the model's configuration.
        4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model.
        """

        # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
        custom_configs = {}
        batch_size = 1024
        recognition_configs = {}
        if path_config_json:
             with open(path_config_json, "r", encoding="utf-8") as f:
                custom_configs = json.load(f)
                custom_configs.pop("arch", None)
                custom_configs.pop("url", None)
                custom_configs.pop("task", None)
                recognition_configs["mean"] = custom_configs.pop("mean")
                recognition_configs["std"] = custom_configs.pop("std")
                #batch_size = custom_configs.pop("batch_size")
                recognition_configs["batch_size"] = batch_size

        if isinstance(architecture, str):
            if architecture not in ARCHS:
                raise ValueError(f"unknown architecture '{architecture}'")

            model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
        else:
            if not isinstance(
                architecture,
                (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
            ):
                raise ValueError(f"unknown architecture: {type(architecture)}")
            model = architecture

        input_shape = model.cfg["input_shape"][-2:]
        """
        (class) PreProcessor
        Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.

        Args:
            output_size: expected size of each page in format (H, W)
            batch_size: the size of page batches
            mean: mean value of the training distribution by channel
            std: standard deviation of the training distribution by channel
        """
        return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)


    def get_wrapped_model(self) -> Any:
        """
        Get the inner (wrapped) model.
        """
        doctr_predictor = self.build_model(self.architecture, self.path_config_json)
        device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.load_model(self.path_weights, doctr_predictor, device_str)
        return doctr_predictor

    @staticmethod
    def get_name(path_weights: str, architecture: str) -> str:
        """Returns the name of the model"""
        return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])