rogerxavier commited on
Commit
1fe6a2c
1 Parent(s): f3b7d3b

Upload 20 files

Browse files
easyocrlite/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from easyocrlite.reader import ReaderLite
easyocrlite/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (304 Bytes). View file
 
easyocrlite/__pycache__/reader.cpython-38.pyc ADDED
Binary file (7 kB). View file
 
easyocrlite/__pycache__/types.cpython-38.pyc ADDED
Binary file (369 Bytes). View file
 
easyocrlite/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .craft import CRAFT
easyocrlite/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (292 Bytes). View file
 
easyocrlite/model/__pycache__/craft.cpython-38.pyc ADDED
Binary file (5.01 kB). View file
 
easyocrlite/model/craft.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from collections import namedtuple
8
+ from typing import Iterable, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torchvision
14
+ from packaging import version
15
+ from torchvision import models
16
+
17
+ VGGOutputs = namedtuple(
18
+ "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
19
+ )
20
+
21
+ def init_weights(modules: Iterable[nn.Module]):
22
+ for m in modules:
23
+ if isinstance(m, nn.Conv2d):
24
+ nn.init.xavier_uniform_(m.weight)
25
+ if m.bias is not None:
26
+ nn.init.zeros_(m.bias)
27
+ elif isinstance(m, nn.BatchNorm2d):
28
+ nn.init.constant_(m.weight, 1.0)
29
+ nn.init.zeros_(m.bias)
30
+ elif isinstance(m, nn.Linear):
31
+ nn.init.normal_(m.weight, 0, 0.01)
32
+ nn.init.zeros_(m.bias)
33
+
34
+
35
+ class VGG16_BN(nn.Module):
36
+ def __init__(self, pretrained: bool=True, freeze: bool=True):
37
+ super().__init__()
38
+ if version.parse(torchvision.__version__) >= version.parse("0.13"):
39
+ vgg_pretrained_features = models.vgg16_bn(
40
+ weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
41
+ ).features
42
+ else: # torchvision.__version__ < 0.13
43
+ models.vgg.model_urls["vgg16_bn"] = models.vgg.model_urls[
44
+ "vgg16_bn"
45
+ ].replace("https://", "http://")
46
+ vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
47
+
48
+ self.slice1 = torch.nn.Sequential()
49
+ self.slice2 = torch.nn.Sequential()
50
+ self.slice3 = torch.nn.Sequential()
51
+ self.slice4 = torch.nn.Sequential()
52
+ self.slice5 = torch.nn.Sequential()
53
+ for x in range(12): # conv2_2
54
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
55
+ for x in range(12, 19): # conv3_3
56
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
57
+ for x in range(19, 29): # conv4_3
58
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
59
+ for x in range(29, 39): # conv5_3
60
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
61
+
62
+ # fc6, fc7 without atrous conv
63
+ self.slice5 = torch.nn.Sequential(
64
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
65
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
66
+ nn.Conv2d(1024, 1024, kernel_size=1),
67
+ )
68
+
69
+ if not pretrained:
70
+ init_weights(self.slice1.modules())
71
+ init_weights(self.slice2.modules())
72
+ init_weights(self.slice3.modules())
73
+ init_weights(self.slice4.modules())
74
+
75
+ init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
76
+
77
+ if freeze:
78
+ for param in self.slice1.parameters(): # only first conv
79
+ param.requires_grad = False
80
+
81
+ def forward(self, x: torch.Tensor) -> VGGOutputs:
82
+ h = self.slice1(x)
83
+ h_relu2_2 = h
84
+ h = self.slice2(h)
85
+ h_relu3_2 = h
86
+ h = self.slice3(h)
87
+ h_relu4_3 = h
88
+ h = self.slice4(h)
89
+ h_relu5_3 = h
90
+ h = self.slice5(h)
91
+ h_fc7 = h
92
+
93
+ out = VGGOutputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
94
+ return out
95
+
96
+
97
+ class DoubleConv(nn.Module):
98
+ def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
99
+ super().__init__()
100
+ self.conv = nn.Sequential(
101
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
102
+ nn.BatchNorm2d(mid_ch),
103
+ nn.ReLU(inplace=True),
104
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
105
+ nn.BatchNorm2d(out_ch),
106
+ nn.ReLU(inplace=True),
107
+ )
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ x = self.conv(x)
111
+ return x
112
+
113
+
114
+ class CRAFT(nn.Module):
115
+ def __init__(self, pretrained: bool=False, freeze: bool=False):
116
+ super(CRAFT, self).__init__()
117
+
118
+ """ Base network """
119
+ self.basenet = VGG16_BN(pretrained, freeze)
120
+
121
+ """ U network """
122
+ self.upconv1 = DoubleConv(1024, 512, 256)
123
+ self.upconv2 = DoubleConv(512, 256, 128)
124
+ self.upconv3 = DoubleConv(256, 128, 64)
125
+ self.upconv4 = DoubleConv(128, 64, 32)
126
+
127
+ num_class = 2
128
+ self.conv_cls = nn.Sequential(
129
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
130
+ nn.ReLU(inplace=True),
131
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
132
+ nn.ReLU(inplace=True),
133
+ nn.Conv2d(32, 16, kernel_size=3, padding=1),
134
+ nn.ReLU(inplace=True),
135
+ nn.Conv2d(16, 16, kernel_size=1),
136
+ nn.ReLU(inplace=True),
137
+ nn.Conv2d(16, num_class, kernel_size=1),
138
+ )
139
+
140
+ init_weights(self.upconv1.modules())
141
+ init_weights(self.upconv2.modules())
142
+ init_weights(self.upconv3.modules())
143
+ init_weights(self.upconv4.modules())
144
+ init_weights(self.conv_cls.modules())
145
+
146
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ """Base network"""
148
+ sources = self.basenet(x)
149
+
150
+ """ U network """
151
+ y = torch.cat([sources[0], sources[1]], dim=1)
152
+ y = self.upconv1(y)
153
+
154
+ y = F.interpolate(
155
+ y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
156
+ )
157
+ y = torch.cat([y, sources[2]], dim=1)
158
+ y = self.upconv2(y)
159
+
160
+ y = F.interpolate(
161
+ y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
162
+ )
163
+ y = torch.cat([y, sources[3]], dim=1)
164
+ y = self.upconv3(y)
165
+
166
+ y = F.interpolate(
167
+ y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
168
+ )
169
+ y = torch.cat([y, sources[4]], dim=1)
170
+ feature = self.upconv4(y)
171
+
172
+ y = self.conv_cls(feature)
173
+
174
+ return y.permute(0, 2, 3, 1), feature
easyocrlite/reader.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Union
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Tuple
8
+
9
+ import PIL.Image
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image, ImageEnhance
14
+
15
+ from easyocrlite.model import CRAFT
16
+
17
+ from easyocrlite.utils.download_utils import prepare_model
18
+ from easyocrlite.utils.image_utils import (
19
+ adjust_result_coordinates,
20
+ boxed_transform,
21
+ normalize_mean_variance,
22
+ resize_aspect_ratio,
23
+ )
24
+ from easyocrlite.utils.detect_utils import (
25
+ extract_boxes,
26
+ extract_regions_from_boxes,
27
+ box_expand,
28
+ greedy_merge,
29
+ )
30
+ from easyocrlite.types import BoxTuple, RegionTuple
31
+ import easyocrlite.utils.utils as utils
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ MODULE_PATH = (
36
+ os.environ.get("EASYOCR_MODULE_PATH")
37
+ or os.environ.get("MODULE_PATH")
38
+ or os.path.expanduser("~/.EasyOCR/")
39
+ )
40
+
41
+
42
+ class ReaderLite(object):
43
+ def __init__(
44
+ self,
45
+ gpu=True,
46
+ model_storage_directory=None,
47
+ download_enabled=True,
48
+ verbose=True,
49
+ quantize=True,
50
+ cudnn_benchmark=False,
51
+ ):
52
+
53
+ self.verbose = verbose
54
+
55
+ model_storage_directory = Path(
56
+ model_storage_directory
57
+ if model_storage_directory
58
+ else MODULE_PATH + "/model"
59
+ )
60
+ self.detector_path = prepare_model(
61
+ model_storage_directory, download_enabled, verbose
62
+ )
63
+
64
+ self.quantize = quantize
65
+ self.cudnn_benchmark = cudnn_benchmark
66
+ if gpu is False:
67
+ self.device = "cpu"
68
+ if verbose:
69
+ logger.warning(
70
+ "Using CPU. Note: This module is much faster with a GPU."
71
+ )
72
+ elif not torch.cuda.is_available():
73
+ self.device = "cpu"
74
+ if verbose:
75
+ logger.warning(
76
+ "CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU."
77
+ )
78
+ elif gpu is True:
79
+ self.device = "cuda"
80
+ else:
81
+ self.device = gpu
82
+
83
+ self.detector = CRAFT()
84
+
85
+ state_dict = torch.load(self.detector_path, map_location=self.device)
86
+ if list(state_dict.keys())[0].startswith("module"):
87
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
88
+
89
+ self.detector.load_state_dict(state_dict)
90
+
91
+ if self.device == "cpu":
92
+ if self.quantize:
93
+ try:
94
+ torch.quantization.quantize_dynamic(
95
+ self.detector, dtype=torch.qint8, inplace=True
96
+ )
97
+ except:
98
+ pass
99
+ else:
100
+ self.detector = torch.nn.DataParallel(self.detector).to(self.device)
101
+ import torch.backends.cudnn as cudnn
102
+
103
+ cudnn.benchmark = self.cudnn_benchmark
104
+
105
+ self.detector.eval()
106
+
107
+ def process(
108
+ self,
109
+ image_path: Union[str, PIL.Image.Image],
110
+ max_size: int = 960,
111
+ expand_ratio: float = 1.0,
112
+ sharp: float = 1.0,
113
+ contrast: float = 1.0,
114
+ text_confidence: float = 0.7,
115
+ text_threshold: float = 0.4,
116
+ link_threshold: float = 0.4,
117
+ slope_ths: float = 0.1,
118
+ ratio_ths: float = 0.5,
119
+ center_ths: float = 0.5,
120
+ dim_ths: float = 0.5,
121
+ space_ths: float = 1.0,
122
+ add_margin: float = 0.1,
123
+ min_size: float = 0.01,
124
+ ) -> Tuple[BoxTuple, list[np.ndarray]]:
125
+ if isinstance(image_path, str):
126
+ image = Image.open(image_path).convert('RGB')
127
+ elif isinstance(image_path, PIL.Image.Image):
128
+ image = image_path.convert('RGB')
129
+ tensor, inverse_ratio = self.preprocess(
130
+ image, max_size, expand_ratio, sharp, contrast
131
+ )
132
+
133
+ scores = self.forward_net(tensor)
134
+
135
+ boxes = self.detect(scores, text_confidence, text_threshold, link_threshold)
136
+
137
+ image = np.array(image)
138
+ region_list, box_list = self.postprocess(
139
+ image,
140
+ boxes,
141
+ inverse_ratio,
142
+ slope_ths,
143
+ ratio_ths,
144
+ center_ths,
145
+ dim_ths,
146
+ space_ths,
147
+ add_margin,
148
+ min_size,
149
+ )
150
+
151
+ # get cropped image
152
+ image_list = []
153
+ for region in region_list:
154
+ x_min, x_max, y_min, y_max = region
155
+ crop_img = image[y_min:y_max, x_min:x_max, :]
156
+ image_list.append(
157
+ (
158
+ ((x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)),
159
+ crop_img,
160
+ )
161
+ )
162
+
163
+ for box in box_list:
164
+ transformed_img = boxed_transform(image, np.array(box, dtype="float32"))
165
+ image_list.append((box, transformed_img))
166
+
167
+ # sort by top left point
168
+ image_list = sorted(image_list, key=lambda x: (x[0][0][1], x[0][0][0]))
169
+
170
+ return image_list
171
+
172
+ def preprocess(
173
+ self,
174
+ image: Image.Image,
175
+ max_size: int,
176
+ expand_ratio: float = 1.0,
177
+ sharp: float = 1.0,
178
+ contrast: float = 1.0,
179
+ ) -> torch.Tensor:
180
+ if sharp != 1:
181
+ enhancer = ImageEnhance.Sharpness(image)
182
+ image = enhancer.enhance(sharp)
183
+ if contrast != 1:
184
+ enhancer = ImageEnhance.Contrast(image)
185
+ image = enhancer.enhance(contrast)
186
+
187
+ image = np.array(image)
188
+
189
+ image, target_ratio = resize_aspect_ratio(
190
+ image, max_size, interpolation=cv2.INTER_LINEAR, expand_ratio=expand_ratio
191
+ )
192
+ inverse_ratio = 1 / target_ratio
193
+
194
+ x = np.transpose(normalize_mean_variance(image), (2, 0, 1))
195
+
196
+ x = torch.tensor(np.array([x]), device=self.device)
197
+
198
+ return x, inverse_ratio
199
+
200
+ @torch.no_grad()
201
+ def forward_net(self, tensor: torch.Tensor) -> torch.Tensor:
202
+ scores, feature = self.detector(tensor)
203
+ return scores[0]
204
+
205
+ def detect(
206
+ self,
207
+ scores: torch.Tensor,
208
+ text_confidence: float = 0.7,
209
+ text_threshold: float = 0.4,
210
+ link_threshold: float = 0.4,
211
+ ) -> list[BoxTuple]:
212
+ # make score and link map
213
+ score_text = scores[:, :, 0].cpu().data.numpy()
214
+ score_link = scores[:, :, 1].cpu().data.numpy()
215
+ # extract box
216
+ boxes, _ = extract_boxes(
217
+ score_text, score_link, text_confidence, text_threshold, link_threshold
218
+ )
219
+ return boxes
220
+
221
+ def postprocess(
222
+ self,
223
+ image: np.ndarray,
224
+ boxes: list[BoxTuple],
225
+ inverse_ratio: float,
226
+ slope_ths: float = 0.1,
227
+ ratio_ths: float = 0.5,
228
+ center_ths: float = 0.5,
229
+ dim_ths: float = 0.5,
230
+ space_ths: float = 1.0,
231
+ add_margin: float = 0.1,
232
+ min_size: int = 0,
233
+ ) -> Tuple[list[RegionTuple], list[BoxTuple]]:
234
+
235
+ # coordinate adjustment
236
+ boxes = adjust_result_coordinates(boxes, inverse_ratio)
237
+
238
+ max_y, max_x, _ = image.shape
239
+
240
+ # extract region and merge
241
+ region_list, box_list = extract_regions_from_boxes(boxes, slope_ths)
242
+
243
+ region_list = greedy_merge(
244
+ region_list,
245
+ ratio_ths=ratio_ths,
246
+ center_ths=center_ths,
247
+ dim_ths=dim_ths,
248
+ space_ths=space_ths,
249
+ verbose=0
250
+ )
251
+
252
+ # add margin
253
+ region_list = [
254
+ region.expand(add_margin, (max_x, max_y)).as_tuple()
255
+ for region in region_list
256
+ ]
257
+
258
+ box_list = [box_expand(box, add_margin, (max_x, max_y)) for box in box_list]
259
+
260
+ # filter by size
261
+ if min_size:
262
+ if min_size < 1:
263
+ min_size = int(min(max_y, max_x) * min_size)
264
+
265
+ region_list = [
266
+ i for i in region_list if max(i[1] - i[0], i[3] - i[2]) > min_size
267
+ ]
268
+ box_list = [
269
+ i
270
+ for i in box_list
271
+ if max(utils.diff([c[0] for c in i]), utils.diff([c[1] for c in i]))
272
+ > min_size
273
+ ]
274
+
275
+ return region_list, box_list
easyocrlite/types.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ Point = Tuple[int, int]
4
+ BoxTuple = Tuple[Point, Point, Point, Point]
5
+ RegionTuple = Tuple[int, int, int, int]
easyocrlite/utils/__init__.py ADDED
File without changes
easyocrlite/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (257 Bytes). View file
 
