Spaces:
Runtime error
Runtime error
import string | |
import re | |
import math | |
import html | |
import os | |
from itertools import groupby | |
import torch | |
import torchvision.transforms as T | |
import torch.nn as nn | |
from torchvision.models import resnet101 | |
from torch.utils.data import Dataset | |
import numpy as np | |
import cv2 | |
import unicodedata | |
import gradio as gr | |
import pytesseract | |
import pandas as pd | |
""" | |
Data preproc functions: | |
adjust_to_see: adjust image to better visualize (rotate and transpose) | |
augmentation: apply variations to a list of images | |
normalization: apply normalization and variations on images (if required) | |
preprocess: main function for preprocess. | |
Make the image: | |
illumination_compensation: apply illumination regularitation | |
remove_cursive_style: remove cursive style from image (if necessary) | |
sauvola: apply sauvola binarization | |
text_standardize: preprocess and standardize sentence | |
""" | |
import re | |
import os | |
import cv2 | |
import html | |
import string | |
import numpy as np | |
import numba as nb | |
def adjust_to_see(img): | |
"""Rotate and transpose to image visualize (cv2 method or jupyter notebook)""" | |
(h, w) = img.shape[:2] | |
(cX, cY) = (w // 2, h // 2) | |
M = cv2.getRotationMatrix2D((cX, cY), -90, 1.0) | |
cos = np.abs(M[0, 0]) | |
sin = np.abs(M[0, 1]) | |
nW = int((h * sin) + (w * cos)) | |
nH = int((h * cos) + (w * sin)) | |
M[0, 2] += (nW / 2) - cX | |
M[1, 2] += (nH / 2) - cY | |
img = cv2.warpAffine(img, M, (nW + 1, nH + 1)) | |
img = cv2.warpAffine(img.transpose(), M, (nW, nH)) | |
return img | |
def augmentation(imgs, | |
rotation_range=0, | |
scale_range=0, | |
height_shift_range=0, | |
width_shift_range=0, | |
dilate_range=1, | |
erode_range=1): | |
"""Apply variations to a list of images (rotate, width and height shift, scale, erode, dilate)""" | |
imgs = imgs.astype(np.float32) | |
_, h, w = imgs.shape | |
dilate_kernel = np.ones((int(np.random.uniform(1, dilate_range)),), np.uint8) | |
erode_kernel = np.ones((int(np.random.uniform(1, erode_range)),), np.uint8) | |
height_shift = np.random.uniform(-height_shift_range, height_shift_range) | |
rotation = np.random.uniform(-rotation_range, rotation_range) | |
scale = np.random.uniform(1 - scale_range, 1) | |
width_shift = np.random.uniform(-width_shift_range, width_shift_range) | |
trans_map = np.float32([[1, 0, width_shift * w], [0, 1, height_shift * h]]) | |
rot_map = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) | |
trans_map_aff = np.r_[trans_map, [[0, 0, 1]]] | |
rot_map_aff = np.r_[rot_map, [[0, 0, 1]]] | |
affine_mat = rot_map_aff.dot(trans_map_aff)[:2, :] | |
for i in range(len(imgs)): | |
imgs[i] = cv2.warpAffine(imgs[i], affine_mat, (w, h), flags=cv2.INTER_NEAREST, borderValue=255) | |
imgs[i] = cv2.erode(imgs[i], erode_kernel, iterations=1) | |
imgs[i] = cv2.dilate(imgs[i], dilate_kernel, iterations=1) | |
return imgs | |
def normalization(img): | |
"""Normalize list of image""" | |
m, s = cv2.meanStdDev(img) | |
img = img - m[0][0] | |
img = img / s[0][0] if s[0][0] > 0 else img | |
return img | |
def preprocess(img, input_size): | |
"""Make the process with the `input_size` to the scale resize""" | |
def imread(path): | |
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
if len(img.shape) == 3: | |
if img.shape[2] == 4: | |
trans_mask = img[:, :, 3] == 0 | |
img[trans_mask] = [255, 255, 255, 255] | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
return img | |
if isinstance(img, str): | |
img = imread(img) | |
if isinstance(img, tuple): | |
image, boundbox = img | |
img = imread(image) | |
for i in range(len(boundbox)): | |
if isinstance(boundbox[i], float): | |
total = len(img) if i < 2 else len(img[0]) | |
boundbox[i] = int(total * boundbox[i]) | |
img = np.asarray(img[boundbox[0]:boundbox[1], boundbox[2]:boundbox[3]], dtype=np.uint8) | |
wt, ht, _ = input_size | |
h, w = np.asarray(img).shape | |
f = max((w / wt), (h / ht)) | |
new_size = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) | |
img = illumination_compensation(img) | |
img = remove_cursive_style(img) | |
img = cv2.resize(img, new_size) | |
target = np.ones([ht, wt], dtype=np.uint8) * 255 | |
target[0:new_size[1], 0:new_size[0]] = img | |
img = cv2.transpose(target) | |
return img | |
def illumination_compensation(img, only_cei=False): | |
"""Illumination compensation technique for text image""" | |
_, binary = cv2.threshold(img, 254, 255, cv2.THRESH_BINARY) | |
if np.sum(binary) > np.sum(img) * 0.8: | |
return np.asarray(img, dtype=np.uint8) | |
def scale(img): | |
s = np.max(img) - np.min(img) | |
res = img / s | |
res -= np.min(res) | |
res *= 255 | |
return res | |
img = img.astype(np.float32) | |
height, width = img.shape | |
sqrt_hw = np.sqrt(height * width) | |
bins = np.arange(0, 300, 10) | |
bins[26] = 255 | |
hp = np.histogram(img, bins) | |
for i in range(len(hp[0])): | |
if hp[0][i] > sqrt_hw: | |
hr = i * 10 | |
break | |
np.seterr(divide='ignore', invalid='ignore') | |
cei = (img - (hr + 50 * 0.3)) * 2 | |
cei[cei > 255] = 255 | |
cei[cei < 0] = 0 | |
if only_cei: | |
return np.asarray(cei, dtype=np.uint8) | |
m1 = np.asarray([-1, 0, 1, -2, 0, 2, -1, 0, 1]).reshape((3, 3)) | |
m2 = np.asarray([-2, -1, 0, -1, 0, 1, 0, 1, 2]).reshape((3, 3)) | |
m3 = np.asarray([-1, -2, -1, 0, 0, 0, 1, 2, 1]).reshape((3, 3)) | |
m4 = np.asarray([0, 1, 2, -1, 0, 1, -2, -1, 0]).reshape((3, 3)) | |
eg1 = np.abs(cv2.filter2D(img, -1, m1)) | |
eg2 = np.abs(cv2.filter2D(img, -1, m2)) | |
eg3 = np.abs(cv2.filter2D(img, -1, m3)) | |
eg4 = np.abs(cv2.filter2D(img, -1, m4)) | |
eg_avg = scale((eg1 + eg2 + eg3 + eg4) / 4) | |
h, w = eg_avg.shape | |
eg_bin = np.zeros((h, w)) | |
eg_bin[eg_avg >= 30] = 255 | |
h, w = cei.shape | |
cei_bin = np.zeros((h, w)) | |
cei_bin[cei >= 60] = 255 | |
h, w = eg_bin.shape | |
tli = 255 * np.ones((h, w)) | |
tli[eg_bin == 255] = 0 | |
tli[cei_bin == 255] = 0 | |
kernel = np.ones((3, 3), np.uint8) | |
erosion = cv2.erode(tli, kernel, iterations=1) | |
int_img = np.asarray(cei) | |
estimate_light_distribution(width, height, erosion, cei, int_img) | |
mean_filter = 1 / 121 * np.ones((11, 11), np.uint8) | |
ldi = cv2.filter2D(scale(int_img), -1, mean_filter) | |
result = np.divide(cei, ldi) * 260 | |
result[erosion != 0] *= 1.5 | |
result[result < 0] = 0 | |
result[result > 255] = 255 | |
return np.asarray(result, dtype=np.uint8) | |
def estimate_light_distribution(width, height, erosion, cei, int_img): | |
"""Light distribution performed by numba (thanks @Sundrops)""" | |
for y in range(width): | |
for x in range(height): | |
if erosion[x][y] == 0: | |
i = x | |
while i < erosion.shape[0] and erosion[i][y] == 0: | |
i += 1 | |
end = i - 1 | |
n = end - x + 1 | |
if n <= 30: | |
h, e = [], [] | |
for k in range(5): | |
if x - k >= 0: | |
h.append(cei[x - k][y]) | |
if end + k < cei.shape[0]: | |
e.append(cei[end + k][y]) | |
mpv_h, mpv_e = max(h), max(e) | |
for m in range(n): | |
int_img[x + m][y] = mpv_h + (m + 1) * ((mpv_e - mpv_h) / n) | |
x = end | |
break | |
def remove_cursive_style(img): | |
"""Remove cursive writing style from image with deslanting algorithm""" | |
def calc_y_alpha(vec): | |
indices = np.where(vec > 0)[0] | |
h_alpha = len(indices) | |
if h_alpha > 0: | |
delta_y_alpha = indices[h_alpha - 1] - indices[0] + 1 | |
if h_alpha == delta_y_alpha: | |
return h_alpha * h_alpha | |
return 0 | |
alpha_vals = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0] | |
rows, cols = img.shape | |
results = [] | |
ret, otsu = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
binary = otsu if ret < 127 else sauvola(img, (int(img.shape[0] / 2), int(img.shape[0] / 2)), 127, 1e-2) | |
for alpha in alpha_vals: | |
shift_x = max(-alpha * rows, 0.) | |
size = (cols + int(np.ceil(abs(alpha * rows))), rows) | |
transform = np.asarray([[1, alpha, shift_x], [0, 1, 0]], dtype=np.float32) | |
shear_img = cv2.warpAffine(binary, transform, size, cv2.INTER_NEAREST) | |
sum_alpha = 0 | |
sum_alpha += np.apply_along_axis(calc_y_alpha, 0, shear_img) | |
results.append([np.sum(sum_alpha), size, transform]) | |
result = sorted(results, key=lambda x: x[0], reverse=True)[0] | |
result = cv2.warpAffine(img, result[2], result[1], borderValue=255) | |
result = cv2.resize(result, dsize=(cols, rows)) | |
return np.asarray(result, dtype=np.uint8) | |
def sauvola(img, window, thresh, k): | |
"""Sauvola binarization""" | |
rows, cols = img.shape | |
pad = int(np.floor(window[0] / 2)) | |
sum2, sqsum = cv2.integral2( | |
cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_CONSTANT)) | |
isum = sum2[window[0]:rows + window[0], window[1]:cols + window[1]] + \ | |
sum2[0:rows, 0:cols] - \ | |
sum2[window[0]:rows + window[0], 0:cols] - \ | |
sum2[0:rows, window[1]:cols + window[1]] | |
isqsum = sqsum[window[0]:rows + window[0], window[1]:cols + window[1]] + \ | |
sqsum[0:rows, 0:cols] - \ | |
sqsum[window[0]:rows + window[0], 0:cols] - \ | |
sqsum[0:rows, window[1]:cols + window[1]] | |
ksize = window[0] * window[1] | |
mean = isum / ksize | |
std = (((isqsum / ksize) - (mean**2) / ksize) / ksize) ** 0.5 | |
threshold = (mean * (1 + k * (std / thresh - 1))) * (mean >= 100) | |
return np.asarray(255 * (img >= threshold), 'uint8') | |
RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) | |
RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format( | |
chr(768), chr(769), chr(832), chr(833), chr(2387), | |
chr(5151), chr(5152), chr(65344), chr(8242)), re.UNICODE) | |
RE_RESERVED_CHAR_FILTER = re.compile(r'[¶¤«»]', re.UNICODE) | |
RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) | |
RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) | |
RE_BASIC_CLEANER = re.compile(r'[^\w\s{}]'.format(re.escape(string.punctuation)), re.UNICODE) | |
LEFT_PUNCTUATION_FILTER = """!%&),.:;<=>?@\\]^_`|}~""" | |
RIGHT_PUNCTUATION_FILTER = """"(/<=>@[\\^_`{|~""" | |
NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) | |
def text_standardize(text): | |
"""Organize/add spaces around punctuation marks""" | |
if text is None: | |
return "" | |
text = html.unescape(text).replace("\\n", "").replace("\\t", "") | |
text = RE_RESERVED_CHAR_FILTER.sub("", text) | |
text = RE_DASH_FILTER.sub("-", text) | |
text = RE_APOSTROPHE_FILTER.sub("'", text) | |
text = RE_LEFT_PARENTH_FILTER.sub("(", text) | |
text = RE_RIGHT_PARENTH_FILTER.sub(")", text) | |
text = RE_BASIC_CLEANER.sub("", text) | |
text = text.lstrip(LEFT_PUNCTUATION_FILTER) | |
text = text.rstrip(RIGHT_PUNCTUATION_FILTER) | |
text = text.translate(str.maketrans({c: f" {c} " for c in string.punctuation})) | |
text = NORMALIZE_WHITESPACE_REGEX.sub(" ", text.strip()) | |
return text | |
class Tokenizer(): | |
"""Manager tokens functions and charset/dictionary properties""" | |
def __init__(self, chars, max_text_length=128): | |
self.PAD_TK, self.UNK_TK,self.SOS,self.EOS = "¶", "¤", "SOS", "EOS" | |
self.chars = [self.PAD_TK] + [self.UNK_TK ]+ [self.SOS] + [self.EOS] +list(chars) | |
self.PAD = self.chars.index(self.PAD_TK) | |
self.UNK = self.chars.index(self.UNK_TK) | |
self.vocab_size = len(self.chars) | |
self.maxlen = max_text_length | |
def encode(self, text): | |
"""Encode text to vector""" | |
text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII") | |
text = " ".join(text.split()) | |
groups = ["".join(group) for _, group in groupby(text)] | |
text = "".join([self.UNK_TK.join(list(x)) if len(x) > 1 else x for x in groups]) | |
encoded = [] | |
text = ['SOS'] + list(text) + ['EOS'] | |
for item in text: | |
index = self.chars.index(item) | |
index = self.UNK if index == -1 else index | |
encoded.append(index) | |
return np.asarray(encoded) | |
def decode(self, text): | |
"""Decode vector to text""" | |
decoded = "".join([self.chars[int(x)] for x in text if x > -1]) | |
decoded = self.remove_tokens(decoded) | |
decoded = text_standardize(decoded) | |
return decoded | |
def remove_tokens(self, text): | |
"""Remove tokens (PAD) from text""" | |
return text.replace(self.PAD_TK, "").replace(self.UNK_TK, "") | |
charset_base = string.printable[:95] | |
tokenizer = Tokenizer(charset_base) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_memory(model,imgs): | |
x = model.conv(model.get_feature(imgs)) | |
bs,_,H, W = x.shapex | |
pos = torch.cat([ | |
model.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), | |
model.row_embed[:H].unsqueeze(1).repeat(1, W, 1), | |
], dim=-1).flatten(0, 1).unsqueeze(1) | |
return model.transformer.encoder(pos + 0.1 * x.flatten(2).permute(2, 0, 1)) | |
def test(model, test_loader, max_text_length): | |
model.eval() | |
predicts = [] | |
gt = [] | |
imgs = [] | |
with torch.no_grad(): | |
for batch in test_loader: | |
src, trg = batch | |
imgs.append(src.flatten(0,1)) | |
src, trg = src.cuda(), trg.cuda() | |
memory = get_memory(model,src.float()) | |
out_indexes = [tokenizer.chars.index('SOS'), ] | |
for i in range(max_text_length): | |
mask = model.generate_square_subsequent_mask(i+1).to('cuda') | |
trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device) | |
output = model.vocab(model.transformer.decoder(model.query_pos(model.decoder(trg_tensor)), memory,tgt_mask=mask)) | |
out_token = output.argmax(2)[-1].item() | |
out_indexes.append(out_token) | |
if out_token == tokenizer.chars.index('EOS'): | |
break | |
predicts.append(tokenizer.decode(out_indexes)) | |
gt.append(tokenizer.decode(trg.flatten(0,1))) | |
return predicts, gt, imgs | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=128): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:x.size(0), :] | |
return self.dropout(x) | |
class OCR(nn.Module): | |
def __init__(self, vocab_len, hidden_dim, nheads, | |
num_encoder_layers, num_decoder_layers): | |
super().__init__() | |
# create ResNet-101 backbone | |
self.backbone = resnet101() | |
del self.backbone.fc | |
# create conversion layer | |
self.conv = nn.Conv2d(2048, hidden_dim, 1) | |
# create a default PyTorch transformer | |
self.transformer = nn.Transformer( | |
hidden_dim, nheads, num_encoder_layers, num_decoder_layers) | |
# prediction heads with length of vocab | |
# DETR used basic 3 layer MLP for output | |
self.vocab = nn.Linear(hidden_dim,vocab_len) | |
# output positional encodings (object queries) | |
self.decoder = nn.Embedding(vocab_len, hidden_dim) | |
self.query_pos = PositionalEncoding(hidden_dim, .2) | |
# spatial positional encodings, sine positional encoding can be used. | |
# Detr baseline uses sine positional encoding. | |
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) | |
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) | |
self.trg_mask = None | |
def generate_square_subsequent_mask(self, sz): | |
mask = torch.triu(torch.ones(sz, sz), 1) | |
mask = mask.masked_fill(mask==1, float('-inf')) | |
return mask | |
def get_feature(self,x): | |
x = self.backbone.conv1(x) | |
x = self.backbone.bn1(x) | |
x = self.backbone.relu(x) | |
x = self.backbone.maxpool(x) | |
x = self.backbone.layer1(x) | |
x = self.backbone.layer2(x) | |
x = self.backbone.layer3(x) | |
x = self.backbone.layer4(x) | |
return x | |
def make_len_mask(self, inp): | |
return (inp == 0).transpose(0, 1) | |
def forward(self, inputs, trg): | |
# propagate inputs through ResNet-101 up to avg-pool layer | |
x = self.get_feature(inputs) | |
# convert from 2048 to 256 feature planes for the transformer | |
h = self.conv(x) | |
# construct positional encodings | |
bs,_,H, W = h.shape | |
pos = torch.cat([ | |
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), | |
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), | |
], dim=-1).flatten(0, 1).unsqueeze(1) | |
# generating subsequent mask for target | |
if self.trg_mask is None or self.trg_mask.size(0) != len(trg): | |
self.trg_mask = self.generate_square_subsequent_mask(trg.shape[1]).to(trg.device) | |
# Padding mask | |
trg_pad_mask = self.make_len_mask(trg) | |
# Getting postional encoding for target | |
trg = self.decoder(trg) | |
trg = self.query_pos(trg) | |
output = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), trg.permute(1,0,2), tgt_mask=self.trg_mask, | |
tgt_key_padding_mask=trg_pad_mask.permute(1,0)) | |
return self.vocab(output.transpose(0,1)) | |
def make_model(vocab_len, hidden_dim=256, nheads=4, | |
num_encoder_layers=4, num_decoder_layers=4): | |
return OCR(vocab_len, hidden_dim, nheads, | |
num_encoder_layers, num_decoder_layers) | |
class DataGenerator_Spanish(Dataset): | |
def __init__(self, source_dict, charset, max_text_length, transform, shuffle = True): | |
self.tokenizer = Tokenizer(charset, max_text_length) | |
self.transform = transform | |
self.shuffle = shuffle | |
self.dataset = source_dict.copy() | |
if self.shuffle: | |
randomize = np.arange(len(self.dataset['gt'])) | |
np.random.seed(42) | |
np.random.shuffle(randomize) | |
self.dataset['dt'] = np.array(self.dataset['dt'])[randomize] | |
self.dataset['gt'] = np.array(self.dataset['gt'])[randomize] | |
self.dataset['gt'] = [x.decode() for x in self.dataset['gt']] | |
self.size = len(self.dataset['gt']) | |
def __getitem__(self, i): | |
img = self.dataset['dt'][i] | |
img = np.repeat(img[..., np.newaxis], 3, -1) | |
img = normalization(img) | |
if self.transform is not None: | |
img = self.transform(img) | |
y_train = self.tokenizer.encode(self.dataset['gt'][i]) | |
y_train = np.pad(y_train, (0, self.tokenizer.maxlen - len(y_train))) | |
gt = torch.Tensor(y_train) | |
return img, gt | |
def __len__(self): | |
return self.size | |
def crop_dict(page): | |
# page = cv2.imread(page) | |
master_page_par_line_list = [] | |
image = cv2.cvtColor(np.array(page), cv2.COLOR_RGB2GRAY) | |
_, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) | |
data = pytesseract.image_to_data(image,config='--oem 3 --psm 6', output_type='dict') | |
page_num = 1 | |
df = pd.DataFrame(data) | |
df = df[df["conf"] > 0] | |
df["page_num"] = page_num | |
page_par_line_dict = {} | |
for index, row in df.iterrows(): | |
page_par_line = f"{page_num}_{row['par_num']}_{row['line_num']}" | |
if(page_par_line not in page_par_line_dict): | |
page_par_line_dict[page_par_line] = {"text": str(row["text"]) + " ", "box": (row['left'], row['top'], row['left'] + row['width'], row['top'] + row['height'])} | |
else: | |
page_par_line_dict[page_par_line]["text"] = page_par_line_dict[page_par_line]["text"] + str(row["text"]) + " " | |
page_par_line_dict[page_par_line]['box'] = (min(page_par_line_dict[page_par_line]['box'][0], row['left']), | |
min(page_par_line_dict[page_par_line]['box'][1], row['top']), | |
max(page_par_line_dict[page_par_line]['box'][2], row['left'] + row['width']), | |
max(page_par_line_dict[page_par_line]['box'][3], row['top'] + row['height'])) | |
for entry in page_par_line_dict: | |
splitted_key = entry.split('_') | |
entry_value = page_par_line_dict[entry] | |
master_page_par_line_list.append({ | |
'page_number' : splitted_key[0], | |
'paragraph_number' : splitted_key[1], | |
'line_number' : splitted_key[2], | |
'entry_text' : entry_value['text'], | |
'bounding_box' : entry_value['box'] | |
}) | |
imgs_cropped = {} | |
img_text_dict = {"dt" : [], "gt" : []} | |
for line in page_par_line_dict.values(): | |
if line['box'] is not None: | |
cv2.rectangle(image, (line['box'][0], line['box'][1]), (line['box'][2], line['box'][3]), (0, 0, 255), 2) | |
img_cropped = image[line['box'][1]:line['box'][3], line['box'][0]:line['box'][2]] | |
if not os.path.exists('cropped_lines'): | |
os.makedirs('cropped_lines') | |
cv2.imwrite(f"cropped_lines/{line['box'][1]}.jpg", img_cropped) | |
# print(line['text']) | |
imgs_cropped[line['box'][1]] = img_cropped | |
assert os.path.exists(f'cropped_lines/{line["box"][1]}.jpg') | |
img_text_dict["dt"].append(preprocess(f"cropped_lines/{line['box'][1]}.jpg",(1024,128,1))) | |
img_text_dict["gt"].append(line['text'].encode()) | |
return img_text_dict | |
#inference | |
def generate(img_path): | |
pretrained_model = make_model(vocab_len = 100) | |
_=pretrained_model.to(device) | |
pretrained_model.load_state_dict(torch.load('span_fine_tuned_model.pt', map_location=torch.device('cpu'))) | |
max_text_length = 128 | |
transform = T.Compose([T.ToTensor()]) | |
sp_loader = torch.utils.data.DataLoader(DataGenerator_Spanish(crop_dict(img_path),charset_base,max_text_length ,transform, shuffle=False), batch_size=1, shuffle=False, num_workers=2) | |
predicts2, gt2, imgs = test(pretrained_model, sp_loader, max_text_length) | |
predicts2 = list(map(lambda x : x.replace('SOS','').replace('EOS',''),predicts2)) | |
final_pred_str = "" | |
for s in predicts2: | |
final_pred_str += s+"\n" | |
return final_pred_str | |
gr.Interface(fn=generate, | |
inputs=[gr.Image(label='Input image')], | |
outputs=[gr.Textbox(label='Read Text')], | |
allow_flagging='never', | |
title='Transformer Text Detection - HumanAI', | |
theme=gr.themes.Monochrome()).launch() | |