import string import re import math import html import os from itertools import groupby import torch import torchvision.transforms as T import torch.nn as nn from torchvision.models import resnet101 from torch.utils.data import Dataset import numpy as np import cv2 import unicodedata import gradio as gr import pytesseract import pandas as pd """ Data preproc functions: adjust_to_see: adjust image to better visualize (rotate and transpose) augmentation: apply variations to a list of images normalization: apply normalization and variations on images (if required) preprocess: main function for preprocess. Make the image: illumination_compensation: apply illumination regularitation remove_cursive_style: remove cursive style from image (if necessary) sauvola: apply sauvola binarization text_standardize: preprocess and standardize sentence """ import re import os import cv2 import html import string import numpy as np import numba as nb def adjust_to_see(img): """Rotate and transpose to image visualize (cv2 method or jupyter notebook)""" (h, w) = img.shape[:2] (cX, cY) = (w // 2, h // 2) M = cv2.getRotationMatrix2D((cX, cY), -90, 1.0) cos = np.abs(M[0, 0]) sin = np.abs(M[0, 1]) nW = int((h * sin) + (w * cos)) nH = int((h * cos) + (w * sin)) M[0, 2] += (nW / 2) - cX M[1, 2] += (nH / 2) - cY img = cv2.warpAffine(img, M, (nW + 1, nH + 1)) img = cv2.warpAffine(img.transpose(), M, (nW, nH)) return img def augmentation(imgs, rotation_range=0, scale_range=0, height_shift_range=0, width_shift_range=0, dilate_range=1, erode_range=1): """Apply variations to a list of images (rotate, width and height shift, scale, erode, dilate)""" imgs = imgs.astype(np.float32) _, h, w = imgs.shape dilate_kernel = np.ones((int(np.random.uniform(1, dilate_range)),), np.uint8) erode_kernel = np.ones((int(np.random.uniform(1, erode_range)),), np.uint8) height_shift = np.random.uniform(-height_shift_range, height_shift_range) rotation = np.random.uniform(-rotation_range, rotation_range) scale = np.random.uniform(1 - scale_range, 1) width_shift = np.random.uniform(-width_shift_range, width_shift_range) trans_map = np.float32([[1, 0, width_shift * w], [0, 1, height_shift * h]]) rot_map = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) trans_map_aff = np.r_[trans_map, [[0, 0, 1]]] rot_map_aff = np.r_[rot_map, [[0, 0, 1]]] affine_mat = rot_map_aff.dot(trans_map_aff)[:2, :] for i in range(len(imgs)): imgs[i] = cv2.warpAffine(imgs[i], affine_mat, (w, h), flags=cv2.INTER_NEAREST, borderValue=255) imgs[i] = cv2.erode(imgs[i], erode_kernel, iterations=1) imgs[i] = cv2.dilate(imgs[i], dilate_kernel, iterations=1) return imgs def normalization(img): """Normalize list of image""" m, s = cv2.meanStdDev(img) img = img - m[0][0] img = img / s[0][0] if s[0][0] > 0 else img return img def preprocess(img, input_size): """Make the process with the `input_size` to the scale resize""" def imread(path): img = cv2.imread(path, cv2.IMREAD_UNCHANGED) if len(img.shape) == 3: if img.shape[2] == 4: trans_mask = img[:, :, 3] == 0 img[trans_mask] = [255, 255, 255, 255] img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) return img if isinstance(img, str): img = imread(img) if isinstance(img, tuple): image, boundbox = img img = imread(image) for i in range(len(boundbox)): if isinstance(boundbox[i], float): total = len(img) if i < 2 else len(img[0]) boundbox[i] = int(total * boundbox[i]) img = np.asarray(img[boundbox[0]:boundbox[1], boundbox[2]:boundbox[3]], dtype=np.uint8) wt, ht, _ = input_size h, w = np.asarray(img).shape f = max((w / wt), (h / ht)) new_size = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) img = illumination_compensation(img) img = remove_cursive_style(img) img = cv2.resize(img, new_size) target = np.ones([ht, wt], dtype=np.uint8) * 255 target[0:new_size[1], 0:new_size[0]] = img img = cv2.transpose(target) return img def illumination_compensation(img, only_cei=False): """Illumination compensation technique for text image""" _, binary = cv2.threshold(img, 254, 255, cv2.THRESH_BINARY) if np.sum(binary) > np.sum(img) * 0.8: return np.asarray(img, dtype=np.uint8) def scale(img): s = np.max(img) - np.min(img) res = img / s res -= np.min(res) res *= 255 return res img = img.astype(np.float32) height, width = img.shape sqrt_hw = np.sqrt(height * width) bins = np.arange(0, 300, 10) bins[26] = 255 hp = np.histogram(img, bins) for i in range(len(hp[0])): if hp[0][i] > sqrt_hw: hr = i * 10 break np.seterr(divide='ignore', invalid='ignore') cei = (img - (hr + 50 * 0.3)) * 2 cei[cei > 255] = 255 cei[cei < 0] = 0 if only_cei: return np.asarray(cei, dtype=np.uint8) m1 = np.asarray([-1, 0, 1, -2, 0, 2, -1, 0, 1]).reshape((3, 3)) m2 = np.asarray([-2, -1, 0, -1, 0, 1, 0, 1, 2]).reshape((3, 3)) m3 = np.asarray([-1, -2, -1, 0, 0, 0, 1, 2, 1]).reshape((3, 3)) m4 = np.asarray([0, 1, 2, -1, 0, 1, -2, -1, 0]).reshape((3, 3)) eg1 = np.abs(cv2.filter2D(img, -1, m1)) eg2 = np.abs(cv2.filter2D(img, -1, m2)) eg3 = np.abs(cv2.filter2D(img, -1, m3)) eg4 = np.abs(cv2.filter2D(img, -1, m4)) eg_avg = scale((eg1 + eg2 + eg3 + eg4) / 4) h, w = eg_avg.shape eg_bin = np.zeros((h, w)) eg_bin[eg_avg >= 30] = 255 h, w = cei.shape cei_bin = np.zeros((h, w)) cei_bin[cei >= 60] = 255 h, w = eg_bin.shape tli = 255 * np.ones((h, w)) tli[eg_bin == 255] = 0 tli[cei_bin == 255] = 0 kernel = np.ones((3, 3), np.uint8) erosion = cv2.erode(tli, kernel, iterations=1) int_img = np.asarray(cei) estimate_light_distribution(width, height, erosion, cei, int_img) mean_filter = 1 / 121 * np.ones((11, 11), np.uint8) ldi = cv2.filter2D(scale(int_img), -1, mean_filter) result = np.divide(cei, ldi) * 260 result[erosion != 0] *= 1.5 result[result < 0] = 0 result[result > 255] = 255 return np.asarray(result, dtype=np.uint8) @nb.jit(nopython=True) def estimate_light_distribution(width, height, erosion, cei, int_img): """Light distribution performed by numba (thanks @Sundrops)""" for y in range(width): for x in range(height): if erosion[x][y] == 0: i = x while i < erosion.shape[0] and erosion[i][y] == 0: i += 1 end = i - 1 n = end - x + 1 if n <= 30: h, e = [], [] for k in range(5): if x - k >= 0: h.append(cei[x - k][y]) if end + k < cei.shape[0]: e.append(cei[end + k][y]) mpv_h, mpv_e = max(h), max(e) for m in range(n): int_img[x + m][y] = mpv_h + (m + 1) * ((mpv_e - mpv_h) / n) x = end break def remove_cursive_style(img): """Remove cursive writing style from image with deslanting algorithm""" def calc_y_alpha(vec): indices = np.where(vec > 0)[0] h_alpha = len(indices) if h_alpha > 0: delta_y_alpha = indices[h_alpha - 1] - indices[0] + 1 if h_alpha == delta_y_alpha: return h_alpha * h_alpha return 0 alpha_vals = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0] rows, cols = img.shape results = [] ret, otsu = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) binary = otsu if ret < 127 else sauvola(img, (int(img.shape[0] / 2), int(img.shape[0] / 2)), 127, 1e-2) for alpha in alpha_vals: shift_x = max(-alpha * rows, 0.) size = (cols + int(np.ceil(abs(alpha * rows))), rows) transform = np.asarray([[1, alpha, shift_x], [0, 1, 0]], dtype=np.float32) shear_img = cv2.warpAffine(binary, transform, size, cv2.INTER_NEAREST) sum_alpha = 0 sum_alpha += np.apply_along_axis(calc_y_alpha, 0, shear_img) results.append([np.sum(sum_alpha), size, transform]) result = sorted(results, key=lambda x: x[0], reverse=True)[0] result = cv2.warpAffine(img, result[2], result[1], borderValue=255) result = cv2.resize(result, dsize=(cols, rows)) return np.asarray(result, dtype=np.uint8) def sauvola(img, window, thresh, k): """Sauvola binarization""" rows, cols = img.shape pad = int(np.floor(window[0] / 2)) sum2, sqsum = cv2.integral2( cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_CONSTANT)) isum = sum2[window[0]:rows + window[0], window[1]:cols + window[1]] + \ sum2[0:rows, 0:cols] - \ sum2[window[0]:rows + window[0], 0:cols] - \ sum2[0:rows, window[1]:cols + window[1]] isqsum = sqsum[window[0]:rows + window[0], window[1]:cols + window[1]] + \ sqsum[0:rows, 0:cols] - \ sqsum[window[0]:rows + window[0], 0:cols] - \ sqsum[0:rows, window[1]:cols + window[1]] ksize = window[0] * window[1] mean = isum / ksize std = (((isqsum / ksize) - (mean**2) / ksize) / ksize) ** 0.5 threshold = (mean * (1 + k * (std / thresh - 1))) * (mean >= 100) return np.asarray(255 * (img >= threshold), 'uint8') RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format( chr(768), chr(769), chr(832), chr(833), chr(2387), chr(5151), chr(5152), chr(65344), chr(8242)), re.UNICODE) RE_RESERVED_CHAR_FILTER = re.compile(r'[¶¤«»]', re.UNICODE) RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) RE_BASIC_CLEANER = re.compile(r'[^\w\s{}]'.format(re.escape(string.punctuation)), re.UNICODE) LEFT_PUNCTUATION_FILTER = """!%&),.:;<=>?@\\]^_`|}~""" RIGHT_PUNCTUATION_FILTER = """"(/<=>@[\\^_`{|~""" NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) def text_standardize(text): """Organize/add spaces around punctuation marks""" if text is None: return "" text = html.unescape(text).replace("\\n", "").replace("\\t", "") text = RE_RESERVED_CHAR_FILTER.sub("", text) text = RE_DASH_FILTER.sub("-", text) text = RE_APOSTROPHE_FILTER.sub("'", text) text = RE_LEFT_PARENTH_FILTER.sub("(", text) text = RE_RIGHT_PARENTH_FILTER.sub(")", text) text = RE_BASIC_CLEANER.sub("", text) text = text.lstrip(LEFT_PUNCTUATION_FILTER) text = text.rstrip(RIGHT_PUNCTUATION_FILTER) text = text.translate(str.maketrans({c: f" {c} " for c in string.punctuation})) text = NORMALIZE_WHITESPACE_REGEX.sub(" ", text.strip()) return text class Tokenizer(): """Manager tokens functions and charset/dictionary properties""" def __init__(self, chars, max_text_length=128): self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS" self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars) self.PAD = self.chars.index(self.PAD_TK) self.UNK = self.chars.index(self.UNK_TK) self.vocab_size = len(self.chars) self.maxlen = max_text_length def encode(self, text): """Encode text to vector""" text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII") text = " ".join(text.split()) groups = ["".join(group) for _, group in groupby(text)] text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups]) encoded = [] text = ['SOS'] + list(text) + ['EOS'] for item in text: index = self.chars.index(item) index = self.UNK if index == -1 else index encoded.append(index) return np.asarray(encoded) def decode(self, text): """Decode vector to text""" decoded = "".join([self.chars[int(x)] for x in text if x > -1]) decoded = self.remove_tokens(decoded) decoded = text_standardize(decoded) return decoded def remove_tokens(self, text): """Remove tokens (PAD) from text""" return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "") charset_base = string.printable[:95] tokenizer = Tokenizer(charset_base) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_memory(model,imgs): x = model.conv(model.get_feature(imgs)) bs,_,H, W = x.shape pos = torch.cat([ model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), model.row_embed[:H].unsqueeze(1).repeat(1, W, 1), ], dim=-1).flatten(0, 1).unsqueeze(1) return model.transformer.encoder(pos + 0.1 * x.flatten(2).permute(2, 0, 1)) def test(model, test_loader, max_text_length): model.eval() predicts = [] gt = [] imgs = [] with torch.no_grad(): for batch in test_loader: src, trg = batch imgs.append(src.flatten(0,1)) src, trg = src.to(device), trg.to(device) memory = get_memory(model,src.float()) out_indexes = [tokenizer.chars.index('SOS'), ] for i in range(max_text_length): mask = model.generate_square_subsequent_mask(i+1).to(device) trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device) output = model.vocab(model.transformer.decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask)) out_token = output.argmax(2)[-1].item() out_indexes.append(out_token) if out_token == tokenizer.chars.index('EOS'): break predicts.append(tokenizer.decode(out_indexes)) gt.append(tokenizer.decode(trg.flatten(0,1))) return predicts, gt, imgs class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=128): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x) class OCR(nn.Module): def __init__(self, vocab_len, hidden_dim, nheads, num_encoder_layers, num_decoder_layers): super().__init__() # create ResNet-101 backbone self.backbone = resnet101() del self.backbone.fc # create conversion layer self.conv = nn.Conv2d(2048, hidden_dim, 1) # create a default PyTorch transformer self.transformer = nn.Transformer( hidden_dim, nheads, num_encoder_layers, num_decoder_layers) # prediction heads with length of vocab # DETR used basic 3 layer MLP for output self.vocab = nn.Linear(hidden_dim,vocab_len) # output positional encodings (object queries) self.decoder = nn.Embedding(vocab_len, hidden_dim) self.query_pos = PositionalEncoding(hidden_dim, .2) # spatial positional encodings, sine positional encoding can be used. # Detr baseline uses sine positional encoding. self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) self.trg_mask = None def generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz), 1) mask = mask.masked_fill(mask==1, float('-inf')) return mask def get_feature(self,x): x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) return x def make_len_mask(self, inp): return (inp == 0).transpose(0, 1) def forward(self, inputs, trg): # propagate inputs through ResNet-101 up to avg-pool layer x = self.get_feature(inputs) # convert from 2048 to 256 feature planes for the transformer h = self.conv(x) # construct positional encodings bs,_,H, W = h.shape pos = torch.cat([ self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), ], dim=-1).flatten(0, 1).unsqueeze(1) # generating subsequent mask for target if self.trg_mask is None or self.trg_mask.size(0) != len(trg): self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device) # Padding mask trg_pad_mask = self.make_len_mask(trg) # Getting postional encoding for target trg = self.decoder(trg) trg = self.query_pos(trg) output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask, tgt_key_padding_mask=trg_pad_mask.permute(1,0)) return self.vocab(output.transpose(0,1)) def make_model(vocab_len, hidden_dim=256, nheads=4, num_encoder_layers=4, num_decoder_layers=4): return OCR(vocab_len, hidden_dim, nheads, num_encoder_layers, num_decoder_layers) class DataGenerator_Spanish(Dataset): def __init__(self, source_dict, charset, max_text_length, transform, shuffle = True): self.tokenizer = Tokenizer(charset, max_text_length) self.transform = transform self.shuffle = shuffle self.dataset = source_dict.copy() if self.shuffle: randomize = np.arange(len(self.dataset['gt'])) np.random.seed(42) np.random.shuffle(randomize) self.dataset['dt'] = np.array(self.dataset['dt'])[randomize] self.dataset['gt'] = np.array(self.dataset['gt'])[randomize] self.dataset['gt'] = [x.decode() for x in self.dataset['gt']] self.size = len(self.dataset['gt']) def __getitem__(self, i): img = self.dataset['dt'][i] img = np.repeat(img[..., np.newaxis], 3, -1) img = normalization(img) if self.transform is not None: img = self.transform(img) y_train = self.tokenizer.encode(self.dataset['gt'][i]) y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train))) gt = torch.Tensor(y_train) return img, gt def __len__(self): return self.size def crop_dict(page): # page = cv2.imread(page) master_page_par_line_list = [] image = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2GRAY) _, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) data = pytesseract.image_to_data(image,config='--oem 3 --psm 6', output_type='dict') page_num = 1 df = pd.DataFrame(data) df = df[df["conf"] > 0] df["page_num"] = page_num page_par_line_dict = {} for index, row in df.iterrows(): page_par_line = f"{page_num}_{row['par_num']}_{row['line_num']}" if(page_par_line not in page_par_line_dict): page_par_line_dict[page_par_line] = {"text": str(row["text"]) + " ", "box": (row['left'], row['top'], row['left'] + row['width'], row['top'] + row['height'])} else: page_par_line_dict[page_par_line]["text"] = page_par_line_dict[page_par_line]["text"] + str(row["text"]) + " " page_par_line_dict[page_par_line]['box'] = (min(page_par_line_dict[page_par_line]['box'][0], row['left']), min(page_par_line_dict[page_par_line]['box'][1], row['top']), max(page_par_line_dict[page_par_line]['box'][2], row['left'] + row['width']), max(page_par_line_dict[page_par_line]['box'][3], row['top'] + row['height'])) for entry in page_par_line_dict: splitted_key = entry.split('_') entry_value = page_par_line_dict[entry] master_page_par_line_list.append({ 'page_number' : splitted_key[0], 'paragraph_number' : splitted_key[1], 'line_number' : splitted_key[2], 'entry_text' : entry_value['text'], 'bounding_box' : entry_value['box'] }) imgs_cropped = {} img_text_dict = {"dt" : [], "gt" : []} for line in page_par_line_dict.values(): if line['box'] is not None: cv2.rectangle(image, (line['box'][0], line['box'][1]), (line['box'][2], line['box'][3]), (0, 0, 255), 2) img_cropped = image[line['box'][1]:line['box'][3], line['box'][0]:line['box'][2]] if not os.path.exists('cropped_lines'): os.makedirs('cropped_lines') cv2.imwrite(f"cropped_lines/{line['box'][1]}.jpg", img_cropped) # print(line['text']) imgs_cropped[line['box'][1]] = img_cropped assert os.path.exists(f'cropped_lines/{line["box"][1]}.jpg') img_text_dict["dt"].append(preprocess(f"cropped_lines/{line['box'][1]}.jpg",(1024,128,1))) img_text_dict["gt"].append(line['text'].encode()) return img_text_dict #inference def generate(img_path): pretrained_model = make_model(vocab_len = 100) _=pretrained_model.to(device) pretrained_model.load_state_dict(torch.load('span_fine_tuned_model.pt', map_location=torch.device('cpu'))) max_text_length = 128 transform = T.Compose([T.ToTensor()]) sp_loader = torch.utils.data.DataLoader(DataGenerator_Spanish(crop_dict(img_path),charset_base,max_text_length ,transform, shuffle=False), batch_size=1, shuffle=False, num_workers=2) predicts2, gt2, imgs = test(pretrained_model, sp_loader, max_text_length) predicts2 = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts2)) final_pred_str = "" for s in predicts2: final_pred_str += s+"\n" return final_pred_str gr.Interface(fn=generate, inputs=[gr.Image(label='Input image')], outputs=[gr.Textbox(label='Read Text')], allow_flagging='never', title='Transformer Text Detection - HumanAI', theme=gr.themes.Monochrome()).launch()