Spaces:
Build error
Build error
import streamlit as st | |
from PIL import Image | |
import os | |
import TDTSR | |
import pytesseract | |
from pytesseract import Output | |
import postprocess as pp | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import cv2 | |
import numpy as np | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from cv2 import dnn_superres | |
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' | |
st.set_option('deprecation.showPyplotGlobalUse', False) | |
st.set_page_config(layout='wide') | |
st.title("Table Detection and Table Structure Recognition") | |
c1, c2, c3 = st.columns((1,1,1)) | |
def PIL_to_cv(pil_img): | |
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
def cv_to_PIL(cv_img): | |
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
def pytess(cell_pil_img): | |
return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='preserve_interword_spaces')['text']).strip() | |
def TrOCR(cell_pil_img): | |
processor = TrOCRProcessor.from_pretrained("SalML/trocr-base-printed") | |
model = VisionEncoderDecoderModel.from_pretrained("SalML/trocr-base-printed") | |
pixel_values = processor(images=cell_pil_img, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
def super_res(pil_img): | |
# requires opencv-contrib-python installed without the opencv-python | |
sr = dnn_superres.DnnSuperResImpl_create() | |
image = PIL_to_cv(pil_img) | |
model_path = "./LapSRN_x8.pb" | |
model_name = model_path.split('/')[1].split('_')[0].lower() | |
model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1]) | |
sr.readModel(model_path) | |
sr.setModel(model_name, model_scale) | |
final_img = sr.upsample(image) | |
final_img = cv_to_PIL(final_img) | |
return final_img | |
def sharpen_image(pil_img): | |
img = PIL_to_cv(pil_img) | |
sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) | |
# sharpen_kernel = np.array([[0, -1, 0], | |
# [-1, 5,-1], | |
# [0, -1, 0]]) | |
sharpen = cv2.filter2D(img, -1, sharpen_kernel) | |
pil_img = cv_to_PIL(sharpen) | |
return pil_img | |
def preprocess_magic(pil_img): | |
cv_img = PIL_to_cv(pil_img) | |
grayscale_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) | |
_, binary_image = cv2.threshold(grayscale_image, 0, 255, cv2.THRESH_OTSU) | |
count_white = np.sum(binary_image > 0) | |
count_black = np.sum(binary_image == 0) | |
if count_black > count_white: | |
binary_image = 255 - binary_image | |
black_text_white_background_image = binary_image | |
return cv_to_PIL(black_text_white_background_image) | |
### main code: | |
for td_sample in os.listdir('D:/Jupyter/Multi-Type-TD-TSR/TD_samples/'): | |
image = Image.open("D:/Jupyter/Multi-Type-TD-TSR/TD_samples/"+td_sample).convert("RGB") | |
model, image, probas, bboxes_scaled = TDTSR.table_detector(image, THRESHOLD_PROBA=0.6) | |
TDTSR.plot_results_detection(c1, model, image, probas, bboxes_scaled) | |
cropped_img_list = TDTSR.plot_table_detection(c2, model, image, probas, bboxes_scaled) | |
for unpadded_table in cropped_img_list: | |
# table : pil_img | |
table = TDTSR.add_margin(unpadded_table) | |
model, image, probas, bboxes_scaled = TDTSR.table_struct_recog(table, THRESHOLD_PROBA=0.6) | |
# The try, except block of code below plots table header row and simple rows | |
try: | |
rows, cols = TDTSR.plot_structure(c3, model, image, probas, bboxes_scaled, class_to_show=0) | |
rows, cols = TDTSR.sort_table_featuresv2(rows, cols) | |
# headers, rows, cols are ordered dictionaries with 5th element value of tuple being pil_img | |
rows, cols = TDTSR.individual_table_featuresv2(table, rows, cols) | |
# TDTSR.plot_table_features(c1, header, row_header, rows, cols) | |
except Exception as printableException: | |
st.write(td_sample, ' terminated with exception:', printableException) | |
# master_row = TDTSR.master_row_set(header, row_header, rows, cols) | |
master_row = rows | |
# cells_img = TDTSR.object_to_cells(master_row, cols) | |
cells_img = TDTSR.object_to_cellsv2(master_row, cols) | |
headers = [] | |
cells_list = [] | |
# st.write(cells_img) | |
for n, kv in enumerate(cells_img.items()): | |
k, row_images = kv | |
if n == 0: | |
for idx, header in enumerate(row_images): | |
# plt.imshow(header) | |
# c2.pyplot() | |
# c2.write(pytess(header)) | |
############################ | |
SR_img = super_res(header) | |
# # w, h = SR_img.size | |
# # SR_img = SR_img.crop((0 ,0 ,w, h-60)) | |
# plt.imshow(SR_img) | |
# c3.pyplot() | |
# c3.write(pytess(SR_img)) | |
header_text = pytess(SR_img) | |
if header_text == '': | |
header_text = 'empty_col'+str(idx) | |
headers.append(header_text) | |
else: | |
for cells in row_images: | |
# plt.imshow(cells) | |
# c2.pyplot() | |
# c2.write(pytess(cells)) | |
############################## | |
SR_img = super_res(cells) | |
# # w, h = SR_img.size | |
# # SR_img = SR_img.crop((0 ,0 ,w, h-60)) | |
# plt.imshow(SR_img) | |
# c3.pyplot() | |
# c3.write(pytess(SR_img)) | |
cells_list.append(pytess(SR_img)) | |
df = pd.DataFrame("", index=range(0, len(master_row)), columns=headers) | |
cell_idx = 0 | |
for nrows in range(len(master_row)-1): | |
for ncols in range(len(cols)): | |
df.iat[nrows, ncols] = cells_list[cell_idx] | |
cell_idx += 1 | |
c3.dataframe(df) | |
# break | |