Spaces:
Build error
Build error
import os | |
from abc import ABC | |
from pathlib import Path | |
from typing import Any, List, Literal, Mapping, Optional, Tuple, Union, Dict, Type, Sequence | |
import json | |
import logging | |
import torch | |
from doctr.models.preprocessor import PreProcessor | |
from doctr.models.detection.predictor import DetectionPredictor # pylint: disable=W0611 | |
from doctr.models.detection.zoo import detection_predictor,detection | |
import numpy.typing as npt | |
import numpy as np | |
from numpy import uint8 | |
ImageType = npt.NDArray[uint8] | |
from utils import Annotation,getlogger,group_words_into_lines | |
ARCHS = [ | |
"db_resnet34", | |
"db_resnet50", | |
"db_mobilenet_v3_large", | |
"linknet_resnet18", | |
"linknet_resnet34", | |
"linknet_resnet50", | |
"fast_tiny", | |
"fast_small", | |
"fast_base", | |
] | |
class Wordboxes: | |
def __init__(self,score, box): | |
self.box = box | |
self.score = score | |
class DoctrWordDetector(): | |
""" | |
A deepdoctection wrapper of DocTr text line detector. We model text line detection as ObjectDetector | |
and assume to use this detector in a ImageLayoutService. | |
DocTr supports several text line detection implementations but provides only a subset of pre-trained models. | |
The most usable one for document OCR for which a pre-trained model exists is DBNet as described in “Real-time Scene | |
Text Detection with Differentiable Binarization”, with a ResNet-50 backbone. This model can be used in either | |
Tensorflow or PyTorch. | |
Some other pre-trained models exist that have not been registered in `ModelCatalog`. Please check the DocTr library | |
and organize the download of the pre-trained model by yourself. | |
**Example:** | |
path_weights_tl = ModelDownloadManager.maybe_download_weights_and_configs("doctr/db_resnet50/pt | |
/db_resnet50-ac60cadc.pt") | |
# Use "doctr/db_resnet50/tf/db_resnet50-adcafc63.zip" for Tensorflow | |
categories = ModelCatalog.get_profile("doctr/db_resnet50/pt/db_resnet50-ac60cadc.pt").categories | |
det = DoctrTextlineDetector("db_resnet50",path_weights_tl,categories,"cpu") | |
layout = ImageLayoutService(det,to_image=True, crop_image=True) | |
path_weights_tr = dd.ModelDownloadManager.maybe_download_weights_and_configs("doctr/crnn_vgg16_bn | |
/pt/crnn_vgg16_bn-9762b0b0.pt") | |
rec = DoctrTextRecognizer("crnn_vgg16_bn", path_weights_tr, "cpu") | |
text = TextExtractionService(rec, extract_from_roi="word") | |
analyzer = DoctectionPipe(pipeline_component_list=[layout,text]) | |
path = "/path/to/image_dir" | |
df = analyzer.analyze(path = path) | |
for dp in df: | |
... | |
""" | |
def __init__( | |
self, | |
architecture: str, | |
path_weights: str, | |
path_config_json:str | |
) -> None: | |
""" | |
:param architecture: DocTR supports various text line detection models, e.g. "db_resnet50", | |
"db_mobilenet_v3_large". The full list can be found here: | |
https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20 | |
:param path_weights: Path to the weights of the model | |
:param categories: A dict with the model output label and value | |
:param device: "cpu" or "cuda" or any tf.device or torch.device. The device must be compatible with the dll | |
:param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used. | |
""" | |
self.architecture = architecture | |
self.path_weights = path_weights | |
self.path_config_json =path_config_json | |
# Ensure the correct device is chosen (either CPU or CUDA if available) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize the model with the given architecture and path to weights | |
self.doctr_predictor = self.get_wrapped_model() | |
""" | |
Two static method so that they can be called without creating an instance of the class | |
Also, they don't require any instance specific data | |
""" | |
def get_wrapped_model( | |
self | |
) -> Any: | |
""" | |
Get the inner (wrapped) model. | |
:param architecture: DocTR supports various text line detection models, e.g. "db_resnet50", | |
"db_mobilenet_v3_large". The full list can be found here: | |
https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20 | |
:param path_weights: Path to the weights of the model | |
:return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow | |
""" | |
""" | |
(function) detection_predictor: ((arch: Any = "db_resnet50", pretrained: bool = False, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor) | |
""" | |
#doctr_predictor = detection_predictor(arch=architecture, pretrained=False, pretrained_backbone=False) | |
#doctr_predictor = detection_predictor(arch=architecture, pretrained=False) | |
doctr_predictor = self.build_model(self.architecture, self.path_config_json) | |
self.load_model(self.path_weights, doctr_predictor, self.device) | |
return doctr_predictor | |
def build_model(arch: str, pretrained = False,assume_straight_pages=True, path_config_json: Optional[str] = None) -> "DetectionPredictor": | |
"""Building the model | |
1. Specific keys (arch, url, task) are removed from custom_configs. | |
mean and std values are moved to recognition_configs. | |
2. Creating model | |
Check Architecture Type: | |
Case 1 : | |
If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS). | |
If valid, it creates an instance of the model using the specified architecture and custom configurations. | |
Handle Custom Architecture Instances: | |
Case 2 : | |
If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.). | |
If valid, it assigns the provided architecture to model. | |
Get Input Shape and Create RecognitionPredictor: | |
3. Retrieves the input_shape from the model's configuration. | |
4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model. | |
""" | |
custom_configs = {} | |
batch_size = 4 | |
detection_configs = {} | |
if path_config_json: | |
with open(path_config_json, "r", encoding="utf-8") as f: | |
custom_configs = json.load(f) | |
custom_configs.pop("arch", None) | |
custom_configs.pop("url", None) | |
custom_configs.pop("task", None) | |
detection_configs["mean"] = custom_configs.pop("mean") | |
detection_configs["std"] = custom_configs.pop("std") | |
#batch_size = custom_configs.pop("batch_size") | |
detection_configs["batch_size"] = batch_size | |
if isinstance(arch, str): | |
if arch not in ARCHS: | |
raise ValueError(f"unknown architecture '{arch}'") | |
model = detection.__dict__[arch]( | |
pretrained=pretrained, | |
assume_straight_pages=assume_straight_pages | |
) | |
else: | |
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): | |
raise ValueError(f"unknown architecture: {type(arch)}") | |
model = arch | |
model.assume_straight_pages = assume_straight_pages | |
input_shape = model.cfg["input_shape"][-2:] | |
predictor = DetectionPredictor( | |
PreProcessor(input_shape, batch_size=batch_size,**detection_configs), | |
model | |
) | |
return predictor | |
def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None: | |
"""Loading model weights | |
1. Load the State Dictionary: | |
state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device. | |
2. Modify Keys in the State Dictionary: | |
The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model. | |
3. Load State Dictionary into Model: | |
doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model. | |
4. Move Model to Device: | |
doctr_predictor.to(device) moves the model to the specified device. | |
""" | |
state_dict = torch.load(path_weights, map_location=device) | |
for key in list(state_dict.keys()): | |
state_dict["model." + key] = state_dict.pop(key) | |
doctr_predictor.load_state_dict(state_dict) | |
doctr_predictor.to(device) | |
def predict(self, np_img: ImageType,sort_vertical = False) -> List[Wordboxes]: | |
""" | |
Prediction per image. | |
:param np_img: image as numpy array | |
:return: A list of DetectionResult | |
""" | |
raw_output =self.doctr_predictor([np_img]) | |
height, width = np_img.shape[:2] | |
""" | |
raw_output is arrary of dictionary with just one key "words" | |
1-4th element : coordinates You take first 4 elements in this array by doing box[:4] | |
5th element - score | |
But those are 4 point and we need 4X2 | |
type(raw_output[0]["words"]) are numpy arrary | |
Okay hypothesis :xmin, ymin, xmax, ymax | |
Points should be ordered in this order :left_lower, right_lower, right_upper, left_upper | |
""" | |
logger = getlogger("array") | |
# Check if the logger has any handlers | |
if (logger.hasHandlers()): | |
logger.handlers.clear() | |
# Create a handler | |
handler = logging.StreamHandler() | |
# Create a formatter and add it to the handler | |
formatter = logging.Formatter('%(levelname)s:%(message)s') | |
handler.setFormatter(formatter) | |
# Add the handler to the logger | |
logger.addHandler(handler) | |
#logger.info(raw_output[0]["words"]) | |
#array is numpy array of shape (n,5) where n is number of words and 5 is size of each element(array) with coordinate(xmin,ymin,xmax,ymax) + score | |
array = raw_output[0]["words"] | |
if not sort_vertical: | |
#Only When input has one line | |
sorted_array = array[array[:, 0].argsort()] | |
else: | |
#When input can have multiple lines | |
sorted_array = group_words_into_lines(array) | |
#logger.info(sorted_array) | |
detection_results = [] | |
for box in sorted_array: | |
xmin, ymin, xmax, ymax = box[:4] | |
xmin = xmin*width | |
ymin = ymin*height | |
xmax = xmax*width | |
ymax = ymax*height | |
newb = np.array([ | |
[xmin, ymin], | |
[xmax, ymin], | |
[xmax, ymax], | |
[xmin, ymax] | |
], dtype=np.float32) | |
assert newb.shape == (4, 2), f"Points array must be of shape (4, 2), but got {box.shape}" | |
assert newb.dtype == np.float32, f"Points array must be of dtype float32, but got {box.dtype}" | |
w = Wordboxes( | |
score=box[4], | |
box = newb | |
) | |
detection_results.append(w) | |
return detection_results | |