File size: 7,283 Bytes
01bb3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import sys
import os
import torch
from PIL import Image
import cv2
import numpy as np


# from IndicPhotoOCR.detection.east_detector import EASTdetector
# from IndicPhotoOCR.script_identification.CLIP_identifier import CLIPidentifier
from IndicPhotoOCR.script_identification.vit.vit_infer import VIT_identifier
from IndicPhotoOCR.recognition.parseq_recogniser import PARseqrecogniser
import IndicPhotoOCR.detection.east_config as cfg
from IndicPhotoOCR.detection.textbpn.textbpnpp_detector import TextBPNpp_detector

from IndicPhotoOCR.utils.helper import detect_para


class OCR:
    def __init__(self, device='cuda:0', verbose=False):
        # self.detect_model_checkpoint = detect_model_checkpoint
        self.device = device
        self.verbose = verbose
        # self.image_path = image_path
        # self.detector = EASTdetector()
        self.detector = TextBPNpp_detector(device=self.device)
        self.recogniser = PARseqrecogniser()
        # self.identifier = CLIPidentifier()
        self.identifier = VIT_identifier()

    # def detect(self, image_path, detect_model_checkpoint=cfg.checkpoint):
    #     """Run the detection model to get bounding boxes of text areas."""

    #     if self.verbose:
    #         print("Running text detection...")
    #     detections = self.detector.detect(image_path, detect_model_checkpoint, self.device)
    #     # print(detections)
    #     return detections['detections']
    def detect(self, image_path):
        self.detections = self.detector.detect(image_path)
        return self.detections['detections']

    def visualize_detection(self, image_path, detections, save_path=None, show=False):
        # Default save path if none is provided
        default_save_path = "test.png"
        path_to_save = save_path if save_path is not None else default_save_path

        # Get the directory part of the path
        directory = os.path.dirname(path_to_save)
        
        # Check if the directory exists, and create it if it doesn’t
        if directory and not os.path.exists(directory):
            os.makedirs(directory)
            print(f"Created directory: {directory}")

        # Read the image and draw bounding boxes
        image = cv2.imread(image_path)
        for box in detections:
            # Convert list of points to a numpy array with int type
            points = np.array(box, np.int32)

            # Compute the top-left and bottom-right corners of the bounding box
            x_min = np.min(points[:, 0])
            y_min = np.min(points[:, 1])
            x_max = np.max(points[:, 0])
            y_max = np.max(points[:, 1])

            # Draw the rectangle
            cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=3)

        # Show the image if 'show' is True
        if show:
            plt.figure(figsize=(10, 10))
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            plt.axis("off")
            plt.show()

        # Save the annotated image
        cv2.imwrite(path_to_save, image)
        print(f"Image saved at: {path_to_save}")

    def crop_and_identify_script(self, image, bbox):
        """
        Crop a text area from the image and identify its script language.

        Args:
            image (PIL.Image): The full image.
            bbox (list): List of four corner points, each a [x, y] pair.

        Returns:
            str: Identified script language.
        """
        # Extract x and y coordinates from the four corner points
        x_coords = [point[0] for point in bbox]
        y_coords = [point[1] for point in bbox]

        # Get the bounding box coordinates (min and max)
        x_min, y_min = min(x_coords), min(y_coords)
        x_max, y_max = max(x_coords), max(y_coords)

        # Crop the image based on the bounding box
        cropped_image = image.crop((x_min, y_min, x_max, y_max))
        root_image_dir = "IndicPhotoOCR/script_identification"
        os.makedirs(f"{root_image_dir}/images", exist_ok=True)
        # Temporarily save the cropped image to pass to the script model
        cropped_path = f'{root_image_dir}/images/temp_crop_{x_min}_{y_min}.jpg'
        cropped_image.save(cropped_path)

        # Predict script language, here we assume "hindi" as the model name
        if self.verbose:
            print("Identifying script for the cropped area...")
        script_lang = self.identifier.identify(cropped_path, "hindi", self.device)  # Use "hindi" as the model name
        # print(script_lang)

        # Clean up temporary file
        # os.remove(cropped_path)

        return script_lang, cropped_path

    def recognise(self, cropped_image_path, script_lang):
        """Recognize text in a cropped image area using the identified script."""
        if self.verbose:
            print("Recognizing text in detected area...")
        recognized_text = self.recogniser.recognise(script_lang, cropped_image_path, script_lang, self.verbose, self.device)
        # print(recognized_text)
        return recognized_text

    def ocr(self, image_path):
        """Process the image by detecting text areas, identifying script, and recognizing text."""
        recognized_texts = {}
        recognized_words = []
        image = Image.open(image_path)
        
        # Run detection
        detections = self.detect(image_path)

        # Process each detected text area
        # for bbox in detections:
            # # Crop and identify script language
            # script_lang, cropped_path = self.crop_and_identify_script(image, bbox)

            # # Check if the script language is valid
            # if script_lang:

            #     # Recognize text
            #     recognized_word = self.recognise(cropped_path, script_lang)
            #     recognized_words.append(recognized_word)

            #     if self.verbose:
            #         print(f"Recognized word: {recognized_word}")


        for id, bbox in enumerate(detections):
            # Identify the script and crop the image to this region
            script_lang, cropped_path = self.crop_and_identify_script(image, bbox)

            # Calculate bounding box coordinates
            x1 = min([bbox[i][0] for i in range(len(bbox))])
            y1 = min([bbox[i][1] for i in range(len(bbox))])
            x2 = max([bbox[i][0] for i in range(len(bbox))])
            y2 = max([bbox[i][1] for i in range(len(bbox))])

            if script_lang:
                recognized_text = self.recognise(cropped_path, script_lang)
                recognized_texts[f"img_{id}"] = {"txt": recognized_text, "bbox": [x1, y1, x2, y2]}

        return detect_para(recognized_texts)
        # return recognized_words

if __name__ == '__main__':
    # detect_model_checkpoint = 'bharatSTR/East/tmp/epoch_990_checkpoint.pth.tar'
    sample_image_path = 'test_images/image_88.jpg'
    cropped_image_path = 'test_images/cropped_image/image_141_0.jpg'

    ocr = OCR(device="cuda", verbose=False)

    # detections = ocr.detect(sample_image_path)
    # print(detections)

    # ocr.visualize_detection(sample_image_path, detections)

    # recognition = ocr.recognise(cropped_image_path, "hindi")
    # print(recognition)

    recognised_words = ocr.ocr(sample_image_path)
    print(recognised_words)