SalML commited on
Commit
b965b5d
1 Parent(s): bbef25b

initial commit of app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import os
4
+ import TDTSR
5
+ import pytesseract
6
+ from pytesseract import Output
7
+ import postprocess as pp
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import cv2
11
+ import numpy as np
12
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
13
+ from cv2 import dnn_superres
14
+
15
+ pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
16
+
17
+
18
+
19
+ st.set_option('deprecation.showPyplotGlobalUse', False)
20
+ st.set_page_config(layout='wide')
21
+ st.title("Table Detection and Table Structure Recognition")
22
+
23
+ c1, c2, c3 = st.columns((1,1,1))
24
+
25
+
26
+ def PIL_to_cv(pil_img):
27
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
28
+
29
+ def cv_to_PIL(cv_img):
30
+ return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
31
+
32
+ def pytess(cell_pil_img):
33
+ return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='preserve_interword_spaces')['text']).strip()
34
+
35
+ def TrOCR(cell_pil_img):
36
+
37
+ processor = TrOCRProcessor.from_pretrained("SalML/trocr-base-printed")
38
+ model = VisionEncoderDecoderModel.from_pretrained("SalML/trocr-base-printed")
39
+ pixel_values = processor(images=cell_pil_img, return_tensors="pt").pixel_values
40
+
41
+ generated_ids = model.generate(pixel_values)
42
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
+
44
+ return generated_text
45
+
46
+
47
+
48
+ def super_res(pil_img):
49
+ # requires opencv-contrib-python installed without the opencv-python
50
+ sr = dnn_superres.DnnSuperResImpl_create()
51
+ image = PIL_to_cv(pil_img)
52
+ model_path = "./LapSRN_x8.pb"
53
+ model_name = model_path.split('/')[1].split('_')[0].lower()
54
+ model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
55
+
56
+ sr.readModel(model_path)
57
+ sr.setModel(model_name, model_scale)
58
+ final_img = sr.upsample(image)
59
+ final_img = cv_to_PIL(final_img)
60
+
61
+ return final_img
62
+
63
+
64
+ def sharpen_image(pil_img):
65
+
66
+ img = PIL_to_cv(pil_img)
67
+ sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
68
+ # sharpen_kernel = np.array([[0, -1, 0],
69
+ # [-1, 5,-1],
70
+ # [0, -1, 0]])
71
+ sharpen = cv2.filter2D(img, -1, sharpen_kernel)
72
+ pil_img = cv_to_PIL(sharpen)
73
+ return pil_img
74
+
75
+
76
+ def preprocess_magic(pil_img):
77
+
78
+ cv_img = PIL_to_cv(pil_img)
79
+ grayscale_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
80
+ _, binary_image = cv2.threshold(grayscale_image, 0, 255, cv2.THRESH_OTSU)
81
+
82
+ count_white = np.sum(binary_image > 0)
83
+ count_black = np.sum(binary_image == 0)
84
+
85
+ if count_black > count_white:
86
+ binary_image = 255 - binary_image
87
+
88
+ black_text_white_background_image = binary_image
89
+
90
+ return cv_to_PIL(black_text_white_background_image)
91
+
92
+
93
+ ### main code:
94
+ for td_sample in os.listdir('D:/Jupyter/Multi-Type-TD-TSR/TD_samples/'):
95
+
96
+ image = Image.open("D:/Jupyter/Multi-Type-TD-TSR/TD_samples/"+td_sample).convert("RGB")
97
+ model, image, probas, bboxes_scaled = TDTSR.table_detector(image, THRESHOLD_PROBA=0.6)
98
+ TDTSR.plot_results_detection(c1, model, image, probas, bboxes_scaled)
99
+ cropped_img_list = TDTSR.plot_table_detection(c2, model, image, probas, bboxes_scaled)
100
+
101
+ for unpadded_table in cropped_img_list:
102
+ # table : pil_img
103
+ table = TDTSR.add_margin(unpadded_table)
104
+ model, image, probas, bboxes_scaled = TDTSR.table_struct_recog(table, THRESHOLD_PROBA=0.6)
105
+
106
+ # The try, except block of code below plots table header row and simple rows
107
+ try:
108
+ rows, cols = TDTSR.plot_structure(c3, model, image, probas, bboxes_scaled, class_to_show=0)
109
+ rows, cols = TDTSR.sort_table_featuresv2(rows, cols)
110
+ # headers, rows, cols are ordered dictionaries with 5th element value of tuple being pil_img
111
+ rows, cols = TDTSR.individual_table_featuresv2(table, rows, cols)
112
+ # TDTSR.plot_table_features(c1, header, row_header, rows, cols)
113
+ except Exception as printableException:
114
+ st.write(td_sample, ' terminated with exception:', printableException)
115
+
116
+ # master_row = TDTSR.master_row_set(header, row_header, rows, cols)
117
+ master_row = rows
118
+
119
+ # cells_img = TDTSR.object_to_cells(master_row, cols)
120
+ cells_img = TDTSR.object_to_cellsv2(master_row, cols)
121
+
122
+ headers = []
123
+ cells_list = []
124
+ # st.write(cells_img)
125
+ for n, kv in enumerate(cells_img.items()):
126
+ k, row_images = kv
127
+ if n == 0:
128
+ for idx, header in enumerate(row_images):
129
+ # plt.imshow(header)
130
+ # c2.pyplot()
131
+ # c2.write(pytess(header))
132
+ ############################
133
+ SR_img = super_res(header)
134
+ # # w, h = SR_img.size
135
+ # # SR_img = SR_img.crop((0 ,0 ,w, h-60))
136
+ # plt.imshow(SR_img)
137
+ # c3.pyplot()
138
+ # c3.write(pytess(SR_img))
139
+ header_text = pytess(SR_img)
140
+ if header_text == '':
141
+ header_text = 'empty_col'+str(idx)
142
+ headers.append(header_text)
143
+
144
+
145
+ else:
146
+ for cells in row_images:
147
+ # plt.imshow(cells)
148
+ # c2.pyplot()
149
+ # c2.write(pytess(cells))
150
+ ##############################
151
+ SR_img = super_res(cells)
152
+ # # w, h = SR_img.size
153
+ # # SR_img = SR_img.crop((0 ,0 ,w, h-60))
154
+ # plt.imshow(SR_img)
155
+ # c3.pyplot()
156
+ # c3.write(pytess(SR_img))
157
+ cells_list.append(pytess(SR_img))
158
+
159
+
160
+
161
+ df = pd.DataFrame("", index=range(0, len(master_row)), columns=headers)
162
+
163
+ cell_idx = 0
164
+
165
+ for nrows in range(len(master_row)-1):
166
+ for ncols in range(len(cols)):
167
+
168
+ df.iat[nrows, ncols] = cells_list[cell_idx]
169
+ cell_idx += 1
170
+
171
+ c3.dataframe(df)
172
+ # break
173
+