easyocrlite/utils/__pycache__/detect_utils.cpython-38.pyc ADDED
Binary file (7.51 kB). View file
 
easyocrlite/utils/__pycache__/download_utils.cpython-38.pyc ADDED
Binary file (3.04 kB). View file
 
easyocrlite/utils/__pycache__/image_utils.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
 
easyocrlite/utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.36 kB). View file
 
easyocrlite/utils/detect_utils.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import logging
5
+ import math
6
+ import operator
7
+ from collections import namedtuple
8
+ from functools import cached_property
9
+ from typing import Iterable, Optional, Tuple
10
+
11
+ import cv2
12
+ import numpy as np
13
+ from easyocrlite.types import BoxTuple, RegionTuple
14
+ from easyocrlite.utils.utils import grouped_by
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class Region(namedtuple("Region", ["x_min", "x_max", "y_min", "y_max"])):
19
+ @cached_property
20
+ def ycenter(self):
21
+ return 0.5 * (self.y_min + self.y_max)
22
+
23
+ @cached_property
24
+ def xcenter(self):
25
+ return 0.5 * (self.x_min + self.x_max)
26
+
27
+ @cached_property
28
+ def height(self):
29
+ return self.y_max - self.y_min
30
+
31
+ @cached_property
32
+ def width(self):
33
+ return self.x_max - self.x_min
34
+
35
+ @classmethod
36
+ def from_box(cls, box: BoxTuple) -> Region:
37
+ (xtl, ytl), (xtr, ytr), (xbr, ybr), (xbl, ybl) = box
38
+
39
+ x_max = max(xtl, xtr, xbr, xbl)
40
+ x_min = min(xtl, xtr, xbr, xbl)
41
+ y_max = max(ytl, ytr, ybr, ybl)
42
+ y_min = min(ytl, ytr, ybr, ybl)
43
+
44
+ return cls(x_min, x_max, y_min, y_max)
45
+
46
+ def as_tuple(self) -> RegionTuple:
47
+ return self.x_min, self.x_max, self.y_min, self.y_max
48
+
49
+ def expand(
50
+ self, add_margin: float, size: Optional[Tuple[int, int] | int] = None
51
+ ) -> Region:
52
+
53
+ margin = int(add_margin * min(self.width, self.height))
54
+ if isinstance(size, Iterable):
55
+ max_x, max_y = size
56
+ elif size is None:
57
+ max_x = self.width * 2
58
+ max_y = self.height * 2
59
+ else:
60
+ max_x = max_y = size
61
+
62
+ return Region(
63
+ max(0, self.x_min - margin),
64
+ min(max_x, self.x_max + margin),
65
+ max(0, self.y_min - margin),
66
+ min(max_y, self.y_max + margin),
67
+ )
68
+
69
+ def __add__(self, region: Region) -> Region:
70
+ return Region(
71
+ min(self.x_min, region.x_min),
72
+ max(self.x_max, region.x_max),
73
+ min(self.y_min, region.y_min),
74
+ max(self.y_max, region.y_max),
75
+ )
76
+
77
+ def extract_boxes(
78
+ textmap: np.ndarray,
79
+ linkmap: np.ndarray,
80
+ text_threshold: float,
81
+ link_threshold: float,
82
+ low_text: float,
83
+ ) -> Tuple[list[BoxTuple], list[int]]:
84
+ # prepare data
85
+ linkmap = linkmap.copy()
86
+ textmap = textmap.copy()
87
+ img_h, img_w = textmap.shape
88
+
89
+ """ labeling method """
90
+ ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
91
+ ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
92
+
93
+ text_score_comb = np.clip(text_score + link_score, 0, 1)
94
+ nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
95
+ text_score_comb.astype(np.uint8), connectivity=4
96
+ )
97
+
98
+ boxes = []
99
+ mapper = []
100
+ for k in range(1, nLabels):
101
+ # size filtering
102
+ size = stats[k, cv2.CC_STAT_AREA]
103
+ if size < 10:
104
+ continue
105
+
106
+ # thresholding
107
+ if np.max(textmap[labels == k]) < text_threshold:
108
+ continue
109
+
110
+ # make segmentation map
111
+ segmap = np.zeros(textmap.shape, dtype=np.uint8)
112
+ segmap[labels == k] = 255
113
+
114
+ mapper.append(k)
115
+ segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 # remove link area
116
+ x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
117
+ w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
118
+ niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
119
+ sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
120
+ # boundary check
121
+ if sx < 0:
122
+ sx = 0
123
+ if sy < 0:
124
+ sy = 0
125
+ if ex >= img_w:
126
+ ex = img_w
127
+ if ey >= img_h:
128
+ ey = img_h
129
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter))
130
+ segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
131
+
132
+ # make box
133
+ np_contours = (
134
+ np.roll(np.array(np.where(segmap != 0)), 1, axis=0)
135
+ .transpose()
136
+ .reshape(-1, 2)
137
+ )
138
+ rectangle = cv2.minAreaRect(np_contours)
139
+ box = cv2.boxPoints(rectangle)
140
+
141
+ # align diamond-shape
142
+ w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
143
+ box_ratio = max(w, h) / (min(w, h) + 1e-5)
144
+ if abs(1 - box_ratio) <= 0.1:
145
+ l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
146
+ t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
147
+ box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
148
+
149
+ # make clock-wise order
150
+ startidx = box.sum(axis=1).argmin()
151
+ box = np.roll(box, 4 - startidx, 0)
152
+ box = np.array(box)
153
+ boxes.append(box)
154
+
155
+ return boxes, mapper
156
+
157
+
158
+ def extract_regions_from_boxes(
159
+ boxes: list[BoxTuple], slope_ths: float
160
+ ) -> Tuple[list[Region], list[BoxTuple]]:
161
+
162
+ region_list: list[Region] = []
163
+ box_list = []
164
+
165
+ for box in boxes:
166
+ box = np.array(box).astype(np.int32)
167
+ (xtl, ytl), (xtr, ytr), (xbr, ybr), (xbl, ybl) = box
168
+
169
+ # get the tan of top and bottom edge
170
+ # why 10?
171
+ slope_top = (ytr - ytl) / max(10, xtr - xtl)
172
+ slope_bottom = (ybr - ybl) / max(10, xbr - xbl)
173
+ if max(abs(slope_top), abs(slope_bottom)) < slope_ths:
174
+ # not very tilted, rectangle box
175
+ region_list.append(Region.from_box(box))
176
+ else:
177
+ # tilted
178
+ box_list.append(box)
179
+ return region_list, box_list
180
+
181
+
182
+ def box_expand(
183
+ box: BoxTuple, add_margin: float, size: Optional[Tuple[int, int] | int] = None
184
+ ) -> BoxTuple:
185
+
186
+ (xtl, ytl), (xtr, ytr), (xbr, ybr), (xbl, ybl) = box
187
+ height = np.linalg.norm([xbl - xtl, ybl - ytl]) # from top left to bottom left
188
+ width = np.linalg.norm([xtr - xtl, ytr - ytl]) # from top left to top right
189
+
190
+ # margin is added based on the diagonal
191
+ margin = int(1.44 * add_margin * min(width, height))
192
+
193
+ theta13 = abs(np.arctan((ytl - ybr) / max(10, (xtl - xbr))))
194
+ theta24 = abs(np.arctan((ytr - ybl) / max(10, (xtr - xbl))))
195
+
196
+ if isinstance(size, Iterable):
197
+ max_x, max_y = size
198
+ elif size is None:
199
+ max_x = width * 2
200
+ max_y = height * 2
201
+ else:
202
+ max_x = max_y = size
203
+
204
+ new_box = (
205
+ (
206
+ max(0, int(xtl - np.cos(theta13) * margin)),
207
+ max(0, int(ytl - np.sin(theta13) * margin)),
208
+ ),
209
+ (
210
+ min(max_x, math.ceil(xtr + np.cos(theta24) * margin)),
211
+ max(0, int(ytr - np.sin(theta24) * margin)),
212
+ ),
213
+ (
214
+ min(max_x, math.ceil(xbr + np.cos(theta13) * margin)),
215
+ min(max_y, math.ceil(ybr + np.sin(theta13) * margin)),
216
+ ),
217
+ (
218
+ max(0, int(xbl - np.cos(theta24) * margin)),
219
+ min(max_y, math.ceil(ybl + np.sin(theta24) * margin)),
220
+ ),
221
+ )
222
+ return new_box
223
+
224
+
225
+ def greedy_merge(
226
+ regions: list[Region],
227
+ ratio_ths: float = 0.5,
228
+ center_ths: float = 0.5,
229
+ dim_ths: float = 0.5,
230
+ space_ths: float = 1.0,
231
+ verbose: int = 4,
232
+ ) -> list[Region]:
233
+
234
+ regions = sorted(regions, key=operator.attrgetter("ycenter"))
235
+
236
+ # grouped by ycenter
237
+ groups = grouped_by(
238
+ regions,
239
+ operator.attrgetter("ycenter"),
240
+ center_ths,
241
+ operator.attrgetter("height"),
242
+ )
243
+ for group in groups:
244
+ group.sort(key=operator.attrgetter("x_min"))
245
+ idx = 0
246
+ while idx < len(group) - 1:
247
+ region1, region2 = group[idx], group[idx + 1]
248
+ # both are horizontal regions
249
+ cond = (region1.width / region1.height) > ratio_ths and (
250
+ region2.width / region2.height
251
+ ) > ratio_ths
252
+ # similar heights
253
+ cond = cond and abs(region1.height - region2.height) < dim_ths * np.mean(
254
+ [region1.height, region2.height]
255
+ )
256
+ # similar ycenters
257
+ # cond = cond and abs(region1.ycenter - region2.ycenter) < center_ths * np.mean(
258
+ # [region1.height, region2.height]
259
+ # )
260
+ # horizontal space is small
261
+ cond = cond and (region2.x_min - region1.x_max) < space_ths * np.mean(
262
+ [region1.height, region2.height]
263
+ )
264
+ if cond:
265
+ # merge regiona
266
+ region = region1 + region2
267
+
268
+ if verbose > 2:
269
+ logger.debug(f"horizontal merging {region1} {region2}")
270
+ group.pop(idx)
271
+ group.pop(idx)
272
+ group.insert(idx, region)
273
+
274
+ else:
275
+ if verbose > 0:
276
+ logger.debug(f"not horizontal merging {region1} {region2}")
277
+ idx += 1
278
+
279
+ # flatten groups
280
+ regions = list(itertools.chain.from_iterable(groups))
281
+
282
+ # grouped by xcenter
283
+ groups = grouped_by(
284
+ regions,
285
+ operator.attrgetter("xcenter"),
286
+ center_ths,
287
+ operator.attrgetter("width"),
288
+ )
289
+
290
+ for group in groups:
291
+ group.sort(key=operator.attrgetter("y_min"))
292
+ idx = 0
293
+ while idx < len(group) - 1:
294
+ region1, region2 = group[idx], group[idx + 1]
295
+ # both are vertical regions
296
+ cond = (region1.height / region1.width) > ratio_ths and (
297
+ region2.height / region2.width
298
+ ) > ratio_ths
299
+ # similar widths
300
+ cond = cond and abs(region1.width - region2.width) < dim_ths * np.mean(
301
+ [region1.width, region2.width]
302
+ )
303
+ # # similar xcenters
304
+ # cond = cond and abs(region1.xcenter - region2.xcenter) < center_ths * np.mean(
305
+ # [region1.width, region2.width]
306
+ # )
307
+ # vertical space is small
308
+ cond = cond and (region2.y_min - region1.y_max) < space_ths * np.mean(
309
+ [region1.width, region2.width]
310
+ )
311
+ if cond:
312
+ # merge region
313
+ region = region1 + region2
314
+ if verbose > 2:
315
+ logger.debug(f"vertical merging {region1} {region2}")
316
+ group.pop(idx)
317
+ group.pop(idx)
318
+ group.insert(idx, region)
319
+ else:
320
+ if verbose > 1:
321
+ logger.debug(f"not vertical merging {region1} {region2}")
322
+ idx += 1
323
+
324
+ # flatten groups
325
+ regions = list(itertools.chain.from_iterable(groups))
326
+
327
+ return regions
easyocrlite/utils/download_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Callable, Optional
5
+ from urllib.request import urlretrieve
6
+ from zipfile import ZipFile
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ FILENAME = "craft_mlt_25k.pth"
11
+ URL = (
12
+ "https://xc-models.oss-cn-zhangjiakou.aliyuncs.com/modelscope/studio/easyocr/craft_mlt_25k.zip"
13
+ )
14
+ MD5SUM = "2f8227d2def4037cdb3b34389dcf9ec1"
15
+ MD5MSG = "MD5 hash mismatch, possible file corruption"
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def calculate_md5(path: Path) -> str:
22
+ hash_md5 = hashlib.md5()
23
+ with open(path, "rb") as f:
24
+ for chunk in iter(lambda: f.read(4096), b""):
25
+ hash_md5.update(chunk)
26
+ return hash_md5.hexdigest()
27
+
28
+
29
+ def print_progress_bar(t: tqdm) -> Callable[[int, int, Optional[int]], None]:
30
+ last = 0
31
+
32
+ def update_to(
33
+ count: int = 1, block_size: int = 1, total_size: Optional[int] = None
34
+ ):
35
+ nonlocal last
36
+ if total_size is not None:
37
+ t.total = total_size
38
+ t.update((count - last) * block_size)
39
+ last = count
40
+
41
+ return update_to
42
+
43
+
44
+ def download_and_unzip(
45
+ url: str, filename: str, model_storage_directory: Path, verbose: bool = True
46
+ ):
47
+ zip_path = model_storage_directory / "temp.zip"
48
+ with tqdm(
49
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, disable=not verbose
50
+ ) as t:
51
+ reporthook = print_progress_bar(t)
52
+ urlretrieve(url, str(zip_path), reporthook=reporthook)
53
+ with ZipFile(zip_path, "r") as zipObj:
54
+ zipObj.extract(filename, str(model_storage_directory))
55
+ zip_path.unlink()
56
+
57
+
58
+ def prepare_model(model_storage_directory: Path, download=True, verbose: bool = True) -> bool:
59
+ model_storage_directory.mkdir(parents=True, exist_ok=True)
60
+
61
+ detector_path = model_storage_directory / FILENAME
62
+
63
+ # try get model path
64
+ model_available = False
65
+ if not detector_path.is_file():
66
+ if not download:
67
+ raise FileNotFoundError(f"Missing {detector_path} and downloads disabled")
68
+ logger.info(
69
+ "Downloading detection model, please wait. "
70
+ "This may take several minutes depending upon your network connection."
71
+ )
72
+ elif calculate_md5(detector_path) != MD5SUM:
73
+ logger.warning(MD5MSG)
74
+ if not download:
75
+ raise FileNotFoundError(
76
+ f"MD5 mismatch for {detector_path} and downloads disabled"
77
+ )
78
+ detector_path.unlink()
79
+ logger.info(
80
+ "Re-downloading the detection model, please wait. "
81
+ "This may take several minutes depending upon your network connection."
82
+ )
83
+ else:
84
+ model_available = True
85
+
86
+ if not model_available:
87
+ download_and_unzip(URL, FILENAME, model_storage_directory, verbose)
88
+ if calculate_md5(detector_path) != MD5SUM:
89
+ raise ValueError(MD5MSG)
90
+ logger.info("Download complete")
91
+
92
+ return detector_path
easyocrlite/utils/image_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from easyocrlite.types import BoxTuple
8
+
9
+
10
+ def resize_aspect_ratio(
11
+ img: np.ndarray, max_size: int, interpolation: int, expand_ratio: float = 1.0
12
+ ) -> Tuple[np.ndarray, float]:
13
+ height, width, channel = img.shape
14
+
15
+ # magnify image size
16
+ target_size = expand_ratio * max(height, width)
17
+
18
+ # set original image size
19
+ if max_size and max_size > 0 and target_size > max_size:
20
+ target_size = max_size
21
+
22
+ ratio = target_size / max(height, width)
23
+
24
+ target_h, target_w = int(height * ratio), int(width * ratio)
25
+
26
+ if target_h != height or target_w != width:
27
+ proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
28
+ # make canvas and paste image
29
+ target_h32, target_w32 = target_h, target_w
30
+ if target_h % 32 != 0:
31
+ target_h32 = target_h + (32 - target_h % 32)
32
+ if target_w % 32 != 0:
33
+ target_w32 = target_w + (32 - target_w % 32)
34
+ resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
35
+ resized[0:target_h, 0:target_w, :] = proc
36
+ target_h, target_w = target_h32, target_w32
37
+ else:
38
+ resized = img
39
+ return resized, ratio
40
+
41
+
42
+ def adjust_result_coordinates(
43
+ box: BoxTuple, inverse_ratio: int = 1, ratio_net: int = 2
44
+ ) -> np.ndarray:
45
+ if len(box) > 0:
46
+ box = np.array(box)
47
+ for k in range(len(box)):
48
+ if box[k] is not None:
49
+ box[k] *= (inverse_ratio * ratio_net, inverse_ratio * ratio_net)
50
+ return box
51
+
52
+
53
+ def normalize_mean_variance(
54
+ in_img: np.ndarray,
55
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
56
+ variance: Tuple[float, float, float] = (0.229, 0.224, 0.225),
57
+ ) -> np.ndarray:
58
+ # should be RGB order
59
+ img = in_img.copy().astype(np.float32)
60
+
61
+ img -= np.array(
62
+ [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
63
+ )
64
+ img /= np.array(
65
+ [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
66
+ dtype=np.float32,
67
+ )
68
+ return img
69
+
70
+ def boxed_transform(image: np.ndarray, box: BoxTuple) -> np.ndarray:
71
+ (tl, tr, br, bl) = box
72
+
73
+ widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
74
+ widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
75
+ maxWidth = max(int(widthA), int(widthB))
76
+
77
+ # compute the height of the new image, which will be the
78
+ # maximum distance between the top-right and bottom-right
79
+ # y-coordinates or the top-left and bottom-left y-coordinates
80
+ heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
81
+ heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
82
+ maxHeight = max(int(heightA), int(heightB))
83
+
84
+ dst = np.array(
85
+ [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
86
+ dtype="float32",
87
+ )
88
+
89
+ # compute the perspective transform matrix and then apply it
90
+ M = cv2.getPerspectiveTransform(box, dst)
91
+ warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
92
+
93
+ return warped
easyocrlite/utils/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from typing import Iterable, TypeVar, Callable
6
+
7
+ T = TypeVar("T")
8
+ V = TypeVar("V")
9
+
10
+
11
+ def diff(input_list: Iterable[T]) -> T:
12
+ return max(input_list) - min(input_list)
13
+
14
+
15
+ def grouped_by(
16
+ items: list[T],
17
+ group_key: Callable[[T], V],
18
+ eps: float,
19
+ eps_key: Callable[[T], float],
20
+ ) -> list[list[T]]:
21
+ items = sorted(items, key=group_key)
22
+
23
+ groups = []
24
+ group = []
25
+
26
+ for item in items:
27
+ if not group:
28
+ group.append(item)
29
+ continue
30
+
31
+ if group:
32
+ cond = abs(
33
+ group_key(item) - np.mean([group_key(item) for item in group])
34
+ ) < eps * np.mean([eps_key(item) for item in group])
35
+ if cond:
36
+ group.append(item)
37
+ else:
38
+ groups.append(group)
39
+ group = [item]
40
+ else:
41
+ if group:
42
+ groups.append(group)
43
+ return groups