|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import sys |
|
|
|
__dir__ = os.path.dirname(__file__) |
|
sys.path.append(__dir__) |
|
sys.path.append(os.path.join(__dir__, '..')) |
|
|
|
import numpy as np |
|
from .locality_aware_nms import nms_locality |
|
import paddle |
|
import cv2 |
|
import time |
|
|
|
|
|
class SASTPostProcess(object): |
|
""" |
|
The post process for SAST. |
|
""" |
|
|
|
def __init__(self, |
|
score_thresh=0.5, |
|
nms_thresh=0.2, |
|
sample_pts_num=2, |
|
shrink_ratio_of_width=0.3, |
|
expand_scale=1.0, |
|
tcl_map_thresh=0.5, |
|
**kwargs): |
|
|
|
self.score_thresh = score_thresh |
|
self.nms_thresh = nms_thresh |
|
self.sample_pts_num = sample_pts_num |
|
self.shrink_ratio_of_width = shrink_ratio_of_width |
|
self.expand_scale = expand_scale |
|
self.tcl_map_thresh = tcl_map_thresh |
|
|
|
|
|
self.is_python35 = False |
|
if sys.version_info.major == 3 and sys.version_info.minor == 5: |
|
self.is_python35 = True |
|
|
|
def point_pair2poly(self, point_pair_list): |
|
""" |
|
Transfer vertical point_pairs into poly point in clockwise. |
|
""" |
|
|
|
point_num = len(point_pair_list) * 2 |
|
point_list = [0] * point_num |
|
for idx, point_pair in enumerate(point_pair_list): |
|
point_list[idx] = point_pair[0] |
|
point_list[point_num - 1 - idx] = point_pair[1] |
|
return np.array(point_list).reshape(-1, 2) |
|
|
|
def shrink_quad_along_width(self, |
|
quad, |
|
begin_width_ratio=0., |
|
end_width_ratio=1.): |
|
""" |
|
Generate shrink_quad_along_width. |
|
""" |
|
ratio_pair = np.array( |
|
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32) |
|
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair |
|
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair |
|
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) |
|
|
|
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): |
|
""" |
|
expand poly along width. |
|
""" |
|
point_num = poly.shape[0] |
|
left_quad = np.array( |
|
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) |
|
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ |
|
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) |
|
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, |
|
1.0) |
|
right_quad = np.array( |
|
[ |
|
poly[point_num // 2 - 2], poly[point_num // 2 - 1], |
|
poly[point_num // 2], poly[point_num // 2 + 1] |
|
], |
|
dtype=np.float32) |
|
right_ratio = 1.0 + \ |
|
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ |
|
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) |
|
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, |
|
right_ratio) |
|
poly[0] = left_quad_expand[0] |
|
poly[-1] = left_quad_expand[-1] |
|
poly[point_num // 2 - 1] = right_quad_expand[1] |
|
poly[point_num // 2] = right_quad_expand[2] |
|
return poly |
|
|
|
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): |
|
"""Restore quad.""" |
|
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) |
|
xy_text = xy_text[:, ::-1] |
|
|
|
|
|
xy_text = xy_text[np.argsort(xy_text[:, 1])] |
|
|
|
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] |
|
scores = scores[:, np.newaxis] |
|
|
|
|
|
point_num = int(tvo_map.shape[-1] / 2) |
|
assert point_num == 4 |
|
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] |
|
xy_text_tile = np.tile(xy_text, (1, point_num)) |
|
quads = xy_text_tile - tvo_map |
|
|
|
return scores, quads, xy_text |
|
|
|
def quad_area(self, quad): |
|
""" |
|
compute area of a quad. |
|
""" |
|
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), |
|
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), |
|
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), |
|
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])] |
|
return np.sum(edge) / 2. |
|
|
|
def nms(self, dets): |
|
if self.is_python35: |
|
import lanms |
|
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh) |
|
else: |
|
dets = nms_locality(dets, self.nms_thresh) |
|
return dets |
|
|
|
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): |
|
""" |
|
Cluster pixels in tcl_map based on quads. |
|
""" |
|
instance_count = quads.shape[0] + 1 |
|
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) |
|
if instance_count == 1: |
|
return instance_count, instance_label_map |
|
|
|
|
|
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) |
|
n = xy_text.shape[0] |
|
xy_text = xy_text[:, ::-1] |
|
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] |
|
pred_tc = xy_text - tco |
|
|
|
|
|
m = quads.shape[0] |
|
gt_tc = np.mean(quads, axis=1) |
|
|
|
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], |
|
(1, m, 1)) |
|
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) |
|
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) |
|
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 |
|
|
|
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign |
|
return instance_count, instance_label_map |
|
|
|
def estimate_sample_pts_num(self, quad, xy_text): |
|
""" |
|
Estimate sample points number. |
|
""" |
|
eh = (np.linalg.norm(quad[0] - quad[3]) + |
|
np.linalg.norm(quad[1] - quad[2])) / 2.0 |
|
ew = (np.linalg.norm(quad[0] - quad[1]) + |
|
np.linalg.norm(quad[2] - quad[3])) / 2.0 |
|
|
|
dense_sample_pts_num = max(2, int(ew)) |
|
dense_xy_center_line = xy_text[np.linspace( |
|
0, |
|
xy_text.shape[0] - 1, |
|
dense_sample_pts_num, |
|
endpoint=True, |
|
dtype=np.float32).astype(np.int32)] |
|
|
|
dense_xy_center_line_diff = dense_xy_center_line[ |
|
1:] - dense_xy_center_line[:-1] |
|
estimate_arc_len = np.sum( |
|
np.linalg.norm( |
|
dense_xy_center_line_diff, axis=1)) |
|
|
|
sample_pts_num = max(2, int(estimate_arc_len / eh)) |
|
return sample_pts_num |
|
|
|
def detect_sast(self, |
|
tcl_map, |
|
tvo_map, |
|
tbo_map, |
|
tco_map, |
|
ratio_w, |
|
ratio_h, |
|
src_w, |
|
src_h, |
|
shrink_ratio_of_width=0.3, |
|
tcl_map_thresh=0.5, |
|
offset_expand=1.0, |
|
out_strid=4.0): |
|
""" |
|
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys |
|
""" |
|
|
|
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, |
|
tvo_map) |
|
dets = np.hstack((quads, scores)).astype(np.float32, copy=False) |
|
dets = self.nms(dets) |
|
if dets.shape[0] == 0: |
|
return [] |
|
quads = dets[:, :-1].reshape(-1, 4, 2) |
|
|
|
|
|
quad_areas = [] |
|
for quad in quads: |
|
quad_areas.append(-self.quad_area(quad)) |
|
|
|
|
|
|
|
instance_count, instance_label_map = self.cluster_by_quads_tco( |
|
tcl_map, tcl_map_thresh, quads, tco_map) |
|
|
|
|
|
poly_list = [] |
|
for instance_idx in range(1, instance_count): |
|
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1] |
|
quad = quads[instance_idx - 1] |
|
q_area = quad_areas[instance_idx - 1] |
|
if q_area < 5: |
|
continue |
|
|
|
|
|
len1 = float(np.linalg.norm(quad[0] - quad[1])) |
|
len2 = float(np.linalg.norm(quad[1] - quad[2])) |
|
min_len = min(len1, len2) |
|
if min_len < 3: |
|
continue |
|
|
|
|
|
if xy_text.shape[0] <= 0: |
|
continue |
|
|
|
|
|
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] |
|
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: |
|
|
|
continue |
|
|
|
|
|
left_center_pt = np.array( |
|
[[(quad[0, 0] + quad[-1, 0]) / 2.0, |
|
(quad[0, 1] + quad[-1, 1]) / 2.0]]) |
|
right_center_pt = np.array( |
|
[[(quad[1, 0] + quad[2, 0]) / 2.0, |
|
(quad[1, 1] + quad[2, 1]) / 2.0]]) |
|
proj_unit_vec = (right_center_pt - left_center_pt) / \ |
|
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) |
|
proj_value = np.sum(xy_text * proj_unit_vec, axis=1) |
|
xy_text = xy_text[np.argsort(proj_value)] |
|
|
|
|
|
if self.sample_pts_num == 0: |
|
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) |
|
else: |
|
sample_pts_num = self.sample_pts_num |
|
xy_center_line = xy_text[np.linspace( |
|
0, |
|
xy_text.shape[0] - 1, |
|
sample_pts_num, |
|
endpoint=True, |
|
dtype=np.float32).astype(np.int32)] |
|
|
|
point_pair_list = [] |
|
for x, y in xy_center_line: |
|
|
|
offset = tbo_map[y, x, :].reshape(2, 2) |
|
if offset_expand != 1.0: |
|
offset_length = np.linalg.norm( |
|
offset, axis=1, keepdims=True) |
|
expand_length = np.clip( |
|
offset_length * (offset_expand - 1), |
|
a_min=0.5, |
|
a_max=3.0) |
|
offset_detal = offset / offset_length * expand_length |
|
offset = offset + offset_detal |
|
|
|
ori_yx = np.array([y, x], dtype=np.float32) |
|
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array( |
|
[ratio_w, ratio_h]).reshape(-1, 2) |
|
point_pair_list.append(point_pair) |
|
|
|
|
|
detected_poly = self.point_pair2poly(point_pair_list) |
|
detected_poly = self.expand_poly_along_width(detected_poly, |
|
shrink_ratio_of_width) |
|
detected_poly[:, 0] = np.clip( |
|
detected_poly[:, 0], a_min=0, a_max=src_w) |
|
detected_poly[:, 1] = np.clip( |
|
detected_poly[:, 1], a_min=0, a_max=src_h) |
|
poly_list.append(detected_poly) |
|
|
|
return poly_list |
|
|
|
def __call__(self, outs_dict, shape_list): |
|
score_list = outs_dict['f_score'] |
|
border_list = outs_dict['f_border'] |
|
tvo_list = outs_dict['f_tvo'] |
|
tco_list = outs_dict['f_tco'] |
|
if isinstance(score_list, paddle.Tensor): |
|
score_list = score_list.numpy() |
|
border_list = border_list.numpy() |
|
tvo_list = tvo_list.numpy() |
|
tco_list = tco_list.numpy() |
|
|
|
img_num = len(shape_list) |
|
poly_lists = [] |
|
for ino in range(img_num): |
|
p_score = score_list[ino].transpose((1, 2, 0)) |
|
p_border = border_list[ino].transpose((1, 2, 0)) |
|
p_tvo = tvo_list[ino].transpose((1, 2, 0)) |
|
p_tco = tco_list[ino].transpose((1, 2, 0)) |
|
src_h, src_w, ratio_h, ratio_w = shape_list[ino] |
|
|
|
poly_list = self.detect_sast( |
|
p_score, |
|
p_tvo, |
|
p_border, |
|
p_tco, |
|
ratio_w, |
|
ratio_h, |
|
src_w, |
|
src_h, |
|
shrink_ratio_of_width=self.shrink_ratio_of_width, |
|
tcl_map_thresh=self.tcl_map_thresh, |
|
offset_expand=self.expand_scale) |
|
poly_lists.append({'points': np.array(poly_list)}) |
|
|
|
return poly_lists |
|
|