|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import paddle |
|
|
|
from .rec_postprocess import AttnLabelDecode |
|
|
|
|
|
class TableLabelDecode(AttnLabelDecode): |
|
""" """ |
|
|
|
def __init__(self, |
|
character_dict_path, |
|
merge_no_span_structure=False, |
|
**kwargs): |
|
dict_character = [] |
|
with open(character_dict_path, "rb") as fin: |
|
lines = fin.readlines() |
|
for line in lines: |
|
line = line.decode('utf-8').strip("\n").strip("\r\n") |
|
dict_character.append(line) |
|
|
|
if merge_no_span_structure: |
|
if "<td></td>" not in dict_character: |
|
dict_character.append("<td></td>") |
|
if "<td>" in dict_character: |
|
dict_character.remove("<td>") |
|
|
|
dict_character = self.add_special_char(dict_character) |
|
self.dict = {} |
|
for i, char in enumerate(dict_character): |
|
self.dict[char] = i |
|
self.character = dict_character |
|
self.td_token = ['<td>', '<td', '<td></td>'] |
|
|
|
def __call__(self, preds, batch=None): |
|
structure_probs = preds['structure_probs'] |
|
bbox_preds = preds['loc_preds'] |
|
if isinstance(structure_probs, paddle.Tensor): |
|
structure_probs = structure_probs.numpy() |
|
if isinstance(bbox_preds, paddle.Tensor): |
|
bbox_preds = bbox_preds.numpy() |
|
shape_list = batch[-1] |
|
result = self.decode(structure_probs, bbox_preds, shape_list) |
|
if len(batch) == 1: |
|
return result |
|
|
|
label_decode_result = self.decode_label(batch) |
|
return result, label_decode_result |
|
|
|
def decode(self, structure_probs, bbox_preds, shape_list): |
|
"""convert text-label into text-index. |
|
""" |
|
ignored_tokens = self.get_ignored_tokens() |
|
end_idx = self.dict[self.end_str] |
|
|
|
structure_idx = structure_probs.argmax(axis=2) |
|
structure_probs = structure_probs.max(axis=2) |
|
|
|
structure_batch_list = [] |
|
bbox_batch_list = [] |
|
batch_size = len(structure_idx) |
|
for batch_idx in range(batch_size): |
|
structure_list = [] |
|
bbox_list = [] |
|
score_list = [] |
|
for idx in range(len(structure_idx[batch_idx])): |
|
char_idx = int(structure_idx[batch_idx][idx]) |
|
if idx > 0 and char_idx == end_idx: |
|
break |
|
if char_idx in ignored_tokens: |
|
continue |
|
text = self.character[char_idx] |
|
if text in self.td_token: |
|
bbox = bbox_preds[batch_idx, idx] |
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx]) |
|
bbox_list.append(bbox) |
|
structure_list.append(text) |
|
score_list.append(structure_probs[batch_idx, idx]) |
|
structure_batch_list.append([structure_list, np.mean(score_list)]) |
|
bbox_batch_list.append(np.array(bbox_list)) |
|
result = { |
|
'bbox_batch_list': bbox_batch_list, |
|
'structure_batch_list': structure_batch_list, |
|
} |
|
return result |
|
|
|
def decode_label(self, batch): |
|
"""convert text-label into text-index. |
|
""" |
|
structure_idx = batch[1] |
|
gt_bbox_list = batch[2] |
|
shape_list = batch[-1] |
|
ignored_tokens = self.get_ignored_tokens() |
|
end_idx = self.dict[self.end_str] |
|
|
|
structure_batch_list = [] |
|
bbox_batch_list = [] |
|
batch_size = len(structure_idx) |
|
for batch_idx in range(batch_size): |
|
structure_list = [] |
|
bbox_list = [] |
|
for idx in range(len(structure_idx[batch_idx])): |
|
char_idx = int(structure_idx[batch_idx][idx]) |
|
if idx > 0 and char_idx == end_idx: |
|
break |
|
if char_idx in ignored_tokens: |
|
continue |
|
structure_list.append(self.character[char_idx]) |
|
|
|
bbox = gt_bbox_list[batch_idx][idx] |
|
if bbox.sum() != 0: |
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx]) |
|
bbox_list.append(bbox) |
|
structure_batch_list.append(structure_list) |
|
bbox_batch_list.append(bbox_list) |
|
result = { |
|
'bbox_batch_list': bbox_batch_list, |
|
'structure_batch_list': structure_batch_list, |
|
} |
|
return result |
|
|
|
def _bbox_decode(self, bbox, shape): |
|
h, w, ratio_h, ratio_w, pad_h, pad_w = shape |
|
bbox[0::2] *= w |
|
bbox[1::2] *= h |
|
return bbox |
|
|
|
|
|
class TableMasterLabelDecode(TableLabelDecode): |
|
""" """ |
|
|
|
def __init__(self, |
|
character_dict_path, |
|
box_shape='ori', |
|
merge_no_span_structure=True, |
|
**kwargs): |
|
super(TableMasterLabelDecode, self).__init__(character_dict_path, |
|
merge_no_span_structure) |
|
self.box_shape = box_shape |
|
assert box_shape in [ |
|
'ori', 'pad' |
|
], 'The shape used for box normalization must be ori or pad' |
|
|
|
def add_special_char(self, dict_character): |
|
self.beg_str = '<SOS>' |
|
self.end_str = '<EOS>' |
|
self.unknown_str = '<UKN>' |
|
self.pad_str = '<PAD>' |
|
dict_character = dict_character |
|
dict_character = dict_character + [ |
|
self.unknown_str, self.beg_str, self.end_str, self.pad_str |
|
] |
|
return dict_character |
|
|
|
def get_ignored_tokens(self): |
|
pad_idx = self.dict[self.pad_str] |
|
start_idx = self.dict[self.beg_str] |
|
end_idx = self.dict[self.end_str] |
|
unknown_idx = self.dict[self.unknown_str] |
|
return [start_idx, end_idx, pad_idx, unknown_idx] |
|
|
|
def _bbox_decode(self, bbox, shape): |
|
h, w, ratio_h, ratio_w, pad_h, pad_w = shape |
|
if self.box_shape == 'pad': |
|
h, w = pad_h, pad_w |
|
bbox[0::2] *= w |
|
bbox[1::2] *= h |
|
bbox[0::2] /= ratio_w |
|
bbox[1::2] /= ratio_h |
|
x, y, w, h = bbox |
|
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2 |
|
bbox = np.array([x1, y1, x2, y2]) |
|
return bbox |
|
|