# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition import math import torch import torchvision.transforms as T import warnings warnings.filterwarnings("ignore", category=UserWarning) class NormalizePAD(object): def __init__(self, max_size, PAD_type='right'): self.toTensor = T.ToTensor() self.max_size = max_size self.max_width_half = math.floor(max_size[2] / 2) self.PAD_type = PAD_type def __call__(self, img): img = self.toTensor(img) img.sub_(0.5).div_(0.5) c, h, w = img.size() Pad_img = torch.FloatTensor(*self.max_size).fill_(0) Pad_img[:, :, :w] = img # right pad if self.max_size[2] != w: # add border Pad Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) return Pad_img class CTCLabelConverter(object): """ Convert between text-label and text-index """ def __init__(self, character): # character (str): set of the possible characters. dict_character = list(character) self.dict = {} for i, char in enumerate(dict_character): # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss self.dict[char] = i + 1 self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) def encode(self, text, batch_max_length=25): """convert text-label into text-index. input: text: text labels of each image. [batch_size] batch_max_length: max length of text label in the batch. 25 by default output: text: text index for CTCLoss. [batch_size, batch_max_length] length: length of each text. [batch_size] """ length = [len(s) for s in text] # The index used for padding (=0) would not affect the CTC loss calculation. batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) for i, t in enumerate(text): text = list(t) text = [self.dict[char] for char in text] batch_text[i][:len(text)] = torch.LongTensor(text) return (batch_text, torch.IntTensor(length)) def decode(self, text_index, length): """ convert text-index into text-label. """ texts = [] for index, l in enumerate(length): t = text_index[index, :] char_list = [] for i in range(l): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. char_list.append(self.character[t[i]]) text = ''.join(char_list) texts.append(text) return texts