import string
import re
import math
import html
import os
from itertools import groupby
import torch
import torchvision.transforms as T
import torch.nn as nn
from torchvision.models import resnet101
from torch.utils.data import Dataset
import numpy as np
import cv2
import unicodedata
import gradio as gr
import pytesseract
import pandas as pd
"""
Data preproc functions:
adjust_to_see: adjust image to better visualize (rotate and transpose)
augmentation: apply variations to a list of images
normalization: apply normalization and variations on images (if required)
preprocess: main function for preprocess.
Make the image:
illumination_compensation: apply illumination regularitation
remove_cursive_style: remove cursive style from image (if necessary)
sauvola: apply sauvola binarization
text_standardize: preprocess and standardize sentence
"""
import re
import os
import cv2
import html
import string
import numpy as np
import numba as nb
def adjust_to_see(img):
"""Rotate and transpose to image visualize (cv2 method or jupyter notebook)"""
(h, w) = img.shape[:2]
(cX, cY) = (w // 2, h // 2)
M = cv2.getRotationMatrix2D((cX, cY), -90, 1.0)
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
img = cv2.warpAffine(img, M, (nW + 1, nH + 1))
img = cv2.warpAffine(img.transpose(), M, (nW, nH))
return img
def augmentation(imgs,
rotation_range=0,
scale_range=0,
height_shift_range=0,
width_shift_range=0,
dilate_range=1,
erode_range=1):
"""Apply variations to a list of images (rotate, width and height shift, scale, erode, dilate)"""
imgs = imgs.astype(np.float32)
_, h, w = imgs.shape
dilate_kernel = np.ones((int(np.random.uniform(1, dilate_range)),), np.uint8)
erode_kernel = np.ones((int(np.random.uniform(1, erode_range)),), np.uint8)
height_shift = np.random.uniform(-height_shift_range, height_shift_range)
rotation = np.random.uniform(-rotation_range, rotation_range)
scale = np.random.uniform(1 - scale_range, 1)
width_shift = np.random.uniform(-width_shift_range, width_shift_range)
trans_map = np.float32([[1, 0, width_shift * w], [0, 1, height_shift * h]])
rot_map = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale)
trans_map_aff = np.r_[trans_map, [[0, 0, 1]]]
rot_map_aff = np.r_[rot_map, [[0, 0, 1]]]
affine_mat = rot_map_aff.dot(trans_map_aff)[:2, :]
for i in range(len(imgs)):
imgs[i] = cv2.warpAffine(imgs[i], affine_mat, (w, h), flags=cv2.INTER_NEAREST, borderValue=255)
imgs[i] = cv2.erode(imgs[i], erode_kernel, iterations=1)
imgs[i] = cv2.dilate(imgs[i], dilate_kernel, iterations=1)
return imgs
def normalization(img):
"""Normalize list of image"""
m, s = cv2.meanStdDev(img)
img = img - m[0][0]
img = img / s[0][0] if s[0][0] > 0 else img
return img
def preprocess(img, input_size):
"""Make the process with the `input_size` to the scale resize"""
def imread(path):
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3:
if img.shape[2] == 4:
trans_mask = img[:, :, 3] == 0
img[trans_mask] = [255, 255, 255, 255]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return img
if isinstance(img, str):
img = imread(img)
if isinstance(img, tuple):
image, boundbox = img
img = imread(image)
for i in range(len(boundbox)):
if isinstance(boundbox[i], float):
total = len(img) if i < 2 else len(img[0])
boundbox[i] = int(total * boundbox[i])
img = np.asarray(img[boundbox[0]:boundbox[1], boundbox[2]:boundbox[3]], dtype=np.uint8)
wt, ht, _ = input_size
h, w = np.asarray(img).shape
f = max((w / wt), (h / ht))
new_size = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1))
img = illumination_compensation(img)
img = remove_cursive_style(img)
img = cv2.resize(img, new_size)
target = np.ones([ht, wt], dtype=np.uint8) * 255
target[0:new_size[1], 0:new_size[0]] = img
img = cv2.transpose(target)
return img
def illumination_compensation(img, only_cei=False):
"""Illumination compensation technique for text image"""
_, binary = cv2.threshold(img, 254, 255, cv2.THRESH_BINARY)
if np.sum(binary) > np.sum(img) * 0.8:
return np.asarray(img, dtype=np.uint8)
def scale(img):
s = np.max(img) - np.min(img)
res = img / s
res -= np.min(res)
res *= 255
return res
img = img.astype(np.float32)
height, width = img.shape
sqrt_hw = np.sqrt(height * width)
bins = np.arange(0, 300, 10)
bins[26] = 255
hp = np.histogram(img, bins)
for i in range(len(hp[0])):
if hp[0][i] > sqrt_hw:
hr = i * 10
break
np.seterr(divide='ignore', invalid='ignore')
cei = (img - (hr + 50 * 0.3)) * 2
cei[cei > 255] = 255
cei[cei < 0] = 0
if only_cei:
return np.asarray(cei, dtype=np.uint8)
m1 = np.asarray([-1, 0, 1, -2, 0, 2, -1, 0, 1]).reshape((3, 3))
m2 = np.asarray([-2, -1, 0, -1, 0, 1, 0, 1, 2]).reshape((3, 3))
m3 = np.asarray([-1, -2, -1, 0, 0, 0, 1, 2, 1]).reshape((3, 3))
m4 = np.asarray([0, 1, 2, -1, 0, 1, -2, -1, 0]).reshape((3, 3))
eg1 = np.abs(cv2.filter2D(img, -1, m1))
eg2 = np.abs(cv2.filter2D(img, -1, m2))
eg3 = np.abs(cv2.filter2D(img, -1, m3))
eg4 = np.abs(cv2.filter2D(img, -1, m4))
eg_avg = scale((eg1 + eg2 + eg3 + eg4) / 4)
h, w = eg_avg.shape
eg_bin = np.zeros((h, w))
eg_bin[eg_avg >= 30] = 255
h, w = cei.shape
cei_bin = np.zeros((h, w))
cei_bin[cei >= 60] = 255
h, w = eg_bin.shape
tli = 255 * np.ones((h, w))
tli[eg_bin == 255] = 0
tli[cei_bin == 255] = 0
kernel = np.ones((3, 3), np.uint8)
erosion = cv2.erode(tli, kernel, iterations=1)
int_img = np.asarray(cei)
estimate_light_distribution(width, height, erosion, cei, int_img)
mean_filter = 1 / 121 * np.ones((11, 11), np.uint8)
ldi = cv2.filter2D(scale(int_img), -1, mean_filter)
result = np.divide(cei, ldi) * 260
result[erosion != 0] *= 1.5
result[result < 0] = 0
result[result > 255] = 255
return np.asarray(result, dtype=np.uint8)
@nb.jit(nopython=True)
def estimate_light_distribution(width, height, erosion, cei, int_img):
"""Light distribution performed by numba (thanks @Sundrops)"""
for y in range(width):
for x in range(height):
if erosion[x][y] == 0:
i = x
while i < erosion.shape[0] and erosion[i][y] == 0:
i += 1
end = i - 1
n = end - x + 1
if n <= 30:
h, e = [], []
for k in range(5):
if x - k >= 0:
h.append(cei[x - k][y])
if end + k < cei.shape[0]:
e.append(cei[end + k][y])
mpv_h, mpv_e = max(h), max(e)
for m in range(n):
int_img[x + m][y] = mpv_h + (m + 1) * ((mpv_e - mpv_h) / n)
x = end
break
def remove_cursive_style(img):
"""Remove cursive writing style from image with deslanting algorithm"""
def calc_y_alpha(vec):
indices = np.where(vec > 0)[0]
h_alpha = len(indices)
if h_alpha > 0:
delta_y_alpha = indices[h_alpha - 1] - indices[0] + 1
if h_alpha == delta_y_alpha:
return h_alpha * h_alpha
return 0
alpha_vals = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0]
rows, cols = img.shape
results = []
ret, otsu = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
binary = otsu if ret < 127 else sauvola(img, (int(img.shape[0] / 2), int(img.shape[0] / 2)), 127, 1e-2)
for alpha in alpha_vals:
shift_x = max(-alpha * rows, 0.)
size = (cols + int(np.ceil(abs(alpha * rows))), rows)
transform = np.asarray([[1, alpha, shift_x], [0, 1, 0]], dtype=np.float32)
shear_img = cv2.warpAffine(binary, transform, size, cv2.INTER_NEAREST)
sum_alpha = 0
sum_alpha += np.apply_along_axis(calc_y_alpha, 0, shear_img)
results.append([np.sum(sum_alpha), size, transform])
result = sorted(results, key=lambda x: x[0], reverse=True)[0]
result = cv2.warpAffine(img, result[2], result[1], borderValue=255)
result = cv2.resize(result, dsize=(cols, rows))
return np.asarray(result, dtype=np.uint8)
def sauvola(img, window, thresh, k):
"""Sauvola binarization"""
rows, cols = img.shape
pad = int(np.floor(window[0] / 2))
sum2, sqsum = cv2.integral2(
cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_CONSTANT))
isum = sum2[window[0]:rows + window[0], window[1]:cols + window[1]] + \
sum2[0:rows, 0:cols] - \
sum2[window[0]:rows + window[0], 0:cols] - \
sum2[0:rows, window[1]:cols + window[1]]
isqsum = sqsum[window[0]:rows + window[0], window[1]:cols + window[1]] + \
sqsum[0:rows, 0:cols] - \
sqsum[window[0]:rows + window[0], 0:cols] - \
sqsum[0:rows, window[1]:cols + window[1]]
ksize = window[0] * window[1]
mean = isum / ksize
std = (((isqsum / ksize) - (mean**2) / ksize) / ksize) ** 0.5
threshold = (mean * (1 + k * (std / thresh - 1))) * (mean >= 100)
return np.asarray(255 * (img >= threshold), 'uint8')
RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE)
RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(
chr(768), chr(769), chr(832), chr(833), chr(2387),
chr(5151), chr(5152), chr(65344), chr(8242)), re.UNICODE)
RE_RESERVED_CHAR_FILTER = re.compile(r'[¶¤«»]', re.UNICODE)
RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE)
RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE)
RE_BASIC_CLEANER = re.compile(r'[^\w\s{}]'.format(re.escape(string.punctuation)), re.UNICODE)
LEFT_PUNCTUATION_FILTER = """!%&),.:;<=>?@\\]^_`|}~"""
RIGHT_PUNCTUATION_FILTER = """"(/<=>@[\\^_`{|~"""
NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE)
def text_standardize(text):
"""Organize/add spaces around punctuation marks"""
if text is None:
return ""
text = html.unescape(text).replace("\\n", "").replace("\\t", "")
text = RE_RESERVED_CHAR_FILTER.sub("", text)
text = RE_DASH_FILTER.sub("-", text)
text = RE_APOSTROPHE_FILTER.sub("'", text)
text = RE_LEFT_PARENTH_FILTER.sub("(", text)
text = RE_RIGHT_PARENTH_FILTER.sub(")", text)
text = RE_BASIC_CLEANER.sub("", text)
text = text.lstrip(LEFT_PUNCTUATION_FILTER)
text = text.rstrip(RIGHT_PUNCTUATION_FILTER)
text = text.translate(str.maketrans({c: f" {c} " for c in string.punctuation}))
text = NORMALIZE_WHITESPACE_REGEX.sub(" ", text.strip())
return text
class Tokenizer():
"""Manager tokens functions and charset/dictionary properties"""
def __init__(self, chars, max_text_length=128):
self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS"
self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars)
self.PAD = self.chars.index(self.PAD_TK)
self.UNK = self.chars.index(self.UNK_TK)
self.vocab_size = len(self.chars)
self.maxlen = max_text_length
def encode(self, text):
"""Encode text to vector"""
text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
text = " ".join(text.split())
groups = ["".join(group) for _, group in groupby(text)]
text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups])
encoded = []
text = ['SOS'] + list(text) + ['EOS']
for item in text:
index = self.chars.index(item)
index = self.UNK if index == -1 else index
encoded.append(index)
return np.asarray(encoded)
def decode(self, text):
"""Decode vector to text"""
decoded = "".join([self.chars[int(x)] for x in text if x > -1])
decoded = self.remove_tokens(decoded)
decoded = text_standardize(decoded)
return decoded
def remove_tokens(self, text):
"""Remove tokens (PAD) from text"""
return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "")
charset_base = string.printable[:95]
tokenizer = Tokenizer(charset_base)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_memory(model,imgs):
x = model.conv(model.get_feature(imgs))
bs,_,H, W = x.shape
pos = torch.cat([
model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
model.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
return model.transformer.encoder(pos + 0.1 * x.flatten(2).permute(2, 0, 1))
def test(model, test_loader, max_text_length):
model.eval()
predicts = []
gt = []
imgs = []
with torch.no_grad():
for batch in test_loader:
src, trg = batch
imgs.append(src.flatten(0,1))
src, trg = src.to(device), trg.to(device)
memory = get_memory(model,src.float())
out_indexes = [tokenizer.chars.index('SOS'), ]
for i in range(max_text_length):
mask = model.generate_square_subsequent_mask(i+1).to(device)
trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
output = model.vocab(model.transformer.decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask))
out_token = output.argmax(2)[-1].item()
out_indexes.append(out_token)
if out_token == tokenizer.chars.index('EOS'):
break
predicts.append(tokenizer.decode(out_indexes))
gt.append(tokenizer.decode(trg.flatten(0,1)))
return predicts, gt, imgs
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=128):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class OCR(nn.Module):
def __init__(self, vocab_len, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# create ResNet-101 backbone
self.backbone = resnet101()
del self.backbone.fc
# create conversion layer
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# create a default PyTorch transformer
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
# prediction heads with length of vocab
# DETR used basic 3 layer MLP for output
self.vocab = nn.Linear(hidden_dim,vocab_len)
# output positional encodings (object queries)
self.decoder = nn.Embedding(vocab_len, hidden_dim)
self.query_pos = PositionalEncoding(hidden_dim, .2)
# spatial positional encodings, sine positional encoding can be used.
# Detr baseline uses sine positional encoding.
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.trg_mask = None
def generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz), 1)
mask = mask.masked_fill(mask==1, float('-inf'))
return mask
def get_feature(self,x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
return x
def make_len_mask(self, inp):
return (inp == 0).transpose(0, 1)
def forward(self, inputs, trg):
# propagate inputs through ResNet-101 up to avg-pool layer
x = self.get_feature(inputs)
# convert from 2048 to 256 feature planes for the transformer
h = self.conv(x)
# construct positional encodings
bs,_,H, W = h.shape
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# generating subsequent mask for target
if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device)
# Padding mask
trg_pad_mask = self.make_len_mask(trg)
# Getting postional encoding for target
trg = self.decoder(trg)
trg = self.query_pos(trg)
output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask,
tgt_key_padding_mask=trg_pad_mask.permute(1,0))
return self.vocab(output.transpose(0,1))
def make_model(vocab_len, hidden_dim=256, nheads=4,
num_encoder_layers=4, num_decoder_layers=4):
return OCR(vocab_len, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
class DataGenerator_Spanish(Dataset):
def __init__(self, source_dict, charset, max_text_length, transform, shuffle = True):
self.tokenizer = Tokenizer(charset, max_text_length)
self.transform = transform
self.shuffle = shuffle
self.dataset = source_dict.copy()
if self.shuffle:
randomize = np.arange(len(self.dataset['gt']))
np.random.seed(42)
np.random.shuffle(randomize)
self.dataset['dt'] = np.array(self.dataset['dt'])[randomize]
self.dataset['gt'] = np.array(self.dataset['gt'])[randomize]
self.dataset['gt'] = [x.decode() for x in self.dataset['gt']]
self.size = len(self.dataset['gt'])
def __getitem__(self, i):
img = self.dataset['dt'][i]
img = np.repeat(img[..., np.newaxis], 3, -1)
img = normalization(img)
if self.transform is not None:
img = self.transform(img)
y_train = self.tokenizer.encode(self.dataset['gt'][i])
y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train)))
gt = torch.Tensor(y_train)
return img, gt
def __len__(self):
return self.size
def crop_dict(page):
# page = cv2.imread(page)
master_page_par_line_list = []
image = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2GRAY)
_, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
data = pytesseract.image_to_data(image,config='--oem 3 --psm 6', output_type='dict')
page_num = 1
df = pd.DataFrame(data)
df = df[df["conf"] > 0]
df["page_num"] = page_num
page_par_line_dict = {}
for index, row in df.iterrows():
page_par_line = f"{page_num}_{row['par_num']}_{row['line_num']}"
if(page_par_line not in page_par_line_dict):
page_par_line_dict[page_par_line] = {"text": str(row["text"]) + " ", "box": (row['left'], row['top'], row['left'] + row['width'], row['top'] + row['height'])}
else:
page_par_line_dict[page_par_line]["text"] = page_par_line_dict[page_par_line]["text"] + str(row["text"]) + " "
page_par_line_dict[page_par_line]['box'] = (min(page_par_line_dict[page_par_line]['box'][0], row['left']),
min(page_par_line_dict[page_par_line]['box'][1], row['top']),
max(page_par_line_dict[page_par_line]['box'][2], row['left'] + row['width']),
max(page_par_line_dict[page_par_line]['box'][3], row['top'] + row['height']))
for entry in page_par_line_dict:
splitted_key = entry.split('_')
entry_value = page_par_line_dict[entry]
master_page_par_line_list.append({
'page_number' : splitted_key[0],
'paragraph_number' : splitted_key[1],
'line_number' : splitted_key[2],
'entry_text' : entry_value['text'],
'bounding_box' : entry_value['box']
})
imgs_cropped = {}
img_text_dict = {"dt" : [], "gt" : []}
for line in page_par_line_dict.values():
if line['box'] is not None:
cv2.rectangle(image, (line['box'][0], line['box'][1]), (line['box'][2], line['box'][3]), (0, 0, 255), 2)
img_cropped = image[line['box'][1]:line['box'][3], line['box'][0]:line['box'][2]]
if not os.path.exists('cropped_lines'):
os.makedirs('cropped_lines')
cv2.imwrite(f"cropped_lines/{line['box'][1]}.jpg", img_cropped)
# print(line['text'])
imgs_cropped[line['box'][1]] = img_cropped
assert os.path.exists(f'cropped_lines/{line["box"][1]}.jpg')
img_text_dict["dt"].append(preprocess(f"cropped_lines/{line['box'][1]}.jpg",(1024,128,1)))
img_text_dict["gt"].append(line['text'].encode())
return img_text_dict
#inference
def generate(img_path):
pretrained_model = make_model(vocab_len = 100)
_=pretrained_model.to(device)
pretrained_model.load_state_dict(torch.load('span_fine_tuned_model.pt', map_location=torch.device('cpu')))
max_text_length = 128
transform = T.Compose([T.ToTensor()])
sp_loader = torch.utils.data.DataLoader(DataGenerator_Spanish(crop_dict(img_path),charset_base,max_text_length ,transform, shuffle=False), batch_size=1, shuffle=False, num_workers=2)
predicts2, gt2, imgs = test(pretrained_model, sp_loader, max_text_length)
predicts2 = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts2))
final_pred_str = ""
for s in predicts2:
final_pred_str += s+"\n"
return final_pred_str
gr.Interface(fn=generate,
inputs=[gr.Image(label='Input image')],
outputs=[gr.Textbox(label='Read Text')],
allow_flagging='never',
title='Transformer Text Detection - HumanAI',
theme=gr.themes.Monochrome()).launch()