UrduOCR-UTRNet / utils.py
Abdur Rahman
Deploy to HuggingFace spaces
390ca68
# A simplified version of the original code - https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
import math
import torch
import torchvision.transforms as T
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class NormalizePAD(object):
def __init__(self, max_size, PAD_type='right'):
self.toTensor = T.ToTensor()
self.max_size = max_size
self.max_width_half = math.floor(max_size[2] / 2)
self.PAD_type = PAD_type
def __call__(self, img):
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
c, h, w = img.size()
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
return Pad_img
class CTCLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text: text index for CTCLoss. [batch_size, batch_max_length]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
# The index used for padding (=0) would not affect the CTC loss calculation.
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
for i, t in enumerate(text):
text = list(t)
text = [self.dict[char] for char in text]
batch_text[i][:len(text)] = torch.LongTensor(text)
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
t = text_index[index, :]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
return texts