rishiraj commited on
Commit
d32f72c
1 Parent(s): b1577c2

Create utils/tt_module.py

Browse files
Files changed (1) hide show
  1. pdf-extractor/utils/tt_module.py +230 -0
pdf-extractor/utils/tt_module.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForObjectDetection
2
+ import torch
3
+ from pdf2image import convert_from_bytes
4
+ from torchvision import transforms
5
+ from transformers import TableTransformerForObjectDetection
6
+ import numpy as np
7
+ import easyocr
8
+ from tqdm.auto import tqdm
9
+
10
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
11
+ model.config.id2label
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model.to(device)
14
+ structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
15
+ structure_model.to(device)
16
+ reader = easyocr.Reader(['en'])
17
+
18
+ def pdf_to_img(pdf_path):
19
+ image_list = []
20
+ images = convert_from_bytes(pdf_path)
21
+ for i in range(len(images)):
22
+ image = images[i].convert("RGB")
23
+ image_list.append(image)
24
+ return image_list
25
+
26
+ class MaxResize(object):
27
+ def __init__(self, max_size=800):
28
+ self.max_size = max_size
29
+
30
+ def __call__(self, image):
31
+ width, height = image.size
32
+ current_max_size = max(width, height)
33
+ scale = self.max_size / current_max_size
34
+ resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
35
+
36
+ return resized_image
37
+
38
+ def box_cxcywh_to_xyxy(x):
39
+ x_c, y_c, w, h = x.unbind(-1)
40
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
41
+ return torch.stack(b, dim=1)
42
+
43
+ def rescale_bboxes(out_bbox, size):
44
+ img_w, img_h = size
45
+ b = box_cxcywh_to_xyxy(out_bbox)
46
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
47
+ return b
48
+
49
+ def outputs_to_objects(outputs, img_size, id2label):
50
+ m = outputs.logits.softmax(-1).max(-1)
51
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
52
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
53
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
54
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
55
+
56
+ objects = []
57
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
58
+ class_label = id2label[int(label)]
59
+ if not class_label == 'no object':
60
+ objects.append({'label': class_label, 'score': float(score),
61
+ 'bbox': [float(elem) for elem in bbox]})
62
+
63
+ return objects
64
+
65
+ def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
66
+ """
67
+ Process the bounding boxes produced by the table detection model into
68
+ cropped table images and cropped tokens.
69
+ """
70
+
71
+ table_crops = []
72
+ for obj in objects:
73
+ if obj['score'] < class_thresholds[obj['label']]:
74
+ continue
75
+
76
+ cropped_table = {}
77
+
78
+ bbox = obj['bbox']
79
+ bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
80
+
81
+ cropped_img = img.crop(bbox)
82
+
83
+ table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
84
+ for token in table_tokens:
85
+ token['bbox'] = [token['bbox'][0]-bbox[0],
86
+ token['bbox'][1]-bbox[1],
87
+ token['bbox'][2]-bbox[0],
88
+ token['bbox'][3]-bbox[1]]
89
+
90
+ # If table is predicted to be rotated, rotate cropped image and tokens/words:
91
+ if obj['label'] == 'table rotated':
92
+ cropped_img = cropped_img.rotate(270, expand=True)
93
+ for token in table_tokens:
94
+ bbox = token['bbox']
95
+ bbox = [cropped_img.size[0]-bbox[3]-1,
96
+ bbox[0],
97
+ cropped_img.size[0]-bbox[1]-1,
98
+ bbox[2]]
99
+ token['bbox'] = bbox
100
+
101
+ cropped_table['image'] = cropped_img
102
+ cropped_table['tokens'] = table_tokens
103
+
104
+ table_crops.append(cropped_table)
105
+
106
+ return table_crops
107
+
108
+ def get_cell_coordinates_by_row(table_data):
109
+ # Extract rows and columns
110
+ rows = [entry for entry in table_data if entry['label'] == 'table row']
111
+ columns = [entry for entry in table_data if entry['label'] == 'table column']
112
+
113
+ # Sort rows and columns by their Y and X coordinates, respectively
114
+ rows.sort(key=lambda x: x['bbox'][1])
115
+ columns.sort(key=lambda x: x['bbox'][0])
116
+
117
+ # Function to find cell coordinates
118
+ def find_cell_coordinates(row, column):
119
+ cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
120
+ return cell_bbox
121
+
122
+ # Generate cell coordinates and count cells in each row
123
+ cell_coordinates = []
124
+
125
+ for row in rows:
126
+ row_cells = []
127
+ for column in columns:
128
+ cell_bbox = find_cell_coordinates(row, column)
129
+ row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
130
+
131
+ # Sort cells in the row by X coordinate
132
+ row_cells.sort(key=lambda x: x['column'][0])
133
+
134
+ # Append row information to cell_coordinates
135
+ cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
136
+
137
+ # Sort rows from top to bottom
138
+ cell_coordinates.sort(key=lambda x: x['row'][1])
139
+
140
+ return cell_coordinates
141
+
142
+ def apply_ocr(cell_coordinates, cropped_table):
143
+ # let's OCR row by row
144
+ data = dict()
145
+ max_num_columns = 0
146
+ for idx, row in enumerate(tqdm(cell_coordinates)):
147
+ row_text = []
148
+ for cell in row["cells"]:
149
+ # crop cell out of image
150
+ cell_image = np.array(cropped_table.crop(cell["cell"]))
151
+ # apply OCR
152
+ result = reader.readtext(np.array(cell_image))
153
+ if len(result) > 0:
154
+ # print([x[1] for x in list(result)])
155
+ text = " ".join([x[1] for x in result])
156
+ row_text.append(text)
157
+
158
+ if len(row_text) > max_num_columns:
159
+ max_num_columns = len(row_text)
160
+
161
+ data[idx] = row_text
162
+
163
+ print("Max number of columns:", max_num_columns)
164
+
165
+ # pad rows which don't have max_num_columns elements
166
+ # to make sure all rows have the same number of columns
167
+ for row, row_data in data.copy().items():
168
+ if len(row_data) != max_num_columns:
169
+ row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
170
+ data[row] = row_data
171
+
172
+ return data
173
+
174
+ def get_tables(pdf_path):
175
+ image_list = pdf_to_img(pdf_path)
176
+ data_dict = {}
177
+ for index, image in enumerate(image_list):
178
+ detection_transform = transforms.Compose([
179
+ MaxResize(800),
180
+ transforms.ToTensor(),
181
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
182
+ ])
183
+
184
+ pixel_values = detection_transform(image).unsqueeze(0)
185
+ pixel_values = pixel_values.to(device)
186
+
187
+ with torch.no_grad():
188
+ outputs = model(pixel_values)
189
+
190
+ id2label = model.config.id2label
191
+ id2label[len(model.config.id2label)] = "no object"
192
+
193
+ objects = outputs_to_objects(outputs, image.size, id2label)
194
+
195
+ tokens = []
196
+ detection_class_thresholds = {
197
+ "table": 0.5,
198
+ "table rotated": 0.5,
199
+ "no object": 10
200
+ }
201
+ crop_padding = 10
202
+
203
+ tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=0)
204
+
205
+ for table_index, table_crop in enumerate(tables_crops):
206
+ cropped_table = table_crop['image'].convert("RGB")
207
+
208
+ structure_transform = transforms.Compose([
209
+ MaxResize(1000),
210
+ transforms.ToTensor(),
211
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
212
+ ])
213
+
214
+ pixel_values = structure_transform(cropped_table).unsqueeze(0)
215
+ pixel_values = pixel_values.to(device)
216
+
217
+ with torch.no_grad():
218
+ outputs = structure_model(pixel_values)
219
+
220
+ structure_id2label = structure_model.config.id2label
221
+ structure_id2label[len(structure_id2label)] = "no object"
222
+
223
+ cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
224
+ if cells[0]['score'] > 0.95:
225
+ cell_coordinates = get_cell_coordinates_by_row(cells)
226
+
227
+ data = apply_ocr(cell_coordinates, cropped_table)
228
+ data_dict[f"{index+1}_{table_index+1}"] = data
229
+
230
+ return data_dict