File size: 3,310 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import torch
import torch.nn.functional as F

from .ctc_postprocess import BaseRecLabelDecode


class VisionLANLabelDecode(BaseRecLabelDecode):
    """Convert between text-label and text-index."""

    def __init__(self,
                 character_dict_path=None,
                 use_space_char=False,
                 **kwargs):
        super(VisionLANLabelDecode, self).__init__(character_dict_path,
                                                   use_space_char)
        self.max_text_length = kwargs.get('max_text_length', 25)
        self.nclass = len(self.character) + 1

    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
        """convert text-index into text-label."""
        result_list = []
        ignored_tokens = self.get_ignored_tokens()
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
            if is_remove_duplicate:
                selection[1:] = text_index[batch_idx][1:] != text_index[
                    batch_idx][:-1]
            for ignored_token in ignored_tokens:
                selection &= text_index[batch_idx] != ignored_token

            char_list = [
                self.character[text_id - 1]
                for text_id in text_index[batch_idx][selection]
            ]
            if text_prob is not None:
                conf_list = text_prob[batch_idx][selection]
            else:
                conf_list = [1] * len(selection)
            if len(conf_list) == 0:
                conf_list = [0]

            text = ''.join(char_list)
            result_list.append((text, np.mean(conf_list).tolist()))
        return result_list

    def __call__(self, preds, batch=None, *args, **kwargs):
        if len(preds) == 2:  # eval mode
            net_out, length = preds
            if batch is not None:
                label = batch[1]

        else:  # train mode
            net_out = preds[0]
            label, length = batch[1], batch[5]
            net_out = torch.cat([t[:l] for t, l in zip(net_out, length)],
                                dim=0)
        text = []
        if not isinstance(net_out, torch.Tensor):
            net_out = torch.tensor(net_out, dtype=torch.float32)
        net_out = F.softmax(net_out, dim=1)
        for i in range(0, length.shape[0]):
            preds_idx = (net_out[int(length[:i].sum()):int(length[:i].sum() +
                                                           length[i])].topk(1)
                         [1][:, 0].tolist())
            preds_text = ''.join([
                self.character[idx - 1]
                if idx > 0 and idx <= len(self.character) else ''
                for idx in preds_idx
            ])
            preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum() +
                                                           length[i])].topk(
                                                               1)[0][:, 0]
            preds_prob = torch.exp(
                torch.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
            text.append((preds_text, float(preds_prob)))
        if batch is None:
            return text
        label = self.decode(label.detach().cpu().numpy())
        return text, label