File size: 10,892 Bytes
52f1bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
from collections import defaultdict
from copy import deepcopy
from typing import List, Dict
import torch
from PIL import Image

from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.schema import TableResult, TableCell, Bbox
from surya.settings import settings
from tqdm import tqdm
import numpy as np
from surya.model.table_rec.config import SPECIAL_TOKENS


def get_batch_size():
    batch_size = settings.TABLE_REC_BATCH_SIZE
    if batch_size is None:
        batch_size = 8
        if settings.TORCH_DEVICE_MODEL == "mps":
            batch_size = 8
        if settings.TORCH_DEVICE_MODEL == "cuda":
            batch_size = 64
    return batch_size


def sort_bboxes(bboxes, tolerance=1):
    vertical_groups = {}
    for block in bboxes:
        group_key = round(block["bbox"][1] / tolerance) * tolerance
        if group_key not in vertical_groups:
            vertical_groups[group_key] = []
        vertical_groups[group_key].append(block)

    # Sort each group horizontally and flatten the groups into a single list
    sorted_page_blocks = []
    for _, group in sorted(vertical_groups.items()):
        sorted_group = sorted(group, key=lambda x: x["bbox"][0])
        sorted_page_blocks.extend(sorted_group)

    return sorted_page_blocks


def is_rotated(rows, cols):
    # Determine if the table is rotated by looking at row and column width / height ratios
    # Rows should have a >1 ratio, cols <1
    widths = sum([r.width for r in rows])
    heights = sum([c.height for c in rows]) + 1
    r_ratio = widths / heights

    widths = sum([c.width for c in cols])
    heights = sum([r.height for r in cols]) + 1
    c_ratio = widths / heights

    return r_ratio * 2 < c_ratio

def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]:
    assert all([isinstance(image, Image.Image) for image in images])
    assert len(images) == len(table_cells)
    if batch_size is None:
        batch_size = get_batch_size()

    output_order = []
    for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"):
        batch_table_cells = deepcopy(table_cells[i:i+batch_size])
        batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in
        batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells]

        batch_images = images[i:i+batch_size]
        batch_images = [image.convert("RGB") for image in batch_images]  # also copies the images
        current_batch_size = len(batch_images)

        orig_sizes = [image.size for image in batch_images]
        model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes))

        batch_pixel_values = model_inputs["pixel_values"]
        batch_bboxes = model_inputs["input_boxes"]
        batch_bbox_mask = model_inputs["input_boxes_mask"]
        batch_bbox_counts = model_inputs["input_boxes_counts"]

        batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
        batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
        batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
        batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)

        # Setup inputs for the decoder
        batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)]
        batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
        inference_token_count = batch_decoder_input.shape[1]

        max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES)
        decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
        model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
        model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)

        batch_predictions = [[] for _ in range(current_batch_size)]

        with torch.inference_mode():
            encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
            text_encoder_hidden_states = model.text_encoder(
                input_boxes=batch_bboxes,
                input_boxes_counts=batch_bbox_counts,
                cache_position=None,
                attention_mask=batch_bbox_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=None,
                use_cache=False
            ).hidden_states

            token_count = 0
            all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)

            while token_count < max_tokens:
                is_prefill = token_count == 0
                return_dict = model.decoder(
                    input_ids=batch_decoder_input,
                    encoder_hidden_states=text_encoder_hidden_states,
                    cache_position=decoder_position_ids,
                    use_cache=True,
                    prefill=is_prefill
                )

                decoder_position_ids = decoder_position_ids[-1:] + 1
                box_logits = return_dict["bbox_logits"][:, -1, :].detach()
                rowcol_logits = return_dict["class_logits"][:, -1, :].detach()

                rowcol_preds = torch.argmax(rowcol_logits, dim=-1)
                box_preds = torch.argmax(box_logits, dim=-1)

                done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id)
                done = done
                all_done = all_done | done

                if all_done.all():
                    break

                batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1)

                for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
                    if not status:
                        batch_predictions[j].append(pred[0].tolist())

                token_count += inference_token_count
                inference_token_count = batch_decoder_input.shape[1]

        for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)):
            img_w, img_h = orig_size
            width_scaler = img_w / model.config.decoder.out_box_size
            height_scaler = img_h / model.config.decoder.out_box_size

            # cx, cy to corners
            for i, pred in enumerate(preds):
                w = pred[2] / 2
                h = pred[3] / 2
                x1 = pred[0] - w
                y1 = pred[1] - h
                x2 = pred[0] + w
                y2 = pred[1] + h
                class_ = int(pred[4] - SPECIAL_TOKENS)

                preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_]

            # Get rows and columns
            bb_rows = [p[:4] for p in preds if p[4] == 0]
            bb_cols = [p[:4] for p in preds if p[4] == 1]

            rows = []
            cols = []
            for row_idx, row in enumerate(bb_rows):
                cell = TableCell(
                    bbox=row,
                    row_id=row_idx
                )
                rows.append(cell)

            for col_idx, col in enumerate(bb_cols):
                cell = TableCell(
                    bbox=col,
                    col_id=col_idx,
                )
                cols.append(cell)

            # Assign cells to rows/columns
            cells = []
            for cell in input_cells:
                max_intersection = 0
                row_pred = None
                for row_idx, row in enumerate(rows):
                    intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row)
                    if intersection_pct > max_intersection:
                        max_intersection = intersection_pct
                        row_pred = row_idx

                max_intersection = 0
                col_pred = None
                for col_idx, col in enumerate(cols):
                    intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col)
                    if intersection_pct > max_intersection:
                        max_intersection = intersection_pct
                        col_pred = col_idx

                cells.append(
                    TableCell(
                        bbox=cell["bbox"],
                        text=cell.get("text"),
                        row_id=row_pred,
                        col_id=col_pred
                    )
                )

            rotated = is_rotated(rows, cols)
            for cell in cells:
                if cell.row_id is None:
                    closest_row = None
                    closest_row_dist = None
                    for cell2 in cells:
                        if cell2.row_id is None:
                            continue
                        if rotated:
                            cell_y_center = cell.center[0]
                            cell2_y_center = cell2.center[0]
                        else:
                            cell_y_center = cell.center[1]
                            cell2_y_center = cell2.center[1]
                        y_dist = abs(cell_y_center - cell2_y_center)
                        if closest_row_dist is None or y_dist < closest_row_dist:
                            closest_row = cell2.row_id
                            closest_row_dist = y_dist
                    cell.row_id = closest_row

                if cell.col_id is None:
                    closest_col = None
                    closest_col_dist = None
                    for cell2 in cells:
                        if cell2.col_id is None:
                            continue
                        if rotated:
                            cell_x_center = cell.center[1]
                            cell2_x_center = cell2.center[1]
                        else:
                            cell_x_center = cell.center[0]
                            cell2_x_center = cell2.center[0]

                        x_dist = abs(cell2_x_center - cell_x_center)
                        if closest_col_dist is None or x_dist < closest_col_dist:
                            closest_col = cell2.col_id
                            closest_col_dist = x_dist

                    cell.col_id = closest_col

            result = TableResult(
                cells=cells,
                rows=rows,
                cols=cols,
                image_bbox=[0, 0, img_w, img_h],
            )

            output_order.append(result)

    return output_order