Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from .ctc_postprocess import BaseRecLabelDecode | |
class VisionLANLabelDecode(BaseRecLabelDecode): | |
"""Convert between text-label and text-index.""" | |
def __init__(self, | |
character_dict_path=None, | |
use_space_char=False, | |
**kwargs): | |
super(VisionLANLabelDecode, self).__init__(character_dict_path, | |
use_space_char) | |
self.max_text_length = kwargs.get('max_text_length', 25) | |
self.nclass = len(self.character) + 1 | |
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | |
"""convert text-index into text-label.""" | |
result_list = [] | |
ignored_tokens = self.get_ignored_tokens() | |
batch_size = len(text_index) | |
for batch_idx in range(batch_size): | |
selection = np.ones(len(text_index[batch_idx]), dtype=bool) | |
if is_remove_duplicate: | |
selection[1:] = text_index[batch_idx][1:] != text_index[ | |
batch_idx][:-1] | |
for ignored_token in ignored_tokens: | |
selection &= text_index[batch_idx] != ignored_token | |
char_list = [ | |
self.character[text_id - 1] | |
for text_id in text_index[batch_idx][selection] | |
] | |
if text_prob is not None: | |
conf_list = text_prob[batch_idx][selection] | |
else: | |
conf_list = [1] * len(selection) | |
if len(conf_list) == 0: | |
conf_list = [0] | |
text = ''.join(char_list) | |
result_list.append((text, np.mean(conf_list).tolist())) | |
return result_list | |
def __call__(self, preds, batch=None, *args, **kwargs): | |
if len(preds) == 2: # eval mode | |
net_out, length = preds | |
if batch is not None: | |
label = batch[1] | |
else: # train mode | |
net_out = preds[0] | |
label, length = batch[1], batch[5] | |
net_out = torch.cat([t[:l] for t, l in zip(net_out, length)], | |
dim=0) | |
text = [] | |
if not isinstance(net_out, torch.Tensor): | |
net_out = torch.tensor(net_out, dtype=torch.float32) | |
net_out = F.softmax(net_out, dim=1) | |
for i in range(0, length.shape[0]): | |
preds_idx = (net_out[int(length[:i].sum()):int(length[:i].sum() + | |
length[i])].topk(1) | |
[1][:, 0].tolist()) | |
preds_text = ''.join([ | |
self.character[idx - 1] | |
if idx > 0 and idx <= len(self.character) else '' | |
for idx in preds_idx | |
]) | |
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum() + | |
length[i])].topk( | |
1)[0][:, 0] | |
preds_prob = torch.exp( | |
torch.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6)) | |
text.append((preds_text, float(preds_prob))) | |
if batch is None: | |
return text | |
label = self.decode(label.detach().cpu().numpy()) | |
return text, label | |