Jiangxz01 commited on
Commit
52f1bcb
·
verified ·
1 Parent(s): bbd7630

Upload 56 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +13 -0
  2. detect_layout.py +67 -0
  3. detect_text.py +81 -0
  4. ocr_app.py +257 -0
  5. ocr_text.py +98 -0
  6. pyproject.toml +59 -0
  7. reading_order.py +81 -0
  8. requirements.txt +5 -0
  9. scripts/verify_benchmark_scores.py +61 -0
  10. surya/benchmark/bbox.py +22 -0
  11. surya/benchmark/metrics.py +193 -0
  12. surya/benchmark/tatr.py +117 -0
  13. surya/benchmark/tesseract.py +179 -0
  14. surya/benchmark/util.py +31 -0
  15. surya/detection.py +144 -0
  16. surya/input/langs.py +19 -0
  17. surya/input/load.py +87 -0
  18. surya/input/pdflines.py +86 -0
  19. surya/input/processing.py +118 -0
  20. surya/languages.py +102 -0
  21. surya/layout.py +229 -0
  22. surya/model/detection/config.py +51 -0
  23. surya/model/detection/model.py +767 -0
  24. surya/model/detection/processor.py +284 -0
  25. surya/model/ordering/config.py +8 -0
  26. surya/model/ordering/decoder.py +557 -0
  27. surya/model/ordering/encoder.py +83 -0
  28. surya/model/ordering/encoderdecoder.py +90 -0
  29. surya/model/ordering/model.py +34 -0
  30. surya/model/ordering/processor.py +156 -0
  31. surya/model/recognition/config.py +348 -0
  32. surya/model/recognition/decoder.py +695 -0
  33. surya/model/recognition/encoder.py +852 -0
  34. surya/model/recognition/encoderdecoder.py +145 -0
  35. surya/model/recognition/model.py +49 -0
  36. surya/model/recognition/processor.py +206 -0
  37. surya/model/recognition/tokenizer.py +120 -0
  38. surya/model/table_rec/config.py +260 -0
  39. surya/model/table_rec/decoder.py +795 -0
  40. surya/model/table_rec/encoderdecoder.py +135 -0
  41. surya/model/table_rec/model.py +34 -0
  42. surya/model/table_rec/processor.py +248 -0
  43. surya/ocr.py +114 -0
  44. surya/ordering.py +141 -0
  45. surya/postprocessing/affinity.py +165 -0
  46. surya/postprocessing/fonts.py +24 -0
  47. surya/postprocessing/heatmap.py +224 -0
  48. surya/postprocessing/math/latex.py +125 -0
  49. surya/postprocessing/math/render.py +88 -0
  50. surya/postprocessing/text.py +118 -0
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import os
4
+
5
+
6
+ def run_app():
7
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
8
+ ocr_app_path = os.path.join(cur_dir, "ocr_app.py")
9
+ cmd = ["streamlit", "run", ocr_app_path]
10
+ subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
11
+
12
+ if __name__ == "__main__":
13
+ run_app()
detect_layout.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pypdfium2 # Causes a warning if not the top import
2
+ import argparse
3
+ import copy
4
+ import json
5
+ from collections import defaultdict
6
+
7
+ from surya.detection import batch_text_detection
8
+ from surya.input.load import load_from_folder, load_from_file
9
+ from surya.layout import batch_layout_detection
10
+ from surya.model.detection.model import load_model, load_processor
11
+ from surya.postprocessing.heatmap import draw_polys_on_image
12
+ from surya.settings import settings
13
+ import os
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).")
18
+ parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.")
19
+ parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
20
+ parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
21
+ parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
22
+ parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
23
+ args = parser.parse_args()
24
+
25
+ model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
26
+ processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
27
+ det_model = load_model()
28
+ det_processor = load_processor()
29
+
30
+ if os.path.isdir(args.input_path):
31
+ images, names, _ = load_from_folder(args.input_path, args.max)
32
+ folder_name = os.path.basename(args.input_path)
33
+ else:
34
+ images, names, _ = load_from_file(args.input_path, args.max)
35
+ folder_name = os.path.basename(args.input_path).split(".")[0]
36
+
37
+ line_predictions = batch_text_detection(images, det_model, det_processor)
38
+
39
+ layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
40
+ result_path = os.path.join(args.results_dir, folder_name)
41
+ os.makedirs(result_path, exist_ok=True)
42
+
43
+ if args.images:
44
+ for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
45
+ polygons = [p.polygon for p in layout_pred.bboxes]
46
+ labels = [p.label for p in layout_pred.bboxes]
47
+ bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
48
+ bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))
49
+
50
+ if args.debug:
51
+ heatmap = layout_pred.segmentation_map
52
+ heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png"))
53
+
54
+ predictions_by_page = defaultdict(list)
55
+ for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
56
+ out_pred = pred.model_dump(exclude=["segmentation_map"])
57
+ out_pred["page"] = len(predictions_by_page[name]) + 1
58
+ predictions_by_page[name].append(out_pred)
59
+
60
+ with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
61
+ json.dump(predictions_by_page, f, ensure_ascii=False)
62
+
63
+ print(f"Wrote results to {result_path}")
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
detect_text.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import time
5
+ from collections import defaultdict
6
+
7
+ from surya.input.load import load_from_folder, load_from_file
8
+ from surya.model.detection.model import load_model, load_processor
9
+ from surya.detection import batch_text_detection
10
+ from surya.postprocessing.affinity import draw_lines_on_image
11
+ from surya.postprocessing.heatmap import draw_polys_on_image
12
+ from surya.settings import settings
13
+ import os
14
+ from tqdm import tqdm
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
19
+ parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
20
+ parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
21
+ parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
22
+ parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
23
+ parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
24
+ args = parser.parse_args()
25
+
26
+ checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
27
+ model = load_model(checkpoint=checkpoint)
28
+ processor = load_processor(checkpoint=checkpoint)
29
+
30
+ if os.path.isdir(args.input_path):
31
+ images, names, _ = load_from_folder(args.input_path, args.max)
32
+ folder_name = os.path.basename(args.input_path)
33
+ else:
34
+ images, names, _ = load_from_file(args.input_path, args.max)
35
+ folder_name = os.path.basename(args.input_path).split(".")[0]
36
+
37
+ start = time.time()
38
+ predictions = batch_text_detection(images, model, processor)
39
+ result_path = os.path.join(args.results_dir, folder_name)
40
+ os.makedirs(result_path, exist_ok=True)
41
+ end = time.time()
42
+ if args.debug:
43
+ print(f"Detection took {end - start} seconds")
44
+
45
+ if args.images:
46
+ for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
47
+ polygons = [p.polygon for p in pred.bboxes]
48
+ bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
49
+ bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))
50
+
51
+ column_image = draw_lines_on_image(pred.vertical_lines, copy.deepcopy(image))
52
+ column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png"))
53
+
54
+ if args.debug:
55
+ heatmap = pred.heatmap
56
+ heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png"))
57
+
58
+ affinity_map = pred.affinity_map
59
+ affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png"))
60
+
61
+ predictions_by_page = defaultdict(list)
62
+ for idx, (pred, name, image) in enumerate(zip(predictions, names, images)):
63
+ out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
64
+ out_pred["page"] = len(predictions_by_page[name]) + 1
65
+ predictions_by_page[name].append(out_pred)
66
+
67
+ with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
68
+ json.dump(predictions_by_page, f, ensure_ascii=False)
69
+
70
+ print(f"Wrote results to {result_path}")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
75
+
76
+
77
+
78
+
79
+
80
+
81
+
ocr_app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import List
3
+
4
+ import pypdfium2
5
+ import streamlit as st
6
+ from pypdfium2 import PdfiumError
7
+
8
+ from surya.detection import batch_text_detection
9
+ from surya.input.pdflines import get_page_text_lines, get_table_blocks
10
+ from surya.layout import batch_layout_detection
11
+ from surya.model.detection.model import load_model, load_processor
12
+ from surya.model.recognition.model import load_model as load_rec_model
13
+ from surya.model.recognition.processor import load_processor as load_rec_processor
14
+ from surya.model.ordering.processor import load_processor as load_order_processor
15
+ from surya.model.ordering.model import load_model as load_order_model
16
+ from surya.model.table_rec.model import load_model as load_table_model
17
+ from surya.model.table_rec.processor import load_processor as load_table_processor
18
+ from surya.ordering import batch_ordering
19
+ from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
20
+ from surya.ocr import run_ocr
21
+ from surya.postprocessing.text import draw_text_on_image
22
+ from PIL import Image
23
+ from surya.languages import CODE_TO_LANGUAGE
24
+ from surya.input.langs import replace_lang_with_code
25
+ from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
26
+ from surya.settings import settings
27
+ from surya.tables import batch_table_recognition
28
+ from surya.postprocessing.util import rescale_bboxes, rescale_bbox
29
+
30
+
31
+ @st.cache_resource()
32
+ def load_det_cached():
33
+ checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
34
+ return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)
35
+
36
+
37
+ @st.cache_resource()
38
+ def load_rec_cached():
39
+ return load_rec_model(), load_rec_processor()
40
+
41
+
42
+ @st.cache_resource()
43
+ def load_layout_cached():
44
+ return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
45
+
46
+ @st.cache_resource()
47
+ def load_order_cached():
48
+ return load_order_model(), load_order_processor()
49
+
50
+
51
+ @st.cache_resource()
52
+ def load_table_cached():
53
+ return load_table_model(), load_table_processor()
54
+
55
+
56
+ def text_detection(img) -> (Image.Image, TextDetectionResult):
57
+ pred = batch_text_detection([img], det_model, det_processor)[0]
58
+ polygons = [p.polygon for p in pred.bboxes]
59
+ det_img = draw_polys_on_image(polygons, img.copy())
60
+ return det_img, pred
61
+
62
+
63
+ def layout_detection(img) -> (Image.Image, LayoutResult):
64
+ _, det_pred = text_detection(img)
65
+ pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
66
+ polygons = [p.polygon for p in pred.bboxes]
67
+ labels = [p.label for p in pred.bboxes]
68
+ layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
69
+ return layout_img, pred
70
+
71
+
72
+ def order_detection(img) -> (Image.Image, OrderResult):
73
+ _, layout_pred = layout_detection(img)
74
+ bboxes = [l.bbox for l in layout_pred.bboxes]
75
+ pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
76
+ polys = [l.polygon for l in pred.bboxes]
77
+ positions = [str(l.position) for l in pred.bboxes]
78
+ order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
79
+ return order_img, pred
80
+
81
+
82
+ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
83
+ if skip_table_detection:
84
+ layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
85
+ table_imgs = [highres_img]
86
+ else:
87
+ _, layout_pred = layout_detection(img)
88
+ layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
89
+ table_imgs = []
90
+ layout_tables = []
91
+ for tb in layout_tables_lowres:
92
+ highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
93
+ table_imgs.append(
94
+ highres_img.crop(highres_bbox)
95
+ )
96
+ layout_tables.append(highres_bbox)
97
+
98
+ try:
99
+ page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
100
+ table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
101
+ except PdfiumError:
102
+ # This happens when we try to get text from an image
103
+ table_bboxes = [[] for _ in layout_tables]
104
+
105
+ if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
106
+ det_results = batch_text_detection(table_imgs, det_model, det_processor)
107
+ table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
108
+
109
+ table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
110
+ table_img = img.copy()
111
+
112
+ for results, table_bbox in zip(table_preds, layout_tables):
113
+ adjusted_bboxes = []
114
+ labels = []
115
+
116
+ for item in results.cells:
117
+ adjusted_bboxes.append([
118
+ (item.bbox[0] + table_bbox[0]),
119
+ (item.bbox[1] + table_bbox[1]),
120
+ (item.bbox[2] + table_bbox[0]),
121
+ (item.bbox[3] + table_bbox[1])
122
+ ])
123
+ labels.append(f"{item.row_id} / {item.col_id}")
124
+ table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18)
125
+ return table_img, table_preds
126
+
127
+
128
+ # Function for OCR
129
+ def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
130
+ replace_lang_with_code(langs)
131
+ img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
132
+
133
+ bboxes = [l.bbox for l in img_pred.text_lines]
134
+ text = [l.text for l in img_pred.text_lines]
135
+ rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
136
+ return rec_img, img_pred
137
+
138
+
139
+ def open_pdf(pdf_file):
140
+ stream = io.BytesIO(pdf_file.getvalue())
141
+ return pypdfium2.PdfDocument(stream)
142
+
143
+
144
+ @st.cache_data()
145
+ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
146
+ doc = open_pdf(pdf_file)
147
+ renderer = doc.render(
148
+ pypdfium2.PdfBitmap.to_pil,
149
+ page_indices=[page_num - 1],
150
+ scale=dpi / 72,
151
+ )
152
+ png = list(renderer)[0]
153
+ png_image = png.convert("RGB")
154
+ return png_image
155
+
156
+
157
+ @st.cache_data()
158
+ def page_count(pdf_file):
159
+ doc = open_pdf(pdf_file)
160
+ return len(doc)
161
+
162
+
163
+ st.set_page_config(layout="wide")
164
+ col1, col2 = st.columns([.5, .5])
165
+
166
+ det_model, det_processor = load_det_cached()
167
+ rec_model, rec_processor = load_rec_cached()
168
+ layout_model, layout_processor = load_layout_cached()
169
+ order_model, order_processor = load_order_cached()
170
+ table_model, table_processor = load_table_cached()
171
+
172
+
173
+ st.markdown("""
174
+ # Surya OCR Demo
175
+
176
+ This app will let you try surya, a multilingual OCR model. It supports text detection + layout analysis in any language, and text recognition in 90+ languages.
177
+
178
+ Notes:
179
+ - This works best on documents with printed text.
180
+ - Preprocessing the image (e.g. increasing contrast) can improve results.
181
+ - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).
182
+ - This supports 90+ languages, see [here](https://github.com/VikParuchuri/surya/tree/master/surya/languages.py) for a full list.
183
+
184
+ Find the project [here](https://github.com/VikParuchuri/surya).
185
+ """)
186
+
187
+ in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
188
+ languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=[], max_selections=4, help="Select the languages in the image (if known) to improve OCR accuracy. Optional.")
189
+
190
+ if in_file is None:
191
+ st.stop()
192
+
193
+ filetype = in_file.type
194
+ whole_image = False
195
+ if "pdf" in filetype:
196
+ page_count = page_count(in_file)
197
+ page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
198
+
199
+ pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
200
+ pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
201
+ else:
202
+ pil_image = Image.open(in_file).convert("RGB")
203
+ pil_image_highres = pil_image
204
+ page_number = None
205
+
206
+ text_det = st.sidebar.button("Run Text Detection")
207
+ text_rec = st.sidebar.button("Run OCR")
208
+ layout_det = st.sidebar.button("Run Layout Analysis")
209
+ order_det = st.sidebar.button("Run Reading Order")
210
+ table_rec = st.sidebar.button("Run Table Rec")
211
+ use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
212
+ skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
213
+
214
+ if pil_image is None:
215
+ st.stop()
216
+
217
+ # Run Text Detection
218
+ if text_det:
219
+ det_img, pred = text_detection(pil_image)
220
+ with col1:
221
+ st.image(det_img, caption="Detected Text", use_column_width=True)
222
+ st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
223
+
224
+
225
+ # Run layout
226
+ if layout_det:
227
+ layout_img, pred = layout_detection(pil_image)
228
+ with col1:
229
+ st.image(layout_img, caption="Detected Layout", use_column_width=True)
230
+ st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)
231
+
232
+ # Run OCR
233
+ if text_rec:
234
+ rec_img, pred = ocr(pil_image, pil_image_highres, languages)
235
+ with col1:
236
+ st.image(rec_img, caption="OCR Result", use_column_width=True)
237
+ json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
238
+ with json_tab:
239
+ st.json(pred.model_dump(), expanded=True)
240
+ with text_tab:
241
+ st.text("\n".join([p.text for p in pred.text_lines]))
242
+
243
+ if order_det:
244
+ order_img, pred = order_detection(pil_image)
245
+ with col1:
246
+ st.image(order_img, caption="Reading Order", use_column_width=True)
247
+ st.json(pred.model_dump(), expanded=True)
248
+
249
+
250
+ if table_rec:
251
+ table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
252
+ with col1:
253
+ st.image(table_img, caption="Table Recognition", use_column_width=True)
254
+ st.json([p.model_dump() for p in pred], expanded=True)
255
+
256
+ with col2:
257
+ st.image(pil_image, caption="Uploaded Image", use_column_width=True)
ocr_text.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import time
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+
9
+ from surya.input.langs import replace_lang_with_code, get_unique_langs
10
+ from surya.input.load import load_from_folder, load_from_file, load_lang_file
11
+ from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
12
+ from surya.model.recognition.model import load_model as load_recognition_model
13
+ from surya.model.recognition.processor import load_processor as load_recognition_processor
14
+ from surya.model.recognition.tokenizer import _tokenize
15
+ from surya.ocr import run_ocr
16
+ from surya.postprocessing.text import draw_text_on_image
17
+ from surya.settings import settings
18
+
19
+
20
+ def main():
21
+ parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
22
+ parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
23
+ parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
24
+ parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
25
+ parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0)
26
+ parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
27
+ parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None)
28
+ parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None)
29
+ parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False)
30
+ args = parser.parse_args()
31
+
32
+ if os.path.isdir(args.input_path):
33
+ images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
34
+ highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
35
+ folder_name = os.path.basename(args.input_path)
36
+ else:
37
+ images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
38
+ highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
39
+ folder_name = os.path.basename(args.input_path).split(".")[0]
40
+
41
+ if args.lang_file:
42
+ # We got all of our language settings from a file
43
+ langs = load_lang_file(args.lang_file, names)
44
+ for lang in langs:
45
+ replace_lang_with_code(lang)
46
+ image_langs = langs
47
+ elif args.langs:
48
+ # We got our language settings from the input
49
+ langs = args.langs.split(",")
50
+ replace_lang_with_code(langs)
51
+ image_langs = [langs] * len(images)
52
+ else:
53
+ image_langs = [None] * len(images)
54
+
55
+ det_processor = load_detection_processor()
56
+ det_model = load_detection_model()
57
+
58
+ rec_model = load_recognition_model()
59
+ rec_processor = load_recognition_processor()
60
+
61
+ result_path = os.path.join(args.results_dir, folder_name)
62
+ os.makedirs(result_path, exist_ok=True)
63
+
64
+ start = time.time()
65
+ predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images)
66
+ if args.debug:
67
+ print(f"OCR took {time.time() - start:.2f} seconds")
68
+ max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
69
+ print(f"Max chars: {max_chars}")
70
+
71
+ if args.images:
72
+ for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
73
+ bboxes = [l.bbox for l in pred.text_lines]
74
+ pred_text = [l.text for l in pred.text_lines]
75
+ page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False)
76
+ page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
77
+
78
+ out_preds = defaultdict(list)
79
+ for name, pred, image in zip(names, predictions_by_image, images):
80
+ out_pred = pred.model_dump()
81
+ out_pred["page"] = len(out_preds[name]) + 1
82
+ out_preds[name].append(out_pred)
83
+
84
+ with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
85
+ json.dump(out_preds, f, ensure_ascii=False)
86
+
87
+ print(f"Wrote results to {result_path}")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
92
+
93
+
94
+
95
+
96
+
97
+
98
+
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "surya-ocr"
3
+ version = "0.6.1"
4
+ description = "OCR, layout, reading order, and table recognition in 90+ languages"
5
+ authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
6
+ readme = "README.md"
7
+ license = "GPL-3.0-or-later"
8
+ repository = "https://github.com/VikParuchuri/surya"
9
+ keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
10
+ packages = [
11
+ {include = "surya"}
12
+ ]
13
+ include = [
14
+ "detect_text.py",
15
+ "ocr_text.py",
16
+ "ocr_app.py",
17
+ "run_ocr_app.py",
18
+ "detect_layout.py",
19
+ "reading_order.py",
20
+ "table_recognition.py"
21
+ ]
22
+
23
+ [tool.poetry.dependencies]
24
+ python = ">=3.9,<3.13,!=3.9.7"
25
+ transformers = "^4.41.0"
26
+ torch = "^2.3.0"
27
+ pydantic = "^2.5.3"
28
+ pydantic-settings = "^2.1.0"
29
+ python-dotenv = "^1.0.0"
30
+ pillow = "^10.2.0"
31
+ pypdfium2 = "^4.25.0"
32
+ opencv-python = "^4.9.0.80"
33
+ tabulate = "^0.9.0"
34
+ filetype = "^1.2.0"
35
+ ftfy = "^6.1.3"
36
+ pdftext = "^0.3.12"
37
+
38
+ [tool.poetry.group.dev.dependencies]
39
+ jupyter = "^1.0.0"
40
+ pytesseract = "^0.3.10"
41
+ pymupdf = "^1.23.8"
42
+ snakeviz = "^2.2.0"
43
+ datasets = "^2.16.1"
44
+ rapidfuzz = "^3.6.1"
45
+ arabic-reshaper = "^3.0.0"
46
+ streamlit = "^1.31.0"
47
+ playwright = "^1.41.2"
48
+
49
+ [tool.poetry.scripts]
50
+ surya_detect = "detect_text:main"
51
+ surya_ocr = "ocr_text:main"
52
+ surya_layout = "detect_layout:main"
53
+ surya_gui = "run_ocr_app:run_app"
54
+ surya_order = "reading_order:main"
55
+ surya_table = "table_recognition:main"
56
+
57
+ [build-system]
58
+ requires = ["poetry-core"]
59
+ build-backend = "poetry.core.masonry.api"
reading_order.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import copy
4
+ import json
5
+ from collections import defaultdict
6
+
7
+ from surya.detection import batch_text_detection
8
+ from surya.input.load import load_from_folder, load_from_file
9
+ from surya.layout import batch_layout_detection
10
+ from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
11
+ from surya.model.ordering.model import load_model
12
+ from surya.model.ordering.processor import load_processor
13
+ from surya.ordering import batch_ordering
14
+ from surya.postprocessing.heatmap import draw_polys_on_image
15
+ from surya.settings import settings
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
20
+ parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
21
+ parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
22
+ parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
23
+ parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
24
+ args = parser.parse_args()
25
+
26
+ model = load_model()
27
+ processor = load_processor()
28
+
29
+ layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
30
+ layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
31
+
32
+ det_model = load_det_model()
33
+ det_processor = load_det_processor()
34
+
35
+ if os.path.isdir(args.input_path):
36
+ images, names, _ = load_from_folder(args.input_path, args.max)
37
+ folder_name = os.path.basename(args.input_path)
38
+ else:
39
+ images, names, _ = load_from_file(args.input_path, args.max)
40
+ folder_name = os.path.basename(args.input_path).split(".")[0]
41
+
42
+ line_predictions = batch_text_detection(images, det_model, det_processor)
43
+ layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
44
+ bboxes = []
45
+ for layout_pred in layout_predictions:
46
+ bbox = [l.bbox for l in layout_pred.bboxes]
47
+ bboxes.append(bbox)
48
+
49
+ order_predictions = batch_ordering(images, bboxes, model, processor)
50
+ result_path = os.path.join(args.results_dir, folder_name)
51
+ os.makedirs(result_path, exist_ok=True)
52
+
53
+ if args.images:
54
+ for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
55
+ polys = [l.polygon for l in order_pred.bboxes]
56
+ labels = [str(l.position) for l in order_pred.bboxes]
57
+ bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
58
+ bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png"))
59
+
60
+ predictions_by_page = defaultdict(list)
61
+ for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)):
62
+ out_pred = pred.model_dump()
63
+ for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes):
64
+ bbox["label"] = layout_bbox.label
65
+
66
+ out_pred["page"] = len(predictions_by_page[name]) + 1
67
+ predictions_by_page[name].append(out_pred)
68
+
69
+ # Sort in reading order
70
+ for name in predictions_by_page:
71
+ for page_preds in predictions_by_page[name]:
72
+ page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"])
73
+
74
+ with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
75
+ json.dump(predictions_by_page, f, ensure_ascii=False)
76
+
77
+ print(f"Wrote results to {result_path}")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ surya-ocr
scripts/verify_benchmark_scores.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+
4
+
5
+ def verify_layout(data):
6
+ scores = data["metrics"]
7
+ for layout_type, metrics in scores.items():
8
+ if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6:
9
+ raise ValueError("Scores do not meet the required threshold")
10
+
11
+
12
+ def verify_det(data):
13
+ scores = data["metrics"]["surya"]
14
+ if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
15
+ raise ValueError("Scores do not meet the required threshold")
16
+
17
+
18
+ def verify_rec(data):
19
+ scores = data["surya"]
20
+ if scores["avg_score"] <= 0.9:
21
+ raise ValueError("Scores do not meet the required threshold")
22
+
23
+
24
+ def verify_order(data):
25
+ score = data["mean_accuracy"]
26
+ if score < 0.75:
27
+ raise ValueError("Scores do not meet the required threshold")
28
+
29
+
30
+ def verify_table_rec(data):
31
+ row_score = data["surya"]["mean_row_iou"]
32
+ col_score = data["surya"]["mean_col_iou"]
33
+
34
+ if row_score < 0.75 or col_score < 0.75:
35
+ raise ValueError("Scores do not meet the required threshold")
36
+
37
+
38
+ def verify_scores(file_path, bench_type):
39
+ with open(file_path, 'r') as file:
40
+ data = json.load(file)
41
+
42
+ if bench_type == "detection":
43
+ verify_det(data)
44
+ elif bench_type == "recognition":
45
+ verify_rec(data)
46
+ elif bench_type == "layout":
47
+ verify_layout(data)
48
+ elif bench_type == "ordering":
49
+ verify_order(data)
50
+ elif bench_type == "table_recognition":
51
+ verify_table_rec(data)
52
+ else:
53
+ raise ValueError("Invalid benchmark type")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser(description="Verify benchmark scores")
58
+ parser.add_argument("file_path", type=str, help="Path to the json file")
59
+ parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection")
60
+ args = parser.parse_args()
61
+ verify_scores(args.file_path, args.bench_type)
surya/benchmark/bbox.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz as pymupdf
2
+ from surya.postprocessing.util import rescale_bbox
3
+
4
+
5
+ def get_pdf_lines(pdf_path, img_sizes):
6
+ doc = pymupdf.open(pdf_path)
7
+ page_lines = []
8
+ for idx, img_size in enumerate(img_sizes):
9
+ page = doc[idx]
10
+ blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]
11
+
12
+ line_boxes = []
13
+ for block_idx, block in enumerate(blocks):
14
+ for l in block["lines"]:
15
+ line_boxes.append(list(l["bbox"]))
16
+
17
+ page_box = page.bound()
18
+ pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
19
+ line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
20
+ page_lines.append(line_boxes)
21
+
22
+ return page_lines
surya/benchmark/metrics.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import repeat
3
+
4
+ import numpy as np
5
+ from concurrent.futures import ProcessPoolExecutor
6
+
7
+
8
+ def intersection_area(box1, box2):
9
+ x_left = max(box1[0], box2[0])
10
+ y_top = max(box1[1], box2[1])
11
+ x_right = min(box1[2], box2[2])
12
+ y_bottom = min(box1[3], box2[3])
13
+
14
+ if x_right < x_left or y_bottom < y_top:
15
+ return 0.0
16
+
17
+ return (x_right - x_left) * (y_bottom - y_top)
18
+
19
+ def box_area(box):
20
+ return (box[2] - box[0]) * (box[3] - box[1])
21
+
22
+
23
+ def calculate_iou(box1, box2, box1_only=False):
24
+ intersection = intersection_area(box1, box2)
25
+ union = box_area(box1)
26
+ if not box1_only:
27
+ union += box_area(box2) - intersection
28
+
29
+ if union == 0:
30
+ return 0
31
+ return intersection / union
32
+
33
+
34
+ def match_boxes(preds, references):
35
+ num_actual = len(references)
36
+ num_predicted = len(preds)
37
+
38
+ iou_matrix = np.zeros((num_actual, num_predicted))
39
+ for i, actual in enumerate(references):
40
+ for j, pred in enumerate(preds):
41
+ iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
42
+
43
+ sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
44
+ sorted_ious = iou_matrix.flatten()[sorted_indices]
45
+ actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
46
+
47
+ assigned_actual = set()
48
+ assigned_pred = set()
49
+
50
+ matches = []
51
+ for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
52
+ i, j = idx
53
+ if i not in assigned_actual and j not in assigned_pred:
54
+ iou_val = iou_matrix[i, j]
55
+ if iou_val > .95: # Account for rounding on box edges
56
+ iou_val = 1.0
57
+ matches.append((i, j, iou_val))
58
+ assigned_actual.add(i)
59
+ assigned_pred.add(j)
60
+
61
+ unassigned_actual = set(range(num_actual)) - assigned_actual
62
+ unassigned_pred = set(range(num_predicted)) - assigned_pred
63
+ matches.extend([(i, None, -1.0) for i in unassigned_actual])
64
+ matches.extend([(None, j, 0.0) for j in unassigned_pred])
65
+
66
+ return matches
67
+
68
+ def penalized_iou_score(preds, references):
69
+ matches = match_boxes(preds, references)
70
+ iou = sum([match[2] for match in matches]) / len(matches)
71
+ return iou
72
+
73
+ def intersection_pixels(box1, box2):
74
+ x_left = max(box1[0], box2[0])
75
+ y_top = max(box1[1], box2[1])
76
+ x_right = min(box1[2], box2[2])
77
+ y_bottom = min(box1[3], box2[3])
78
+
79
+ if x_right < x_left or y_bottom < y_top:
80
+ return set()
81
+
82
+ x_left, x_right = int(x_left), int(x_right)
83
+ y_top, y_bottom = int(y_top), int(y_bottom)
84
+
85
+ coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
86
+ pixels = set(zip(coords[0].flat, coords[1].flat))
87
+
88
+ return pixels
89
+
90
+
91
+ def calculate_coverage(box, other_boxes, penalize_double=False):
92
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
93
+ if box_area == 0:
94
+ return 0
95
+
96
+ # find total coverage of the box
97
+ covered_pixels = set()
98
+ double_coverage = list()
99
+ for other_box in other_boxes:
100
+ ia = intersection_pixels(box, other_box)
101
+ double_coverage.append(list(covered_pixels.intersection(ia)))
102
+ covered_pixels = covered_pixels.union(ia)
103
+
104
+ # Penalize double coverage - having multiple bboxes overlapping the same pixels
105
+ double_coverage_penalty = len(double_coverage)
106
+ if not penalize_double:
107
+ double_coverage_penalty = 0
108
+ covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
109
+ return covered_pixels_count / box_area
110
+
111
+
112
+ def calculate_coverage_fast(box, other_boxes, penalize_double=False):
113
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
114
+ if box_area == 0:
115
+ return 0
116
+
117
+ total_intersect = 0
118
+ for other_box in other_boxes:
119
+ total_intersect += intersection_area(box, other_box)
120
+
121
+ return min(1, total_intersect / box_area)
122
+
123
+
124
+ def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
125
+ if len(references) == 0:
126
+ return {
127
+ "precision": 1,
128
+ "recall": 1,
129
+ }
130
+
131
+ if len(preds) == 0:
132
+ return {
133
+ "precision": 0,
134
+ "recall": 0,
135
+ }
136
+
137
+ # If we're not penalizing double coverage, we can use a faster calculation
138
+ coverage_func = calculate_coverage_fast
139
+ if penalize_double:
140
+ coverage_func = calculate_coverage
141
+
142
+ with ProcessPoolExecutor(max_workers=workers) as executor:
143
+ precision_func = partial(coverage_func, penalize_double=penalize_double)
144
+ precision_iou = executor.map(precision_func, preds, repeat(references))
145
+ reference_iou = executor.map(coverage_func, references, repeat(preds))
146
+
147
+ precision_classes = [1 if i > threshold else 0 for i in precision_iou]
148
+ precision = sum(precision_classes) / len(precision_classes)
149
+
150
+ recall_classes = [1 if i > threshold else 0 for i in reference_iou]
151
+ recall = sum(recall_classes) / len(recall_classes)
152
+
153
+ return {
154
+ "precision": precision,
155
+ "recall": recall,
156
+ }
157
+
158
+
159
+ def mean_coverage(preds, references):
160
+ coverages = []
161
+
162
+ for box1 in references:
163
+ coverage = calculate_coverage(box1, preds)
164
+ coverages.append(coverage)
165
+
166
+ for box2 in preds:
167
+ coverage = calculate_coverage(box2, references)
168
+ coverages.append(coverage)
169
+
170
+ # Calculate the average coverage over all comparisons
171
+ if len(coverages) == 0:
172
+ return 0
173
+ coverage = sum(coverages) / len(coverages)
174
+ return {"coverage": coverage}
175
+
176
+
177
+ def rank_accuracy(preds, references):
178
+ # Preds and references need to be aligned so each position refers to the same bbox
179
+ pairs = []
180
+ for i, pred in enumerate(preds):
181
+ for j, pred2 in enumerate(preds):
182
+ if i == j:
183
+ continue
184
+ pairs.append((i, j, pred > pred2))
185
+
186
+ # Find how many of the prediction rankings are correct
187
+ correct = 0
188
+ for i, ref in enumerate(references):
189
+ for j, ref2 in enumerate(references):
190
+ if (i, j, ref > ref2) in pairs:
191
+ correct += 1
192
+
193
+ return correct / len(pairs)
surya/benchmark/tatr.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
3
+ from surya.settings import settings
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+ class MaxResize(object):
10
+ def __init__(self, max_size=800):
11
+ self.max_size = max_size
12
+
13
+ def __call__(self, image):
14
+ width, height = image.size
15
+ current_max_size = max(width, height)
16
+ scale = self.max_size / current_max_size
17
+ resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
18
+
19
+ return resized_image
20
+
21
+
22
+ def to_tensor(image):
23
+ # Convert PIL Image to NumPy array
24
+ np_image = np.array(image).astype(np.float32)
25
+
26
+ # Rearrange dimensions to [C, H, W] format
27
+ np_image = np_image.transpose((2, 0, 1))
28
+
29
+ # Normalize to [0.0, 1.0]
30
+ np_image /= 255.0
31
+
32
+ return torch.from_numpy(np_image)
33
+
34
+
35
+ def normalize(tensor, mean, std):
36
+ for t, m, s in zip(tensor, mean, std):
37
+ t.sub_(m).div_(s)
38
+ return tensor
39
+
40
+
41
+ def structure_transform(image):
42
+ image = MaxResize(1000)(image)
43
+ tensor = to_tensor(image)
44
+ normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ return normalized_tensor
46
+
47
+
48
+ def box_cxcywh_to_xyxy(x):
49
+ x_c, y_c, w, h = x.unbind(-1)
50
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
51
+ return torch.stack(b, dim=1)
52
+
53
+
54
+ def rescale_bboxes(out_bbox, size):
55
+ width, height = size
56
+ boxes = box_cxcywh_to_xyxy(out_bbox)
57
+ boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
58
+ return boxes
59
+
60
+
61
+ def outputs_to_objects(outputs, img_sizes, id2label):
62
+ m = outputs.logits.softmax(-1).max(-1)
63
+ batch_labels = list(m.indices.detach().cpu().numpy())
64
+ batch_scores = list(m.values.detach().cpu().numpy())
65
+ batch_bboxes = outputs['pred_boxes'].detach().cpu()
66
+
67
+ batch_objects = []
68
+ for i in range(len(img_sizes)):
69
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
70
+ pred_scores = batch_scores[i]
71
+ pred_labels = batch_labels[i]
72
+
73
+ objects = []
74
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
75
+ class_label = id2label[int(label)]
76
+ if not class_label == 'no object':
77
+ objects.append({
78
+ 'label': class_label,
79
+ 'score': float(score),
80
+ 'bbox': [float(elem) for elem in bbox]}
81
+ )
82
+
83
+ rows = []
84
+ cols = []
85
+ for i, cell in enumerate(objects):
86
+ if cell["label"] == "table column":
87
+ cols.append(cell)
88
+
89
+ if cell["label"] == "table row":
90
+ rows.append(cell)
91
+ batch_objects.append({
92
+ "rows": rows,
93
+ "cols": cols
94
+ })
95
+
96
+ return batch_objects
97
+
98
+
99
+ def load_tatr():
100
+ return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
101
+
102
+
103
+ def batch_inference_tatr(model, images, batch_size):
104
+ device = model.device
105
+ rows_cols = []
106
+ for i in range(0, len(images), batch_size):
107
+ batch_images = images[i:i + batch_size]
108
+ pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
109
+
110
+ # forward pass
111
+ with torch.no_grad():
112
+ outputs = model(pixel_values)
113
+
114
+ id2label = model.config.id2label
115
+ id2label[len(model.config.id2label)] = "no object"
116
+ rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
117
+ return rows_cols
surya/benchmark/tesseract.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import pytesseract
5
+ from pytesseract import Output
6
+ from tqdm import tqdm
7
+
8
+ from surya.input.processing import slice_bboxes_from_image
9
+ from surya.settings import settings
10
+ import os
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from surya.detection import get_batch_size as get_det_batch_size
13
+ from surya.recognition import get_batch_size as get_rec_batch_size
14
+ from surya.languages import CODE_TO_LANGUAGE
15
+
16
+
17
+ def surya_lang_to_tesseract(code: str) -> Optional[str]:
18
+ lang_str = CODE_TO_LANGUAGE[code]
19
+ try:
20
+ tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
21
+ except KeyError:
22
+ return None
23
+ return tess_lang
24
+
25
+
26
+ def tesseract_ocr(img, bboxes, lang: str):
27
+ line_imgs = slice_bboxes_from_image(img, bboxes)
28
+ config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
29
+ lines = []
30
+ for line_img in line_imgs:
31
+ line = pytesseract.image_to_string(line_img, lang=lang, config=config)
32
+ lines.append(line)
33
+ return lines
34
+
35
+
36
+ def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
37
+ tess_parallel_cores = min(len(imgs), get_rec_batch_size())
38
+ if not cpus:
39
+ cpus = os.cpu_count()
40
+ tess_parallel_cores = min(tess_parallel_cores, cpus)
41
+
42
+ # Tesseract uses up to 4 processes per instance
43
+ # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
44
+ tess_parallel = max(tess_parallel_cores // 2, 1)
45
+
46
+ with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
47
+ tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
48
+ tess_text = list(tess_text)
49
+ return tess_text
50
+
51
+
52
+ def tesseract_bboxes(img):
53
+ arr_img = np.asarray(img, dtype=np.uint8)
54
+ ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
55
+
56
+ bboxes = []
57
+ n_boxes = len(ocr['level'])
58
+ for i in range(n_boxes):
59
+ # It is possible to merge by line here with line number, but it gives bad results.
60
+ _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
61
+ bbox = (x, y, x + w, y + h)
62
+ bboxes.append(bbox)
63
+
64
+ return bboxes
65
+
66
+
67
+ def tesseract_parallel(imgs):
68
+ # Tesseract uses 4 threads per instance
69
+ tess_parallel_cores = min(len(imgs), get_det_batch_size())
70
+ cpus = os.cpu_count()
71
+ tess_parallel_cores = min(tess_parallel_cores, cpus)
72
+
73
+ # Tesseract uses 4 threads per instance
74
+ tess_parallel = max(tess_parallel_cores // 4, 1)
75
+
76
+ with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
77
+ tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
78
+ tess_bboxes = list(tess_bboxes)
79
+ return tess_bboxes
80
+
81
+
82
+ TESS_CODE_TO_LANGUAGE = {
83
+ "afr": "Afrikaans",
84
+ "amh": "Amharic",
85
+ "ara": "Arabic",
86
+ "asm": "Assamese",
87
+ "aze": "Azerbaijani",
88
+ "bel": "Belarusian",
89
+ "ben": "Bengali",
90
+ "bod": "Tibetan",
91
+ "bos": "Bosnian",
92
+ "bre": "Breton",
93
+ "bul": "Bulgarian",
94
+ "cat": "Catalan",
95
+ "ceb": "Cebuano",
96
+ "ces": "Czech",
97
+ "chi_sim": "Chinese",
98
+ "chr": "Cherokee",
99
+ "cym": "Welsh",
100
+ "dan": "Danish",
101
+ "deu": "German",
102
+ "dzo": "Dzongkha",
103
+ "ell": "Greek",
104
+ "eng": "English",
105
+ "epo": "Esperanto",
106
+ "est": "Estonian",
107
+ "eus": "Basque",
108
+ "fas": "Persian",
109
+ "fin": "Finnish",
110
+ "fra": "French",
111
+ "fry": "Western Frisian",
112
+ "guj": "Gujarati",
113
+ "gla": "Scottish Gaelic",
114
+ "gle": "Irish",
115
+ "glg": "Galician",
116
+ "heb": "Hebrew",
117
+ "hin": "Hindi",
118
+ "hrv": "Croatian",
119
+ "hun": "Hungarian",
120
+ "hye": "Armenian",
121
+ "iku": "Inuktitut",
122
+ "ind": "Indonesian",
123
+ "isl": "Icelandic",
124
+ "ita": "Italian",
125
+ "jav": "Javanese",
126
+ "jpn": "Japanese",
127
+ "kan": "Kannada",
128
+ "kat": "Georgian",
129
+ "kaz": "Kazakh",
130
+ "khm": "Khmer",
131
+ "kir": "Kyrgyz",
132
+ "kor": "Korean",
133
+ "lao": "Lao",
134
+ "lat": "Latin",
135
+ "lav": "Latvian",
136
+ "lit": "Lithuanian",
137
+ "mal": "Malayalam",
138
+ "mar": "Marathi",
139
+ "mkd": "Macedonian",
140
+ "mlt": "Maltese",
141
+ "mon": "Mongolian",
142
+ "msa": "Malay",
143
+ "mya": "Burmese",
144
+ "nep": "Nepali",
145
+ "nld": "Dutch",
146
+ "nor": "Norwegian",
147
+ "ori": "Oriya",
148
+ "pan": "Punjabi",
149
+ "pol": "Polish",
150
+ "por": "Portuguese",
151
+ "pus": "Pashto",
152
+ "ron": "Romanian",
153
+ "rus": "Russian",
154
+ "san": "Sanskrit",
155
+ "sin": "Sinhala",
156
+ "slk": "Slovak",
157
+ "slv": "Slovenian",
158
+ "snd": "Sindhi",
159
+ "spa": "Spanish",
160
+ "sqi": "Albanian",
161
+ "srp": "Serbian",
162
+ "swa": "Swahili",
163
+ "swe": "Swedish",
164
+ "syr": "Syriac",
165
+ "tam": "Tamil",
166
+ "tel": "Telugu",
167
+ "tgk": "Tajik",
168
+ "tha": "Thai",
169
+ "tir": "Tigrinya",
170
+ "tur": "Turkish",
171
+ "uig": "Uyghur",
172
+ "ukr": "Ukrainian",
173
+ "urd": "Urdu",
174
+ "uzb": "Uzbek",
175
+ "vie": "Vietnamese",
176
+ "yid": "Yiddish"
177
+ }
178
+
179
+ TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
surya/benchmark/util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def merge_boxes(box1, box2):
2
+ return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))
3
+
4
+
5
+ def join_lines(bboxes, max_gap=5):
6
+ to_merge = {}
7
+ for i, box1 in bboxes:
8
+ for z, box2 in bboxes[i + 1:]:
9
+ j = i + z + 1
10
+ if box1 == box2:
11
+ continue
12
+
13
+ if box1[0] <= box2[0] and box1[2] >= box2[2]:
14
+ if abs(box1[1] - box2[3]) <= max_gap:
15
+ if i not in to_merge:
16
+ to_merge[i] = []
17
+ to_merge[i].append(j)
18
+
19
+ merged_boxes = set()
20
+ merged = []
21
+ for i, box in bboxes:
22
+ if i in merged_boxes:
23
+ continue
24
+
25
+ if i in to_merge:
26
+ for j in to_merge[i]:
27
+ box = merge_boxes(box, bboxes[j][1])
28
+ merged_boxes.add(j)
29
+
30
+ merged.append(box)
31
+ return merged
surya/detection.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Generator
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ from surya.model.detection.model import EfficientViTForSemanticSegmentation
8
+ from surya.postprocessing.heatmap import get_and_clean_boxes
9
+ from surya.postprocessing.affinity import get_vertical_lines
10
+ from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
11
+ from surya.schema import TextDetectionResult
12
+ from surya.settings import settings
13
+ from tqdm import tqdm
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ import torch.nn.functional as F
16
+
17
+
18
+ def get_batch_size():
19
+ batch_size = settings.DETECTOR_BATCH_SIZE
20
+ if batch_size is None:
21
+ batch_size = 8
22
+ if settings.TORCH_DEVICE_MODEL == "mps":
23
+ batch_size = 8
24
+ if settings.TORCH_DEVICE_MODEL == "cuda":
25
+ batch_size = 36
26
+ return batch_size
27
+
28
+
29
+ def batch_detection(
30
+ images: List,
31
+ model: EfficientViTForSemanticSegmentation,
32
+ processor,
33
+ batch_size=None
34
+ ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
35
+ assert all([isinstance(image, Image.Image) for image in images])
36
+ if batch_size is None:
37
+ batch_size = get_batch_size()
38
+ heatmap_count = model.config.num_labels
39
+
40
+ orig_sizes = [image.size for image in images]
41
+ splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
42
+
43
+ batches = []
44
+ current_batch_size = 0
45
+ current_batch = []
46
+ for i in range(len(images)):
47
+ if current_batch_size + splits_per_image[i] > batch_size:
48
+ if len(current_batch) > 0:
49
+ batches.append(current_batch)
50
+ current_batch = []
51
+ current_batch_size = 0
52
+ current_batch.append(i)
53
+ current_batch_size += splits_per_image[i]
54
+
55
+ if len(current_batch) > 0:
56
+ batches.append(current_batch)
57
+
58
+ for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
59
+ batch_image_idxs = batches[batch_idx]
60
+ batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
61
+
62
+ split_index = []
63
+ split_heights = []
64
+ image_splits = []
65
+ for image_idx, image in enumerate(batch_images):
66
+ image_parts, split_height = split_image(image, processor)
67
+ image_splits.extend(image_parts)
68
+ split_index.extend([image_idx] * len(image_parts))
69
+ split_heights.extend(split_height)
70
+
71
+ image_splits = [prepare_image_detection(image, processor) for image in image_splits]
72
+ # Batch images in dim 0
73
+ batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
74
+
75
+ with torch.inference_mode():
76
+ pred = model(pixel_values=batch)
77
+
78
+ logits = pred.logits
79
+ correct_shape = [processor.size["height"], processor.size["width"]]
80
+ current_shape = list(logits.shape[2:])
81
+ if current_shape != correct_shape:
82
+ logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
83
+
84
+ logits = logits.cpu().detach().numpy().astype(np.float32)
85
+ preds = []
86
+ for i, (idx, height) in enumerate(zip(split_index, split_heights)):
87
+ # If our current prediction length is below the image idx, that means we have a new image
88
+ # Otherwise, we need to add to the current image
89
+ if len(preds) <= idx:
90
+ preds.append([logits[i][k] for k in range(heatmap_count)])
91
+ else:
92
+ heatmaps = preds[idx]
93
+ pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
94
+
95
+ if height < processor.size["height"]:
96
+ # Cut off padding to get original height
97
+ pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps]
98
+
99
+ for k in range(heatmap_count):
100
+ heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
101
+ preds[idx] = heatmaps
102
+
103
+ yield preds, [orig_sizes[j] for j in batch_image_idxs]
104
+
105
+
106
+ def parallel_get_lines(preds, orig_sizes):
107
+ heatmap, affinity_map = preds
108
+ heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
109
+ aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
110
+ affinity_size = list(reversed(affinity_map.shape))
111
+ heatmap_size = list(reversed(heatmap.shape))
112
+ bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
113
+ vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
114
+
115
+ result = TextDetectionResult(
116
+ bboxes=bboxes,
117
+ vertical_lines=vertical_lines,
118
+ heatmap=heat_img,
119
+ affinity_map=aff_img,
120
+ image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
121
+ )
122
+ return result
123
+
124
+
125
+ def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
126
+ detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
127
+
128
+ results = []
129
+ max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
130
+ parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
131
+
132
+ if parallelize:
133
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
134
+ for preds, orig_sizes in detection_generator:
135
+ batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
136
+ results.extend(batch_results)
137
+ else:
138
+ for preds, orig_sizes in detection_generator:
139
+ for pred, orig_size in zip(preds, orig_sizes):
140
+ results.append(parallel_get_lines(pred, orig_size))
141
+
142
+ return results
143
+
144
+
surya/input/langs.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE
3
+
4
+
5
+ def replace_lang_with_code(langs: List[str]):
6
+ for i in range(len(langs)):
7
+ if langs[i].title() in LANGUAGE_TO_CODE:
8
+ langs[i] = LANGUAGE_TO_CODE[langs[i].title()]
9
+ if langs[i] not in CODE_TO_LANGUAGE:
10
+ raise ValueError(f"Language code {langs[i]} not found.")
11
+
12
+
13
+ def get_unique_langs(langs: List[List[str]]):
14
+ uniques = []
15
+ for lang_list in langs:
16
+ for lang in lang_list:
17
+ if lang not in uniques:
18
+ uniques.append(lang)
19
+ return uniques
surya/input/load.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+
3
+ from surya.input.processing import open_pdf, get_page_images
4
+ from surya.settings import settings
5
+ import os
6
+ import filetype
7
+ from PIL import Image
8
+ import json
9
+
10
+
11
+
12
+ def get_name_from_path(path):
13
+ return os.path.basename(path).split(".")[0]
14
+
15
+
16
+ def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
17
+ doc = open_pdf(pdf_path)
18
+ last_page = len(doc)
19
+
20
+ if start_page:
21
+ assert start_page < last_page and start_page >= 0, f"Start page must be between 0 and {last_page}"
22
+ else:
23
+ start_page = 0
24
+
25
+ if max_pages:
26
+ assert max_pages >= 0, f"Max pages must be greater than 0"
27
+ last_page = min(start_page + max_pages, last_page)
28
+
29
+ page_indices = list(range(start_page, last_page))
30
+ images = get_page_images(doc, page_indices, dpi=dpi)
31
+ text_lines = None
32
+ if load_text_lines:
33
+ from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import
34
+ text_lines = get_page_text_lines(
35
+ pdf_path,
36
+ page_indices,
37
+ [i.size for i in images]
38
+ )
39
+ doc.close()
40
+ names = [get_name_from_path(pdf_path) for _ in page_indices]
41
+ return images, names, text_lines
42
+
43
+
44
+ def load_image(image_path):
45
+ image = Image.open(image_path).convert("RGB")
46
+ name = get_name_from_path(image_path)
47
+ return [image], [name], [None]
48
+
49
+
50
+ def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
51
+ input_type = filetype.guess(input_path)
52
+ if input_type.extension == "pdf":
53
+ return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
54
+ else:
55
+ return load_image(input_path)
56
+
57
+
58
+ def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
59
+ image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")]
60
+ image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
61
+
62
+ images = []
63
+ names = []
64
+ text_lines = []
65
+ for path in image_paths:
66
+ extension = filetype.guess(path)
67
+ if extension and extension.extension == "pdf":
68
+ image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
69
+ images.extend(image)
70
+ names.extend(name)
71
+ text_lines.extend(text_line)
72
+ else:
73
+ try:
74
+ image, name, text_line = load_image(path)
75
+ images.extend(image)
76
+ names.extend(name)
77
+ text_lines.extend(text_line)
78
+ except PIL.UnidentifiedImageError:
79
+ print(f"Could not load image {path}")
80
+ continue
81
+ return images, names, text_lines
82
+
83
+
84
+ def load_lang_file(lang_path, names):
85
+ with open(lang_path, "r") as f:
86
+ lang_dict = json.load(f)
87
+ return [lang_dict[name].copy() for name in names]
surya/input/pdflines.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdftext.extraction import dictionary_output
2
+
3
+ from surya.postprocessing.text import sort_text_lines
4
+ from surya.schema import PolygonBox
5
+
6
+
7
+ def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list:
8
+ assert len(page_idxs) == len(out_sizes)
9
+ pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True)
10
+ for full_text, out_size in zip(pages_text, out_sizes):
11
+ width = full_text["width"]
12
+ height = full_text["height"]
13
+ text_w_scale = out_size[0] / width
14
+ text_h_scale = out_size[1] / height
15
+ for block in full_text["blocks"]:
16
+ for line in block["lines"]:
17
+ line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
18
+ line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
19
+ for span in line["spans"]:
20
+ for char in span["chars"]:
21
+ char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
22
+ char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
23
+ return pages_text
24
+
25
+
26
+ def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8):
27
+ # Returns coordinates relative to input table, not full image
28
+ table_texts = []
29
+ for table in tables:
30
+ table_poly = PolygonBox(polygon=[
31
+ [table[0], table[1]],
32
+ [table[2], table[1]],
33
+ [table[2], table[3]],
34
+ [table[0], table[3]]
35
+ ])
36
+ table_text = []
37
+ rotation = full_text["rotation"]
38
+ for block in full_text["blocks"]:
39
+ for line in block["lines"]:
40
+ line_poly = PolygonBox(polygon=[
41
+ [line["bbox"][0], line["bbox"][1]],
42
+ [line["bbox"][2], line["bbox"][1]],
43
+ [line["bbox"][2], line["bbox"][3]],
44
+ [line["bbox"][0], line["bbox"][3]]
45
+ ])
46
+ if line_poly.intersection_pct(table_poly) < table_thresh:
47
+ continue
48
+ curr_span = None
49
+ curr_box = None
50
+ for span in line["spans"]:
51
+ for char in span["chars"]:
52
+ same_span = False
53
+ if curr_span:
54
+ if rotation == 90:
55
+ same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01
56
+ elif rotation == 180:
57
+ same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
58
+ elif rotation == 270:
59
+ same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01
60
+ else:
61
+ same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
62
+
63
+ if curr_span is None:
64
+ curr_span = char["char"]
65
+ curr_box = char["bbox"]
66
+ elif same_span:
67
+ curr_span += char["char"]
68
+ curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
69
+ max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
70
+ else:
71
+ table_text.append({"text": curr_span, "bbox": curr_box})
72
+ curr_span = char["char"]
73
+ curr_box = char["bbox"]
74
+ if curr_span is not None:
75
+ table_text.append({"text": curr_span, "bbox": curr_box})
76
+ # Adjust to be relative to input table
77
+ for item in table_text:
78
+ item["bbox"] = [
79
+ item["bbox"][0] - table[0],
80
+ item["bbox"][1] - table[1],
81
+ item["bbox"][2] - table[0],
82
+ item["bbox"][3] - table[1]
83
+ ]
84
+ table_text = sort_text_lines(table_text)
85
+ table_texts.append(table_text)
86
+ return table_texts
surya/input/processing.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import math
6
+ import pypdfium2
7
+ from PIL import Image, ImageOps, ImageDraw
8
+ import torch
9
+ from surya.settings import settings
10
+
11
+
12
+ def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
13
+ new_images = []
14
+ for image in images:
15
+ if image.mode != "RGB":
16
+ image = image.convert("RGB")
17
+ new_images.append(image)
18
+ return new_images
19
+
20
+
21
+ def get_total_splits(image_size, processor):
22
+ img_height = list(image_size)[1]
23
+ max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
24
+ processor_height = processor.size["height"]
25
+ if img_height > max_height:
26
+ num_splits = math.ceil(img_height / processor_height)
27
+ return num_splits
28
+ return 1
29
+
30
+
31
+ def split_image(img, processor):
32
+ # This will not modify/return the original image - it will either crop, or copy the image
33
+ img_height = list(img.size)[1]
34
+ max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
35
+ processor_height = processor.size["height"]
36
+ if img_height > max_height:
37
+ num_splits = math.ceil(img_height / processor_height)
38
+ splits = []
39
+ split_heights = []
40
+ for i in range(num_splits):
41
+ top = i * processor_height
42
+ bottom = (i + 1) * processor_height
43
+ if bottom > img_height:
44
+ bottom = img_height
45
+ cropped = img.crop((0, top, img.size[0], bottom))
46
+ height = bottom - top
47
+ if height < processor_height:
48
+ cropped = ImageOps.pad(cropped, (img.size[0], processor_height), color=255, centering=(0, 0))
49
+ splits.append(cropped)
50
+ split_heights.append(height)
51
+ return splits, split_heights
52
+ return [img.copy()], [img_height]
53
+
54
+
55
+ def prepare_image_detection(img, processor):
56
+ new_size = (processor.size["width"], processor.size["height"])
57
+
58
+ # This double resize actually necessary for downstream accuracy
59
+ img.thumbnail(new_size, Image.Resampling.LANCZOS)
60
+ img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
61
+
62
+ img = np.asarray(img, dtype=np.uint8)
63
+ img = processor(img)["pixel_values"][0]
64
+ img = torch.from_numpy(img)
65
+ return img
66
+
67
+
68
+ def open_pdf(pdf_filepath):
69
+ return pypdfium2.PdfDocument(pdf_filepath)
70
+
71
+
72
+ def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
73
+ renderer = doc.render(
74
+ pypdfium2.PdfBitmap.to_pil,
75
+ page_indices=indices,
76
+ scale=dpi / 72,
77
+ )
78
+ images = list(renderer)
79
+ images = [image.convert("RGB") for image in images]
80
+ return images
81
+
82
+
83
+ def slice_bboxes_from_image(image: Image.Image, bboxes):
84
+ lines = []
85
+ for bbox in bboxes:
86
+ line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
87
+ if line.size[0] == 0:
88
+ print(f"Warning: found an empty line with bbox {bbox}")
89
+ lines.append(line)
90
+ return lines
91
+
92
+
93
+ def slice_polys_from_image(image: Image.Image, polys):
94
+ image_array = np.array(image, dtype=np.uint8)
95
+ lines = []
96
+ for idx, poly in enumerate(polys):
97
+ lines.append(slice_and_pad_poly(image_array, poly))
98
+ return lines
99
+
100
+
101
+ def slice_and_pad_poly(image_array: np.array, coordinates):
102
+ # Draw polygon onto mask
103
+ coordinates = [(corner[0], corner[1]) for corner in coordinates]
104
+ bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])]
105
+
106
+ # We mask out anything not in the polygon
107
+ cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
108
+ coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
109
+
110
+ # Pad the area outside the polygon with the pad value
111
+ mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
112
+ cv2.fillPoly(mask, [np.int32(coordinates)], 1)
113
+ mask = np.stack([mask] * 3, axis=-1)
114
+
115
+ cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
116
+ rectangle_image = Image.fromarray(cropped_polygon)
117
+
118
+ return rectangle_image
surya/languages.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CODE_TO_LANGUAGE = {
2
+ "_math": "Math",
3
+ 'af': 'Afrikaans',
4
+ 'am': 'Amharic',
5
+ 'ar': 'Arabic',
6
+ 'as': 'Assamese',
7
+ 'az': 'Azerbaijani',
8
+ 'be': 'Belarusian',
9
+ 'bg': 'Bulgarian',
10
+ 'bn': 'Bengali',
11
+ 'br': 'Breton',
12
+ 'bs': 'Bosnian',
13
+ 'ca': 'Catalan',
14
+ 'cs': 'Czech',
15
+ 'cy': 'Welsh',
16
+ 'da': 'Danish',
17
+ 'de': 'German',
18
+ 'el': 'Greek',
19
+ 'en': 'English',
20
+ 'eo': 'Esperanto',
21
+ 'es': 'Spanish',
22
+ 'et': 'Estonian',
23
+ 'eu': 'Basque',
24
+ 'fa': 'Persian',
25
+ 'fi': 'Finnish',
26
+ 'fr': 'French',
27
+ 'fy': 'Western Frisian',
28
+ 'ga': 'Irish',
29
+ 'gd': 'Scottish Gaelic',
30
+ 'gl': 'Galician',
31
+ 'gu': 'Gujarati',
32
+ 'ha': 'Hausa',
33
+ 'he': 'Hebrew',
34
+ 'hi': 'Hindi',
35
+ 'hr': 'Croatian',
36
+ 'hu': 'Hungarian',
37
+ 'hy': 'Armenian',
38
+ 'id': 'Indonesian',
39
+ 'is': 'Icelandic',
40
+ 'it': 'Italian',
41
+ 'ja': 'Japanese',
42
+ 'jv': 'Javanese',
43
+ 'ka': 'Georgian',
44
+ 'kk': 'Kazakh',
45
+ 'km': 'Khmer',
46
+ 'kn': 'Kannada',
47
+ 'ko': 'Korean',
48
+ 'ku': 'Kurdish',
49
+ 'ky': 'Kyrgyz',
50
+ 'la': 'Latin',
51
+ 'lo': 'Lao',
52
+ 'lt': 'Lithuanian',
53
+ 'lv': 'Latvian',
54
+ 'mg': 'Malagasy',
55
+ 'mk': 'Macedonian',
56
+ 'ml': 'Malayalam',
57
+ 'mn': 'Mongolian',
58
+ 'mr': 'Marathi',
59
+ 'ms': 'Malay',
60
+ 'my': 'Burmese',
61
+ 'ne': 'Nepali',
62
+ 'nl': 'Dutch',
63
+ 'no': 'Norwegian',
64
+ 'om': 'Oromo',
65
+ 'or': 'Oriya',
66
+ 'pa': 'Punjabi',
67
+ 'pl': 'Polish',
68
+ 'ps': 'Pashto',
69
+ 'pt': 'Portuguese',
70
+ 'ro': 'Romanian',
71
+ 'ru': 'Russian',
72
+ 'sa': 'Sanskrit',
73
+ 'sd': 'Sindhi',
74
+ 'si': 'Sinhala',
75
+ 'sk': 'Slovak',
76
+ 'sl': 'Slovenian',
77
+ 'so': 'Somali',
78
+ 'sq': 'Albanian',
79
+ 'sr': 'Serbian',
80
+ 'su': 'Sundanese',
81
+ 'sv': 'Swedish',
82
+ 'sw': 'Swahili',
83
+ 'ta': 'Tamil',
84
+ 'te': 'Telugu',
85
+ 'th': 'Thai',
86
+ 'tl': 'Tagalog',
87
+ 'tr': 'Turkish',
88
+ 'ug': 'Uyghur',
89
+ 'uk': 'Ukrainian',
90
+ 'ur': 'Urdu',
91
+ 'uz': 'Uzbek',
92
+ 'vi': 'Vietnamese',
93
+ 'xh': 'Xhosa',
94
+ 'yi': 'Yiddish',
95
+ 'zh': 'Chinese',
96
+ }
97
+
98
+ LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}
99
+
100
+
101
+ def is_arabic(lang_code):
102
+ return lang_code in ["ar", "fa", "ps", "ug", "ur"]
surya/layout.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from concurrent.futures import ProcessPoolExecutor
3
+ from typing import List, Optional
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from surya.detection import batch_detection
8
+ from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
9
+ from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
10
+ from surya.settings import settings
11
+
12
+
13
+ def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
14
+ logits = np.stack(heatmaps, axis=0)
15
+ vertical_line_bboxes = detection_result.vertical_lines
16
+ line_bboxes = detection_result.bboxes
17
+
18
+ # Scale back to processor size
19
+ for line in vertical_line_bboxes:
20
+ line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape)))
21
+
22
+ for line in line_bboxes:
23
+ line.rescale(orig_size, list(reversed(heatmaps[0].shape)))
24
+
25
+ for bbox in vertical_line_bboxes:
26
+ # Give some width to the vertical lines
27
+ vert_bbox = list(bbox.bbox)
28
+ vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width)
29
+
30
+ logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are
31
+
32
+ logits[:, logits[0] >= .5] = 0 # zero out where blanks are
33
+
34
+ # Zero out where other segments are
35
+ for i in range(logits.shape[0]):
36
+ logits[i, segment_assignment != i] = 0
37
+
38
+ detected_boxes = []
39
+ for heatmap_idx in range(1, len(id2label)): # Skip the blank class
40
+ heatmap = logits[heatmap_idx]
41
+ if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
42
+ continue
43
+ bboxes = get_detected_boxes(heatmap)
44
+ bboxes = [bbox for bbox in bboxes if bbox.area > 25]
45
+ for bb in bboxes:
46
+ bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
47
+
48
+ for bbox in bboxes:
49
+ detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1))
50
+
51
+ detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
52
+ # Expand bbox to cover intersecting lines
53
+ box_lines = defaultdict(list)
54
+ used_lines = set()
55
+
56
+ # We try 2 rounds of identifying the correct lines to snap to
57
+ # First round is majority intersection, second lowers the threshold
58
+ for thresh in [.5, .4]:
59
+ for bbox_idx, bbox in enumerate(detected_boxes):
60
+ for line_idx, line_bbox in enumerate(line_bboxes):
61
+ if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines:
62
+ box_lines[bbox_idx].append(line_bbox.bbox)
63
+ used_lines.add(line_idx)
64
+
65
+ new_boxes = []
66
+ for bbox_idx, bbox in enumerate(detected_boxes):
67
+ if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures
68
+ continue
69
+
70
+ # Skip if we didn't find any lines to snap to, except for Pictures and Formulas
71
+ if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
72
+ continue
73
+
74
+ covered_lines = box_lines[bbox_idx]
75
+ # Snap non-picture layout boxes to correct text boundaries
76
+ if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
77
+ min_x = min([line[0] for line in covered_lines])
78
+ min_y = min([line[1] for line in covered_lines])
79
+ max_x = max([line[2] for line in covered_lines])
80
+ max_y = max([line[3] for line in covered_lines])
81
+
82
+ # Tables and formulas can contain text, but text isn't the whole area
83
+ if bbox.label in ["Table", "Formula"]:
84
+ min_x_box = min([b[0] for b in bbox.polygon])
85
+ min_y_box = min([b[1] for b in bbox.polygon])
86
+ max_x_box = max([b[0] for b in bbox.polygon])
87
+ max_y_box = max([b[1] for b in bbox.polygon])
88
+
89
+ min_x = min(min_x, min_x_box)
90
+ min_y = min(min_y, min_y_box)
91
+ max_x = max(max_x, max_x_box)
92
+ max_y = max(max_y, max_y_box)
93
+
94
+ bbox.polygon = [
95
+ [min_x, min_y],
96
+ [max_x, min_y],
97
+ [max_x, max_y],
98
+ [min_x, max_y]
99
+ ]
100
+
101
+ if bbox_idx in box_lines and bbox.label in ["Picture"]:
102
+ bbox.label = "Figure"
103
+
104
+ new_boxes.append(bbox)
105
+
106
+ # Merge tables together (sometimes one column is detected as a separate table)
107
+ mergeable_types = ["Table", "Picture", "Figure"]
108
+ for ftype in mergeable_types:
109
+ to_remove = set()
110
+ for bbox_idx, bbox in enumerate(new_boxes):
111
+ if bbox.label != ftype or bbox_idx in to_remove:
112
+ continue
113
+
114
+ for bbox_idx2, bbox2 in enumerate(new_boxes):
115
+ if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
116
+ continue
117
+
118
+ if bbox.intersection_pct(bbox2, x_margin=.25) > .1:
119
+ bbox.merge(bbox2)
120
+ to_remove.add(bbox_idx2)
121
+
122
+ new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove]
123
+
124
+ # Ensure we account for all text lines in the layout
125
+ unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
126
+ for bbox in unused_lines:
127
+ new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))
128
+
129
+ for bbox in new_boxes:
130
+ bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)
131
+
132
+ detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]
133
+
134
+ # Remove bboxes contained inside others, unless they're captions
135
+ contained_bbox = []
136
+ for i, bbox in enumerate(detected_boxes):
137
+ for j, bbox2 in enumerate(detected_boxes):
138
+ if i == j:
139
+ continue
140
+
141
+ if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]:
142
+ contained_bbox.append(j)
143
+
144
+ detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox]
145
+
146
+ return detected_boxes
147
+
148
+
149
+ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
150
+ bboxes = []
151
+ for i in range(1, len(id2label)): # Skip the blank class
152
+ heatmap = heatmaps[i]
153
+ assert heatmap.shape == segment_assignment.shape
154
+ heatmap[segment_assignment != i] = 0 # zero out where another segment is
155
+
156
+ # Skip processing empty labels
157
+ if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
158
+ continue
159
+
160
+ bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
161
+ for bb in bbox:
162
+ bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
163
+
164
+ bboxes = keep_largest_boxes(bboxes)
165
+ return bboxes
166
+
167
+
168
+ def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
169
+ logits = np.stack(heatmaps, axis=0)
170
+ segment_assignment = logits.argmax(axis=0)
171
+ if detection_results is not None:
172
+ bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label,
173
+ segment_assignment)
174
+ else:
175
+ bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)
176
+
177
+ segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))
178
+
179
+ result = LayoutResult(
180
+ bboxes=bboxes,
181
+ segmentation_map=segmentation_img,
182
+ heatmaps=heatmaps,
183
+ image_bbox=[0, 0, orig_size[0], orig_size[1]]
184
+ )
185
+
186
+ return result
187
+
188
+
189
+ def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
190
+ layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
191
+ id2label = model.config.id2label
192
+
193
+ results = []
194
+ max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
195
+ parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
196
+
197
+ if parallelize:
198
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
199
+ img_idx = 0
200
+ for preds, orig_sizes in layout_generator:
201
+ futures = []
202
+ for pred, orig_size in zip(preds, orig_sizes):
203
+ future = executor.submit(
204
+ parallel_get_regions,
205
+ pred,
206
+ orig_size,
207
+ id2label,
208
+ detection_results[img_idx] if detection_results else None
209
+ )
210
+
211
+ futures.append(future)
212
+ img_idx += 1
213
+
214
+ for future in futures:
215
+ results.append(future.result())
216
+ else:
217
+ img_idx = 0
218
+ for preds, orig_sizes in layout_generator:
219
+ for pred, orig_size in zip(preds, orig_sizes):
220
+ results.append(parallel_get_regions(
221
+ pred,
222
+ orig_size,
223
+ id2label,
224
+ detection_results[img_idx] if detection_results else None
225
+ ))
226
+
227
+ img_idx += 1
228
+
229
+ return results
surya/model/detection/config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class EfficientViTConfig(PretrainedConfig):
5
+ r"""
6
+ ```"""
7
+
8
+ model_type = "efficientvit"
9
+
10
+ def __init__(
11
+ self,
12
+ num_classes=2,
13
+ num_channels=3,
14
+ widths=(32, 64, 128, 256, 512),
15
+ head_dim=32,
16
+ num_stages=4,
17
+ depths=(1, 1, 1, 6, 6),
18
+ strides=(2, 2, 2, 2, 2),
19
+ hidden_sizes=(32, 64, 160, 256),
20
+ patch_size=(7, 7),
21
+ hidden_dropout_prob=0.0,
22
+ attention_probs_dropout_prob=0.0,
23
+ classifier_dropout_prob=0.0,
24
+ layer_norm_eps=1e-6,
25
+ decoder_layer_hidden_size=128,
26
+ decoder_hidden_size=512,
27
+ semantic_loss_ignore_index=255,
28
+ initializer_range=0.02,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+
33
+ self.num_classes = num_classes
34
+ self.widths = widths
35
+ self.head_dim = head_dim
36
+
37
+ self.num_channels = num_channels
38
+ self.num_stages = num_stages
39
+ self.depths = depths
40
+ self.strides = strides
41
+ self.hidden_sizes = hidden_sizes
42
+ self.patch_size = patch_size
43
+ self.hidden_dropout_prob = hidden_dropout_prob
44
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
45
+ self.classifier_dropout_prob = classifier_dropout_prob
46
+ self.layer_norm_eps = layer_norm_eps
47
+ self.decoder_hidden_size = decoder_hidden_size
48
+ self.decoder_layer_hidden_size = decoder_layer_hidden_size
49
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
50
+
51
+ self.initializer_range = initializer_range
surya/model/detection/model.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an implementation of efficientvit, with some modifications (decode head, etc).
3
+
4
+ Original paper at https://arxiv.org/abs/2205.14756
5
+
6
+ Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
7
+ Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
8
+ """
9
+
10
+ from typing import Optional, Union, Tuple
11
+ from functools import partial
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from transformers import PreTrainedModel
18
+ from transformers.modeling_outputs import SemanticSegmenterOutput
19
+
20
+ from surya.model.detection.config import EfficientViTConfig
21
+ from surya.model.detection.processor import SegformerImageProcessor
22
+ from surya.settings import settings
23
+
24
+
25
+ def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
26
+ config = EfficientViTConfig.from_pretrained(checkpoint)
27
+ model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, ignore_mismatched_sizes=True)
28
+ model = model.to(device)
29
+ model = model.eval()
30
+ print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
31
+ return model
32
+
33
+
34
+ def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT):
35
+ processor = SegformerImageProcessor.from_pretrained(checkpoint)
36
+ return processor
37
+
38
+
39
+ def val2list(x: list or tuple or any, repeat_time=1):
40
+ if isinstance(x, (list, tuple)):
41
+ return list(x)
42
+ return [x for _ in range(repeat_time)]
43
+
44
+
45
+ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
46
+ # repeat elements if necessary
47
+ x = val2list(x)
48
+ if len(x) > 0:
49
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
50
+
51
+ return tuple(x)
52
+
53
+
54
+ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
55
+ if isinstance(kernel_size, tuple):
56
+ return tuple([get_same_padding(ks) for ks in kernel_size])
57
+ else:
58
+ assert kernel_size % 2 > 0, "kernel size should be odd number"
59
+ return kernel_size // 2
60
+
61
+
62
+ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
63
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
64
+ return padding
65
+
66
+ class ConvNormAct(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ kernel_size=3,
72
+ stride=1,
73
+ dilation=1,
74
+ groups=1,
75
+ bias=False,
76
+ dropout=0.,
77
+ norm_layer=nn.BatchNorm2d,
78
+ act_layer=nn.ReLU,
79
+ ):
80
+ super(ConvNormAct, self).__init__()
81
+ self.dropout = nn.Dropout(dropout, inplace=False)
82
+ padding = get_padding(kernel_size, stride, dilation)
83
+ self.conv = nn.Conv2d(
84
+ in_channels,
85
+ out_channels,
86
+ kernel_size=kernel_size,
87
+ stride=stride,
88
+ dilation=dilation,
89
+ groups=groups,
90
+ bias=bias,
91
+ padding=padding,
92
+ )
93
+ self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
94
+ self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
95
+
96
+ def forward(self, x):
97
+ x = self.conv(x)
98
+ x = self.norm(x)
99
+ x = self.act(x)
100
+ return x
101
+
102
+
103
+ class DSConv(nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ kernel_size=3,
109
+ stride=1,
110
+ use_bias=False,
111
+ norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
112
+ act_layer=(nn.ReLU6, None),
113
+ ):
114
+ super(DSConv, self).__init__()
115
+ use_bias = val2tuple(use_bias, 2)
116
+ norm_layer = val2tuple(norm_layer, 2)
117
+ act_layer = val2tuple(act_layer, 2)
118
+
119
+ self.depth_conv = ConvNormAct(
120
+ in_channels,
121
+ in_channels,
122
+ kernel_size,
123
+ stride,
124
+ groups=in_channels,
125
+ norm_layer=norm_layer[0],
126
+ act_layer=act_layer[0],
127
+ bias=use_bias[0],
128
+ )
129
+ self.point_conv = ConvNormAct(
130
+ in_channels,
131
+ out_channels,
132
+ 1,
133
+ norm_layer=norm_layer[1],
134
+ act_layer=act_layer[1],
135
+ bias=use_bias[1],
136
+ )
137
+
138
+ def forward(self, x):
139
+ x = self.depth_conv(x)
140
+ x = self.point_conv(x)
141
+ return x
142
+
143
+
144
+ class ConvBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ kernel_size=3,
150
+ stride=1,
151
+ mid_channels=None,
152
+ expand_ratio=1,
153
+ use_bias=False,
154
+ norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
155
+ act_layer=(nn.ReLU6, None),
156
+ ):
157
+ super(ConvBlock, self).__init__()
158
+ use_bias = val2tuple(use_bias, 2)
159
+ norm_layer = val2tuple(norm_layer, 2)
160
+ act_layer = val2tuple(act_layer, 2)
161
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
162
+
163
+ self.conv1 = ConvNormAct(
164
+ in_channels,
165
+ mid_channels,
166
+ kernel_size,
167
+ stride,
168
+ norm_layer=norm_layer[0],
169
+ act_layer=act_layer[0],
170
+ bias=use_bias[0],
171
+ )
172
+ self.conv2 = ConvNormAct(
173
+ mid_channels,
174
+ out_channels,
175
+ kernel_size,
176
+ 1,
177
+ norm_layer=norm_layer[1],
178
+ act_layer=act_layer[1],
179
+ bias=use_bias[1],
180
+ )
181
+
182
+ def forward(self, x):
183
+ x = self.conv1(x)
184
+ x = self.conv2(x)
185
+ return x
186
+
187
+
188
+ class MBConv(nn.Module):
189
+ def __init__(
190
+ self,
191
+ in_channels: int,
192
+ out_channels: int,
193
+ kernel_size=3,
194
+ stride=1,
195
+ mid_channels=None,
196
+ expand_ratio=6,
197
+ use_bias=False,
198
+ norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
199
+ act_layer=(nn.ReLU6, nn.ReLU6, None),
200
+ ):
201
+ super(MBConv, self).__init__()
202
+ use_bias = val2tuple(use_bias, 3)
203
+ norm_layer = val2tuple(norm_layer, 3)
204
+ act_layer = val2tuple(act_layer, 3)
205
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
206
+
207
+ self.inverted_conv = ConvNormAct(
208
+ in_channels,
209
+ mid_channels,
210
+ 1,
211
+ stride=1,
212
+ norm_layer=norm_layer[0],
213
+ act_layer=act_layer[0],
214
+ bias=use_bias[0],
215
+ )
216
+ self.depth_conv = ConvNormAct(
217
+ mid_channels,
218
+ mid_channels,
219
+ kernel_size,
220
+ stride=stride,
221
+ groups=mid_channels,
222
+ norm_layer=norm_layer[1],
223
+ act_layer=act_layer[1],
224
+ bias=use_bias[1],
225
+ )
226
+ self.point_conv = ConvNormAct(
227
+ mid_channels,
228
+ out_channels,
229
+ 1,
230
+ norm_layer=norm_layer[2],
231
+ act_layer=act_layer[2],
232
+ bias=use_bias[2],
233
+ )
234
+
235
+ def forward(self, x):
236
+ x = self.inverted_conv(x)
237
+ x = self.depth_conv(x)
238
+ x = self.point_conv(x)
239
+ return x
240
+
241
+
242
+ class FusedMBConv(nn.Module):
243
+ def __init__(
244
+ self,
245
+ in_channels: int,
246
+ out_channels: int,
247
+ kernel_size=3,
248
+ stride=1,
249
+ mid_channels=None,
250
+ expand_ratio=6,
251
+ groups=1,
252
+ use_bias=False,
253
+ norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
254
+ act_layer=(nn.ReLU6, None),
255
+ ):
256
+ super(FusedMBConv, self).__init__()
257
+ use_bias = val2tuple(use_bias, 2)
258
+ norm_layer = val2tuple(norm_layer, 2)
259
+ act_layer = val2tuple(act_layer, 2)
260
+ mid_channels = mid_channels or round(in_channels * expand_ratio)
261
+
262
+ self.spatial_conv = ConvNormAct(
263
+ in_channels,
264
+ mid_channels,
265
+ kernel_size,
266
+ stride=stride,
267
+ groups=groups,
268
+ norm_layer=norm_layer[0],
269
+ act_layer=act_layer[0],
270
+ bias=use_bias[0],
271
+ )
272
+ self.point_conv = ConvNormAct(
273
+ mid_channels,
274
+ out_channels,
275
+ 1,
276
+ norm_layer=norm_layer[1],
277
+ act_layer=act_layer[1],
278
+ bias=use_bias[1],
279
+ )
280
+
281
+ def forward(self, x):
282
+ x = self.spatial_conv(x)
283
+ x = self.point_conv(x)
284
+ return x
285
+
286
+
287
+ class LiteMLA(nn.Module):
288
+ """Lightweight multi-scale linear attention"""
289
+
290
+ def __init__(
291
+ self,
292
+ in_channels: int,
293
+ out_channels: int,
294
+ heads: int or None = None,
295
+ heads_ratio: float = 1.0,
296
+ dim=8,
297
+ use_bias=False,
298
+ norm_layer=(None, nn.BatchNorm2d),
299
+ act_layer=(None, None),
300
+ kernel_func=nn.ReLU,
301
+ scales=(5,),
302
+ eps=1e-5,
303
+ ):
304
+ super(LiteMLA, self).__init__()
305
+ self.eps = eps
306
+ heads = heads or int(in_channels // dim * heads_ratio)
307
+ total_dim = heads * dim
308
+ use_bias = val2tuple(use_bias, 2)
309
+ norm_layer = val2tuple(norm_layer, 2)
310
+ act_layer = val2tuple(act_layer, 2)
311
+
312
+ self.dim = dim
313
+ self.qkv = ConvNormAct(
314
+ in_channels,
315
+ 3 * total_dim,
316
+ 1,
317
+ bias=use_bias[0],
318
+ norm_layer=norm_layer[0],
319
+ act_layer=act_layer[0],
320
+ )
321
+ self.aggreg = nn.ModuleList([
322
+ nn.Sequential(
323
+ nn.Conv2d(
324
+ 3 * total_dim,
325
+ 3 * total_dim,
326
+ scale,
327
+ padding=get_same_padding(scale),
328
+ groups=3 * total_dim,
329
+ bias=use_bias[0],
330
+ ),
331
+ nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
332
+ )
333
+ for scale in scales
334
+ ])
335
+ self.kernel_func = kernel_func(inplace=False)
336
+
337
+ self.proj = ConvNormAct(
338
+ total_dim * (1 + len(scales)),
339
+ out_channels,
340
+ 1,
341
+ bias=use_bias[1],
342
+ norm_layer=norm_layer[1],
343
+ act_layer=act_layer[1],
344
+ )
345
+
346
+ def _attn(self, q, k, v):
347
+ dtype = v.dtype
348
+ q, k, v = q.float(), k.float(), v.float()
349
+ kv = k.transpose(-1, -2) @ v
350
+ out = q @ kv
351
+ out = out[..., :-1] / (out[..., -1:] + self.eps)
352
+ return out.to(dtype)
353
+
354
+ def forward(self, x):
355
+ # Shape is B, C, H, W
356
+ B, _, H, W = x.shape
357
+
358
+ # generate multi-scale q, k, v
359
+ qkv = self.qkv(x)
360
+ multi_scale_qkv = [qkv]
361
+ for op in self.aggreg:
362
+ multi_scale_qkv.append(op(qkv))
363
+ multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
364
+ multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
365
+ # Shape for each is B, C, HW, head_dim
366
+ q, k, v = multi_scale_qkv.chunk(3, dim=-1)
367
+
368
+ # lightweight global attention
369
+ q = self.kernel_func(q)
370
+ k = self.kernel_func(k)
371
+ v = F.pad(v, (0, 1), mode="constant", value=1.)
372
+
373
+ out = self._attn(q, k, v)
374
+
375
+ # final projection
376
+ out = out.transpose(-1, -2).reshape(B, -1, H, W)
377
+ out = self.proj(out)
378
+ return out
379
+
380
+
381
+ class EfficientVitBlock(nn.Module):
382
+ def __init__(
383
+ self,
384
+ in_channels,
385
+ heads_ratio=1.0,
386
+ head_dim=32,
387
+ expand_ratio=4,
388
+ norm_layer=nn.BatchNorm2d,
389
+ act_layer=nn.Hardswish,
390
+ ):
391
+ super(EfficientVitBlock, self).__init__()
392
+ self.context_module = ResidualBlock(
393
+ LiteMLA(
394
+ in_channels=in_channels,
395
+ out_channels=in_channels,
396
+ heads_ratio=heads_ratio,
397
+ dim=head_dim,
398
+ norm_layer=(None, norm_layer),
399
+ ),
400
+ nn.Identity(),
401
+ )
402
+ self.local_module = ResidualBlock(
403
+ MBConv(
404
+ in_channels=in_channels,
405
+ out_channels=in_channels,
406
+ expand_ratio=expand_ratio,
407
+ use_bias=(True, True, False),
408
+ norm_layer=(None, None, norm_layer),
409
+ act_layer=(act_layer, act_layer, None),
410
+ ),
411
+ nn.Identity(),
412
+ )
413
+
414
+ def forward(self, x):
415
+ x = self.context_module(x)
416
+ x = self.local_module(x)
417
+ return x
418
+
419
+
420
+ class ResidualBlock(nn.Module):
421
+ def __init__(
422
+ self,
423
+ main: Optional[nn.Module],
424
+ shortcut: Optional[nn.Module] = None,
425
+ pre_norm: Optional[nn.Module] = None,
426
+ ):
427
+ super(ResidualBlock, self).__init__()
428
+ self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
429
+ self.main = main
430
+ self.shortcut = shortcut
431
+
432
+ def forward(self, x):
433
+ res = self.main(self.pre_norm(x))
434
+ if self.shortcut is not None:
435
+ res = res + self.shortcut(x)
436
+ return res
437
+
438
+
439
+ def build_local_block(
440
+ in_channels: int,
441
+ out_channels: int,
442
+ stride: int,
443
+ kernel_size: int,
444
+ expand_ratio: float,
445
+ norm_layer: str,
446
+ act_layer: str,
447
+ fewer_norm: bool = False,
448
+ block_type: str = "default",
449
+ ):
450
+ assert block_type in ["default", "large", "fused"]
451
+ if expand_ratio == 1:
452
+ if block_type == "default":
453
+ block = DSConv(
454
+ in_channels=in_channels,
455
+ out_channels=out_channels,
456
+ stride=stride,
457
+ kernel_size=kernel_size,
458
+ use_bias=(True, False) if fewer_norm else False,
459
+ norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
460
+ act_layer=(act_layer, None),
461
+ )
462
+ else:
463
+ block = ConvBlock(
464
+ in_channels=in_channels,
465
+ out_channels=out_channels,
466
+ stride=stride,
467
+ kernel_size=kernel_size,
468
+ use_bias=(True, False) if fewer_norm else False,
469
+ norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
470
+ act_layer=(act_layer, None),
471
+ )
472
+ else:
473
+ if block_type == "default":
474
+ block = MBConv(
475
+ in_channels=in_channels,
476
+ out_channels=out_channels,
477
+ stride=stride,
478
+ kernel_size=kernel_size,
479
+ expand_ratio=expand_ratio,
480
+ use_bias=(True, True, False) if fewer_norm else False,
481
+ norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
482
+ act_layer=(act_layer, act_layer, None),
483
+ )
484
+ else:
485
+ block = FusedMBConv(
486
+ in_channels=in_channels,
487
+ out_channels=out_channels,
488
+ stride=stride,
489
+ kernel_size=kernel_size,
490
+ expand_ratio=expand_ratio,
491
+ use_bias=(True, False) if fewer_norm else False,
492
+ norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
493
+ act_layer=(act_layer, None),
494
+ )
495
+ return block
496
+
497
+
498
+ class Stem(nn.Sequential):
499
+ def __init__(self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type='default'):
500
+ super().__init__()
501
+ self.stride = stride
502
+
503
+ self.add_module(
504
+ 'in_conv',
505
+ ConvNormAct(
506
+ in_chs, out_chs,
507
+ kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer,
508
+ )
509
+ )
510
+ stem_block = 0
511
+ for _ in range(depth):
512
+ self.add_module(f'res{stem_block}', ResidualBlock(
513
+ build_local_block(
514
+ in_channels=out_chs,
515
+ out_channels=out_chs,
516
+ stride=1,
517
+ kernel_size=3,
518
+ expand_ratio=1,
519
+ norm_layer=norm_layer,
520
+ act_layer=act_layer,
521
+ block_type=block_type,
522
+ ),
523
+ nn.Identity(),
524
+ ))
525
+ stem_block += 1
526
+
527
+
528
+ class EfficientVitLargeStage(nn.Module):
529
+ def __init__(
530
+ self,
531
+ in_chs,
532
+ out_chs,
533
+ depth,
534
+ stride,
535
+ norm_layer,
536
+ act_layer,
537
+ head_dim,
538
+ vit_stage=False,
539
+ fewer_norm=False,
540
+ ):
541
+ super(EfficientVitLargeStage, self).__init__()
542
+ blocks = [ResidualBlock(
543
+ build_local_block(
544
+ in_channels=in_chs,
545
+ out_channels=out_chs,
546
+ stride=stride,
547
+ kernel_size=stride + 1,
548
+ expand_ratio=24 if vit_stage else 16,
549
+ norm_layer=norm_layer,
550
+ act_layer=act_layer,
551
+ fewer_norm=vit_stage or fewer_norm,
552
+ block_type='default' if fewer_norm else 'fused',
553
+ ),
554
+ None,
555
+ )]
556
+ in_chs = out_chs
557
+
558
+ if vit_stage:
559
+ # for stage 4
560
+ for _ in range(depth):
561
+ blocks.append(
562
+ EfficientVitBlock(
563
+ in_channels=in_chs,
564
+ head_dim=head_dim,
565
+ expand_ratio=6,
566
+ norm_layer=norm_layer,
567
+ act_layer=act_layer,
568
+ )
569
+ )
570
+ else:
571
+ # for stage 1, 2, 3
572
+ for i in range(depth):
573
+ blocks.append(ResidualBlock(
574
+ build_local_block(
575
+ in_channels=in_chs,
576
+ out_channels=out_chs,
577
+ stride=1,
578
+ kernel_size=3,
579
+ expand_ratio=4,
580
+ norm_layer=norm_layer,
581
+ act_layer=act_layer,
582
+ fewer_norm=fewer_norm,
583
+ block_type='default' if fewer_norm else 'fused',
584
+ ),
585
+ nn.Identity(),
586
+ ))
587
+
588
+ self.blocks = nn.Sequential(*blocks)
589
+
590
+ def forward(self, x):
591
+ return self.blocks(x)
592
+
593
+
594
+ class EfficientVitLarge(nn.Module):
595
+ def __init__(
596
+ self,
597
+ config: EfficientViTConfig,
598
+ norm_layer=nn.BatchNorm2d,
599
+ act_layer=nn.Hardswish,
600
+ ):
601
+ super(EfficientVitLarge, self).__init__()
602
+ self.grad_checkpointing = False
603
+ self.num_classes = config.num_classes
604
+ self.norm_eps = config.layer_norm_eps
605
+ norm_layer = partial(norm_layer, eps=self.norm_eps)
606
+
607
+ # input stem
608
+ self.stem = Stem(config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type='large')
609
+ stride = config.strides[0]
610
+
611
+ # stages
612
+ self.feature_info = []
613
+ self.stages = nn.Sequential()
614
+ in_channels = config.widths[0]
615
+ for i, (w, d, s) in enumerate(zip(config.widths[1:], config.depths[1:], config.strides[1:])):
616
+ self.stages.append(EfficientVitLargeStage(
617
+ in_channels,
618
+ w,
619
+ depth=d,
620
+ stride=s,
621
+ norm_layer=norm_layer,
622
+ act_layer=act_layer,
623
+ head_dim=config.head_dim,
624
+ vit_stage=i >= 3,
625
+ fewer_norm=i >= 2,
626
+ ))
627
+ stride *= s
628
+ in_channels = w
629
+ self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
630
+
631
+ self.num_features = in_channels
632
+
633
+ @torch.jit.ignore
634
+ def set_grad_checkpointing(self, enable=True):
635
+ self.grad_checkpointing = enable
636
+
637
+ def forward(self, x):
638
+ x = self.stem(x)
639
+ encoder_hidden_states = []
640
+ for i, module in enumerate(self.stages):
641
+ x = module(x)
642
+ encoder_hidden_states.append(x)
643
+
644
+ return encoder_hidden_states
645
+
646
+
647
+ class EfficientViTPreTrainedModel(PreTrainedModel):
648
+ """
649
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
650
+ models.
651
+ """
652
+
653
+ config_class = EfficientViTConfig
654
+ base_model_prefix = "efficientvit"
655
+ main_input_name = "pixel_values"
656
+
657
+ def _init_weights(self, module):
658
+ """Initialize the weights"""
659
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
660
+ # Slightly different from the TF version which uses truncated_normal for initialization
661
+ # cf https://github.com/pytorch/pytorch/pull/5617
662
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
663
+ if module.bias is not None:
664
+ module.bias.data.zero_()
665
+ elif isinstance(module, nn.Embedding):
666
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
667
+ if module.padding_idx is not None:
668
+ module.weight.data[module.padding_idx].zero_()
669
+ elif isinstance(module, nn.LayerNorm):
670
+ module.bias.data.zero_()
671
+ module.weight.data.fill_(1.0)
672
+
673
+
674
+ class DecodeMLP(nn.Module):
675
+ def __init__(self, input_dim, output_dim):
676
+ super().__init__()
677
+ self.proj = nn.Linear(input_dim, output_dim)
678
+
679
+ def forward(self, hidden_states: torch.Tensor):
680
+ # Input is B, C, H, W
681
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
682
+ # Output is B, HW, C
683
+ hidden_states = self.proj(hidden_states)
684
+ return hidden_states
685
+
686
+
687
+ class DecodeHead(EfficientViTPreTrainedModel):
688
+ def __init__(self, config: EfficientViTConfig):
689
+ super().__init__(config)
690
+
691
+ # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
692
+ mlps = []
693
+ for width in config.widths[1:]:
694
+ mlp = DecodeMLP(input_dim=width, output_dim=config.decoder_layer_hidden_size)
695
+ mlps.append(mlp)
696
+ self.linear_c = nn.ModuleList(mlps)
697
+
698
+ # the following 3 layers implement the ConvModule of the original implementation
699
+ self.linear_fuse = nn.Conv2d(
700
+ in_channels=config.decoder_layer_hidden_size * config.num_stages,
701
+ out_channels=config.decoder_hidden_size,
702
+ kernel_size=1,
703
+ bias=False,
704
+ )
705
+ self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
706
+ self.activation = nn.ReLU()
707
+
708
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
709
+ self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
710
+
711
+ self.config = config
712
+
713
+ def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
714
+ batch_size = encoder_hidden_states[-1].shape[0]
715
+
716
+ all_hidden_states = ()
717
+ for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
718
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
719
+ encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C
720
+ # Permute to B, C, HW
721
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
722
+ encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
723
+ # upsample
724
+ encoder_hidden_state = nn.functional.interpolate(
725
+ encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
726
+ )
727
+ all_hidden_states += (encoder_hidden_state,)
728
+
729
+ hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
730
+ hidden_states = self.batch_norm(hidden_states)
731
+ hidden_states = self.activation(hidden_states)
732
+
733
+ # logits are of shape (batch_size, num_labels, height/4, width/4)
734
+ logits = self.classifier(hidden_states)
735
+
736
+ return logits
737
+
738
+
739
+ class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel):
740
+ def __init__(self, config, **kwargs):
741
+ super().__init__(config)
742
+ self.vit = EfficientVitLarge(config)
743
+ self.decode_head = DecodeHead(config)
744
+
745
+ # Initialize weights and apply final processing
746
+ self.post_init()
747
+
748
+ def forward(
749
+ self,
750
+ pixel_values: torch.FloatTensor
751
+ ) -> Union[Tuple, SemanticSegmenterOutput]:
752
+
753
+ # Pixel values should be B,C,H,W
754
+ encoder_hidden_states = self.vit(
755
+ pixel_values,
756
+ )
757
+
758
+ logits = self.decode_head(encoder_hidden_states)
759
+
760
+ # Apply sigmoid to get 0-1 output
761
+ logits = torch.special.expit(logits)
762
+
763
+ return SemanticSegmenterOutput(
764
+ loss=None,
765
+ logits=logits,
766
+ hidden_states=encoder_hidden_states
767
+ )
surya/model/detection/processor.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+
6
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
7
+ from transformers.image_transforms import to_channel_dimension_format
8
+ from transformers.image_utils import (
9
+ IMAGENET_DEFAULT_MEAN,
10
+ IMAGENET_DEFAULT_STD,
11
+ ChannelDimension,
12
+ ImageInput,
13
+ PILImageResampling,
14
+ infer_channel_dimension_format,
15
+ make_list_of_images,
16
+ )
17
+ from transformers.utils import TensorType
18
+
19
+
20
+ import PIL.Image
21
+ import torch
22
+
23
+
24
+ class SegformerImageProcessor(BaseImageProcessor):
25
+ r"""
26
+ Constructs a Segformer image processor.
27
+
28
+ Args:
29
+ do_resize (`bool`, *optional*, defaults to `True`):
30
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
31
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
32
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
33
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
34
+ method.
35
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
36
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
37
+ `preprocess` method.
38
+ do_rescale (`bool`, *optional*, defaults to `True`):
39
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
40
+ parameter in the `preprocess` method.
41
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
42
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
43
+ method.
44
+ do_normalize (`bool`, *optional*, defaults to `True`):
45
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
46
+ method.
47
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
48
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
49
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
50
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
51
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
52
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
53
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
54
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
55
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
56
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
57
+ `preprocess` method.
58
+ """
59
+
60
+ model_input_names = ["pixel_values"]
61
+
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ size: Dict[str, int] = None,
66
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
67
+ do_rescale: bool = True,
68
+ rescale_factor: Union[int, float] = 1 / 255,
69
+ do_normalize: bool = True,
70
+ image_mean: Optional[Union[float, List[float]]] = None,
71
+ image_std: Optional[Union[float, List[float]]] = None,
72
+ do_reduce_labels: bool = False,
73
+ **kwargs,
74
+ ) -> None:
75
+ if "reduce_labels" in kwargs:
76
+ warnings.warn(
77
+ "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
78
+ "`do_reduce_labels` instead.",
79
+ FutureWarning,
80
+ )
81
+ do_reduce_labels = kwargs.pop("reduce_labels")
82
+
83
+ super().__init__(**kwargs)
84
+ size = size if size is not None else {"height": 512, "width": 512}
85
+ size = get_size_dict(size)
86
+ self.do_resize = do_resize
87
+ self.size = size
88
+ self.resample = resample
89
+ self.do_rescale = do_rescale
90
+ self.rescale_factor = rescale_factor
91
+ self.do_normalize = do_normalize
92
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
93
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
94
+ self.do_reduce_labels = do_reduce_labels
95
+ self._valid_processor_keys = [
96
+ "images",
97
+ "segmentation_maps",
98
+ "do_resize",
99
+ "size",
100
+ "resample",
101
+ "do_rescale",
102
+ "rescale_factor",
103
+ "do_normalize",
104
+ "image_mean",
105
+ "image_std",
106
+ "do_reduce_labels",
107
+ "return_tensors",
108
+ "data_format",
109
+ "input_data_format",
110
+ ]
111
+
112
+ @classmethod
113
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
114
+ """
115
+ Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
116
+ processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
117
+ reduce_labels=True)`
118
+ """
119
+ image_processor_dict = image_processor_dict.copy()
120
+ if "reduce_labels" in kwargs:
121
+ image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
122
+ return super().from_dict(image_processor_dict, **kwargs)
123
+
124
+ def _preprocess(
125
+ self,
126
+ image: ImageInput,
127
+ do_resize: bool,
128
+ do_rescale: bool,
129
+ do_normalize: bool,
130
+ size: Optional[Dict[str, int]] = None,
131
+ resample: PILImageResampling = None,
132
+ rescale_factor: Optional[float] = None,
133
+ image_mean: Optional[Union[float, List[float]]] = None,
134
+ image_std: Optional[Union[float, List[float]]] = None,
135
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
136
+ ):
137
+
138
+ if do_rescale:
139
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
140
+
141
+ if do_normalize:
142
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
143
+
144
+ return image
145
+
146
+ def _preprocess_image(
147
+ self,
148
+ image: ImageInput,
149
+ do_resize: bool = None,
150
+ size: Dict[str, int] = None,
151
+ resample: PILImageResampling = None,
152
+ do_rescale: bool = None,
153
+ rescale_factor: float = None,
154
+ do_normalize: bool = None,
155
+ image_mean: Optional[Union[float, List[float]]] = None,
156
+ image_std: Optional[Union[float, List[float]]] = None,
157
+ data_format: Optional[Union[str, ChannelDimension]] = None,
158
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
159
+ ) -> np.ndarray:
160
+ """Preprocesses a single image."""
161
+ # All transformations expect numpy arrays.
162
+ if input_data_format is None:
163
+ input_data_format = infer_channel_dimension_format(image)
164
+
165
+ image = self._preprocess(
166
+ image=image,
167
+ do_resize=do_resize,
168
+ size=size,
169
+ resample=resample,
170
+ do_rescale=do_rescale,
171
+ rescale_factor=rescale_factor,
172
+ do_normalize=do_normalize,
173
+ image_mean=image_mean,
174
+ image_std=image_std,
175
+ input_data_format=input_data_format,
176
+ )
177
+ if data_format is not None:
178
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
179
+ return image
180
+
181
+ def __call__(self, images, segmentation_maps=None, **kwargs):
182
+ """
183
+ Preprocesses a batch of images and optionally segmentation maps.
184
+
185
+ Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
186
+ passed in as positional arguments.
187
+ """
188
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
189
+
190
+ def preprocess(
191
+ self,
192
+ images: ImageInput,
193
+ segmentation_maps: Optional[ImageInput] = None,
194
+ do_resize: Optional[bool] = None,
195
+ size: Optional[Dict[str, int]] = None,
196
+ resample: PILImageResampling = None,
197
+ do_rescale: Optional[bool] = None,
198
+ rescale_factor: Optional[float] = None,
199
+ do_normalize: Optional[bool] = None,
200
+ image_mean: Optional[Union[float, List[float]]] = None,
201
+ image_std: Optional[Union[float, List[float]]] = None,
202
+ do_reduce_labels: Optional[bool] = None,
203
+ return_tensors: Optional[Union[str, TensorType]] = None,
204
+ data_format: ChannelDimension = ChannelDimension.FIRST,
205
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
206
+ **kwargs,
207
+ ) -> PIL.Image.Image:
208
+ """
209
+ Preprocess an image or batch of images.
210
+
211
+ Args:
212
+ images (`ImageInput`):
213
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
214
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
215
+ segmentation_maps (`ImageInput`, *optional*):
216
+ Segmentation map to preprocess.
217
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
218
+ Whether to resize the image.
219
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
220
+ Size of the image after `resize` is applied.
221
+ resample (`int`, *optional*, defaults to `self.resample`):
222
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
223
+ has an effect if `do_resize` is set to `True`.
224
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
225
+ Whether to rescale the image values between [0 - 1].
226
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
227
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
228
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
229
+ Whether to normalize the image.
230
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
231
+ Image mean.
232
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
233
+ Image standard deviation.
234
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
235
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
236
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
237
+ ADE20k). The background label will be replaced by 255.
238
+ return_tensors (`str` or `TensorType`, *optional*):
239
+ The type of tensors to return. Can be one of:
240
+ - Unset: Return a list of `np.ndarray`.
241
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
242
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
243
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
244
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
245
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
246
+ The channel dimension format for the output image. Can be one of:
247
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ input_data_format (`ChannelDimension` or `str`, *optional*):
250
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
251
+ from the input image. Can be one of:
252
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
253
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
254
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
255
+ """
256
+ do_resize = do_resize if do_resize is not None else self.do_resize
257
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
258
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
259
+ resample = resample if resample is not None else self.resample
260
+ size = size if size is not None else self.size
261
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
262
+ image_mean = image_mean if image_mean is not None else self.image_mean
263
+ image_std = image_std if image_std is not None else self.image_std
264
+
265
+ images = make_list_of_images(images)
266
+ images = [
267
+ self._preprocess_image(
268
+ image=img,
269
+ do_resize=do_resize,
270
+ resample=resample,
271
+ size=size,
272
+ do_rescale=do_rescale,
273
+ rescale_factor=rescale_factor,
274
+ do_normalize=do_normalize,
275
+ image_mean=image_mean,
276
+ image_std=image_std,
277
+ data_format=data_format,
278
+ input_data_format=input_data_format,
279
+ )
280
+ for img in images
281
+ ]
282
+
283
+ data = {"pixel_values": images}
284
+ return BatchFeature(data=data, tensor_type=return_tensors)
surya/model/ordering/config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartConfig, DonutSwinConfig
2
+
3
+
4
+ class MBartOrderConfig(MBartConfig):
5
+ pass
6
+
7
+ class VariableDonutSwinConfig(DonutSwinConfig):
8
+ pass
surya/model/ordering/decoder.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List, Union, Tuple
3
+
4
+ from transformers import MBartForCausalLM, MBartConfig
5
+ from torch import nn
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
8
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
9
+ from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer
10
+ from surya.model.ordering.config import MBartOrderConfig
11
+ import torch
12
+ import math
13
+
14
+
15
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
16
+ """
17
+ From llama
18
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
19
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
20
+ """
21
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
22
+ if n_rep == 1:
23
+ return hidden_states
24
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
25
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
26
+
27
+
28
+ class MBartGQAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ embed_dim: int,
32
+ num_heads: int,
33
+ num_kv_heads: int,
34
+ dropout: float = 0.0,
35
+ is_decoder: bool = False,
36
+ bias: bool = True,
37
+ is_causal: bool = False,
38
+ config: Optional[MBartConfig] = None,
39
+ ):
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.num_heads = num_heads
43
+ self.num_kv_heads = num_kv_heads
44
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
45
+
46
+ assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
47
+ assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
48
+
49
+ self.dropout = dropout
50
+ self.head_dim = embed_dim // num_heads
51
+ self.config = config
52
+
53
+ if (self.head_dim * num_heads) != self.embed_dim:
54
+ raise ValueError(
55
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
56
+ f" and `num_heads`: {num_heads})."
57
+ )
58
+ self.scaling = self.head_dim**-0.5
59
+ self.is_decoder = is_decoder
60
+ self.is_causal = is_causal
61
+
62
+ self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
63
+ self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
64
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
65
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
66
+
67
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
68
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
69
+
70
+ def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
71
+ return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
72
+
73
+ def forward(
74
+ self,
75
+ hidden_states: torch.Tensor,
76
+ key_value_states: Optional[torch.Tensor] = None,
77
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ layer_head_mask: Optional[torch.Tensor] = None,
80
+ output_attentions: bool = False,
81
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
82
+ """Input shape: Batch x Time x Channel"""
83
+
84
+ # if key_value_states are provided this layer is used as a cross-attention layer
85
+ # for the decoder
86
+ is_cross_attention = key_value_states is not None
87
+
88
+ bsz, tgt_len, _ = hidden_states.size()
89
+
90
+ # get query proj
91
+ query_states = self.q_proj(hidden_states) * self.scaling
92
+ # get key, value proj
93
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
94
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
95
+ # the provided `key_value_states` to support prefix tuning
96
+ if (
97
+ is_cross_attention
98
+ and past_key_value is not None
99
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
100
+ ):
101
+ # reuse k,v, cross_attentions
102
+ key_states = past_key_value[0]
103
+ value_states = past_key_value[1]
104
+ elif is_cross_attention:
105
+ # cross_attentions
106
+ key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
107
+ value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
108
+ elif past_key_value is not None:
109
+ # reuse k, v, self_attention
110
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
111
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
112
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
113
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
114
+ else:
115
+ # self_attention
116
+ key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
117
+ value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
118
+
119
+ if self.is_decoder:
120
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
121
+ # Further calls to cross_attention layer can then reuse all cross-attention
122
+ # key/value_states (first "if" case)
123
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
124
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
125
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
126
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
127
+ past_key_value = (key_states, value_states)
128
+
129
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
130
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
131
+
132
+ # Expand kv heads, then match query shape
133
+ key_states = repeat_kv(key_states, self.num_kv_groups)
134
+ value_states = repeat_kv(value_states, self.num_kv_groups)
135
+ key_states = key_states.reshape(*proj_shape)
136
+ value_states = value_states.reshape(*proj_shape)
137
+
138
+ src_len = key_states.size(1)
139
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
140
+
141
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
142
+ raise ValueError(
143
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
144
+ f" {attn_weights.size()}"
145
+ )
146
+
147
+ if attention_mask is not None:
148
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
149
+ raise ValueError(
150
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
151
+ )
152
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
153
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
154
+
155
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
156
+
157
+ if layer_head_mask is not None:
158
+ if layer_head_mask.size() != (self.num_heads,):
159
+ raise ValueError(
160
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
161
+ f" {layer_head_mask.size()}"
162
+ )
163
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
164
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
165
+
166
+ if output_attentions:
167
+ # this operation is a bit awkward, but it's required to
168
+ # make sure that attn_weights keeps its gradient.
169
+ # In order to do so, attn_weights have to be reshaped
170
+ # twice and have to be reused in the following
171
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
172
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
173
+ else:
174
+ attn_weights_reshaped = None
175
+
176
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
177
+
178
+ attn_output = torch.bmm(attn_probs, value_states)
179
+
180
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
181
+ raise ValueError(
182
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
183
+ f" {attn_output.size()}"
184
+ )
185
+
186
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
187
+ attn_output = attn_output.transpose(1, 2)
188
+
189
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
190
+ # partitioned across GPUs when using tensor-parallelism.
191
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
192
+
193
+ attn_output = self.out_proj(attn_output)
194
+
195
+ return attn_output, attn_weights_reshaped, past_key_value
196
+
197
+
198
+ MBART_ATTENTION_CLASSES = {
199
+ "eager": MBartGQAttention,
200
+ "flash_attention_2": None
201
+ }
202
+
203
+
204
+ class MBartOrderDecoderLayer(MBartDecoderLayer):
205
+ def __init__(self, config: MBartConfig):
206
+ nn.Module.__init__(self)
207
+ self.embed_dim = config.d_model
208
+
209
+ self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
210
+ embed_dim=self.embed_dim,
211
+ num_heads=config.decoder_attention_heads,
212
+ num_kv_heads=config.kv_heads,
213
+ dropout=config.attention_dropout,
214
+ is_decoder=True,
215
+ is_causal=True,
216
+ config=config,
217
+ )
218
+ self.dropout = config.dropout
219
+ self.activation_fn = ACT2FN[config.activation_function]
220
+ self.activation_dropout = config.activation_dropout
221
+
222
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
223
+ self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
224
+ self.embed_dim,
225
+ config.decoder_attention_heads,
226
+ num_kv_heads=config.kv_heads,
227
+ dropout=config.attention_dropout,
228
+ is_decoder=True,
229
+ config=config,
230
+ )
231
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
232
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
233
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
234
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
235
+
236
+
237
+ class BboxEmbedding(nn.Module):
238
+ def __init__(self, config):
239
+ super().__init__()
240
+ self.x1_embed = nn.Embedding(config.max_width, config.d_model)
241
+ self.y1_embed = nn.Embedding(config.max_height, config.d_model)
242
+ self.x2_embed = nn.Embedding(config.max_width, config.d_model)
243
+ self.y2_embed = nn.Embedding(config.max_height, config.d_model)
244
+ self.w_embed = nn.Embedding(config.max_width, config.d_model)
245
+ self.h_embed = nn.Embedding(config.max_height, config.d_model)
246
+ self.cx_embed = nn.Embedding(config.max_width, config.d_model)
247
+ self.cy_embed = nn.Embedding(config.max_height, config.d_model)
248
+ self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model)
249
+
250
+ def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int):
251
+ x1, y1, x2, y2 = boxes.unbind(dim=-1)
252
+ # Shape is (batch_size, num_boxes/seq len, d_model)
253
+ w = x2 - x1
254
+ h = y2 - y1
255
+ # Center x and y in torch long tensors
256
+ cx = (x1 + x2) / 2
257
+ cy = (y1 + y2) / 2
258
+ cx = cx.long()
259
+ cy = cy.long()
260
+
261
+ coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
262
+ embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
263
+
264
+ # Add in positional embeddings for the boxes
265
+ if past_key_values_length == 0:
266
+ for j in range(embedded.shape[0]):
267
+ box_start = input_box_counts[j, 0]
268
+ box_end = input_box_counts[j, 1] - 1 # Skip the sep token
269
+ box_count = box_end - box_start
270
+ embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
271
+
272
+ return embedded
273
+
274
+
275
+ class MBartOrderDecoder(MBartDecoder):
276
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
277
+ MBartPreTrainedModel.__init__(self, config)
278
+ self.dropout = config.dropout
279
+ self.layerdrop = config.decoder_layerdrop
280
+ self.padding_idx = config.pad_token_id
281
+ self.max_target_positions = config.max_position_embeddings
282
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
283
+
284
+ self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens
285
+
286
+ if embed_tokens is not None:
287
+ self.embed_tokens.weight = embed_tokens.weight
288
+
289
+ self.embed_positions = MBartLearnedPositionalEmbedding(
290
+ config.max_position_embeddings,
291
+ config.d_model,
292
+ )
293
+ # Language-specific MoE goes at second and second-to-last layer
294
+ self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)])
295
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
296
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
297
+ self.layer_norm = nn.LayerNorm(config.d_model)
298
+
299
+ self.gradient_checkpointing = False
300
+ # Initialize weights and apply final processing
301
+ self.post_init()
302
+
303
+ def forward(
304
+ self,
305
+ input_boxes: torch.LongTensor = None,
306
+ input_boxes_mask: Optional[torch.Tensor] = None,
307
+ input_boxes_counts: Optional[torch.Tensor] = None,
308
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
309
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
310
+ head_mask: Optional[torch.Tensor] = None,
311
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
312
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
313
+ inputs_embeds: Optional[torch.FloatTensor] = None,
314
+ use_cache: Optional[bool] = None,
315
+ output_attentions: Optional[bool] = None,
316
+ output_hidden_states: Optional[bool] = None,
317
+ return_dict: Optional[bool] = None,
318
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
319
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
320
+ output_hidden_states = (
321
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
322
+ )
323
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
324
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
325
+
326
+ # retrieve input_ids and inputs_embeds
327
+ if input_boxes is not None and inputs_embeds is not None:
328
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
329
+ elif input_boxes is not None:
330
+ input = input_boxes
331
+ input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes)
332
+ elif inputs_embeds is not None:
333
+ input_shape = inputs_embeds.size()[:-1]
334
+ input = inputs_embeds[:, :, -1]
335
+ else:
336
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
337
+
338
+ # past_key_values_length
339
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
340
+
341
+ if inputs_embeds is None:
342
+ inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale
343
+
344
+ if self._use_flash_attention_2:
345
+ # 2d mask is passed through the layers
346
+ attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None
347
+ else:
348
+ # 4d mask is passed through the layers
349
+ attention_mask = _prepare_4d_causal_attention_mask(
350
+ input_boxes_mask, input_shape, inputs_embeds, past_key_values_length
351
+ )
352
+
353
+ if past_key_values_length == 0:
354
+ box_ends = input_boxes_counts[:, 1]
355
+ box_starts = input_boxes_counts[:, 0]
356
+ input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :]
357
+ # Enable all boxes to attend to each other (before the sep token)
358
+ # Ensure that the boxes are not attending to the padding tokens
359
+ boxes_end_mask = input_shape_arranged < box_ends[:, None]
360
+ boxes_start_mask = input_shape_arranged >= box_starts[:, None]
361
+ boxes_mask = boxes_end_mask & boxes_start_mask
362
+ boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting
363
+ attention_mask = attention_mask.masked_fill(boxes_mask, 0)
364
+
365
+ # expand encoder attention mask
366
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
367
+ if self._use_flash_attention_2:
368
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
369
+ else:
370
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
371
+ encoder_attention_mask = _prepare_4d_attention_mask(
372
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
373
+ )
374
+
375
+ # embed positions
376
+ positions = self.embed_positions(input, past_key_values_length)
377
+
378
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
379
+ hidden_states = self.layernorm_embedding(hidden_states)
380
+
381
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
382
+
383
+ if self.gradient_checkpointing and self.training:
384
+ if use_cache:
385
+ use_cache = False
386
+
387
+ # decoder layers
388
+ all_hidden_states = () if output_hidden_states else None
389
+ all_self_attns = () if output_attentions else None
390
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
391
+ next_decoder_cache = () if use_cache else None
392
+
393
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
394
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
395
+ if attn_mask is not None:
396
+ if attn_mask.size()[0] != len(self.layers):
397
+ raise ValueError(
398
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
399
+ f" {attn_mask.size()[0]}."
400
+ )
401
+ for idx, decoder_layer in enumerate(self.layers):
402
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
403
+ if output_hidden_states:
404
+ all_hidden_states += (hidden_states,)
405
+ if self.training:
406
+ dropout_probability = torch.rand([])
407
+ if dropout_probability < self.layerdrop:
408
+ continue
409
+
410
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
411
+
412
+ if self.gradient_checkpointing and self.training:
413
+ layer_outputs = self._gradient_checkpointing_func(
414
+ decoder_layer.__call__,
415
+ hidden_states,
416
+ attention_mask,
417
+ encoder_hidden_states,
418
+ encoder_attention_mask,
419
+ head_mask[idx] if head_mask is not None else None,
420
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
421
+ None,
422
+ output_attentions,
423
+ use_cache,
424
+ )
425
+ else:
426
+ layer_outputs = decoder_layer(
427
+ hidden_states,
428
+ attention_mask=attention_mask,
429
+ encoder_hidden_states=encoder_hidden_states,
430
+ encoder_attention_mask=encoder_attention_mask,
431
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
432
+ cross_attn_layer_head_mask=(
433
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
434
+ ),
435
+ past_key_value=past_key_value,
436
+ output_attentions=output_attentions,
437
+ use_cache=use_cache,
438
+ )
439
+ hidden_states = layer_outputs[0]
440
+
441
+ if use_cache:
442
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
443
+
444
+ if output_attentions:
445
+ all_self_attns += (layer_outputs[1],)
446
+
447
+ if encoder_hidden_states is not None:
448
+ all_cross_attentions += (layer_outputs[2],)
449
+
450
+ hidden_states = self.layer_norm(hidden_states)
451
+
452
+ # add hidden states from the last decoder layer
453
+ if output_hidden_states:
454
+ all_hidden_states += (hidden_states,)
455
+
456
+ next_cache = next_decoder_cache if use_cache else None
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
461
+ if v is not None
462
+ )
463
+ return BaseModelOutputWithPastAndCrossAttentions(
464
+ last_hidden_state=hidden_states,
465
+ past_key_values=next_cache,
466
+ hidden_states=all_hidden_states,
467
+ attentions=all_self_attns,
468
+ cross_attentions=all_cross_attentions,
469
+ )
470
+
471
+
472
+ class MBartOrderDecoderWrapper(MBartPreTrainedModel):
473
+ """
474
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
475
+ used in combination with the [`EncoderDecoderModel`] framework.
476
+ """
477
+
478
+ def __init__(self, config):
479
+ super().__init__(config)
480
+ self.decoder = MBartOrderDecoder(config)
481
+
482
+ def forward(self, *args, **kwargs):
483
+ return self.decoder(*args, **kwargs)
484
+
485
+
486
+ class MBartOrder(MBartForCausalLM):
487
+ config_class = MBartOrderConfig
488
+ _tied_weights_keys = []
489
+
490
+ def __init__(self, config, **kwargs):
491
+ config = copy.deepcopy(config)
492
+ config.is_decoder = True
493
+ config.is_encoder_decoder = False
494
+ MBartPreTrainedModel.__init__(self, config)
495
+ self.model = MBartOrderDecoderWrapper(config)
496
+
497
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ # Initialize weights and apply final processing
500
+ self.post_init()
501
+
502
+ def forward(
503
+ self,
504
+ input_boxes: torch.LongTensor = None,
505
+ input_boxes_mask: Optional[torch.Tensor] = None,
506
+ input_boxes_counts: Optional[torch.Tensor] = None,
507
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
508
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
509
+ head_mask: Optional[torch.Tensor] = None,
510
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
511
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
512
+ inputs_embeds: Optional[torch.FloatTensor] = None,
513
+ labels: Optional[torch.LongTensor] = None,
514
+ use_cache: Optional[bool] = None,
515
+ output_attentions: Optional[bool] = None,
516
+ output_hidden_states: Optional[bool] = None,
517
+ return_dict: Optional[bool] = None,
518
+ **kwargs
519
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
520
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
+ output_hidden_states = (
522
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
+ )
524
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
525
+
526
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
527
+ outputs = self.model.decoder(
528
+ input_boxes=input_boxes,
529
+ input_boxes_mask=input_boxes_mask,
530
+ input_boxes_counts=input_boxes_counts,
531
+ encoder_hidden_states=encoder_hidden_states,
532
+ encoder_attention_mask=encoder_attention_mask,
533
+ head_mask=head_mask,
534
+ cross_attn_head_mask=cross_attn_head_mask,
535
+ past_key_values=past_key_values,
536
+ inputs_embeds=inputs_embeds,
537
+ use_cache=use_cache,
538
+ output_attentions=output_attentions,
539
+ output_hidden_states=output_hidden_states,
540
+ return_dict=return_dict,
541
+ )
542
+
543
+ logits = self.lm_head(outputs[0])
544
+
545
+ loss = None
546
+ if not return_dict:
547
+ output = (logits,) + outputs[1:]
548
+ return (loss,) + output if loss is not None else output
549
+
550
+ return CausalLMOutputWithCrossAttentions(
551
+ loss=loss,
552
+ logits=logits,
553
+ past_key_values=outputs.past_key_values,
554
+ hidden_states=outputs.hidden_states,
555
+ attentions=outputs.attentions,
556
+ cross_attentions=outputs.cross_attentions,
557
+ )
surya/model/ordering/encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from typing import Optional, Tuple, Union
4
+ import collections
5
+ import math
6
+
7
+ from transformers import DonutSwinPreTrainedModel
8
+ from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \
9
+ DonutSwinEncoder
10
+
11
+ from surya.model.ordering.config import VariableDonutSwinConfig
12
+
13
+ class VariableDonutSwinEmbeddings(DonutSwinEmbeddings):
14
+ """
15
+ Construct the patch and position embeddings. Optionally, also the mask token.
16
+ """
17
+
18
+ def __init__(self, config, use_mask_token=False, **kwargs):
19
+ super().__init__(config, use_mask_token)
20
+
21
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
22
+ num_patches = self.patch_embeddings.num_patches
23
+ self.patch_grid = self.patch_embeddings.grid_size
24
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
25
+ self.position_embeddings = None
26
+
27
+ if config.use_absolute_embeddings:
28
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
29
+
30
+ self.row_embeddings = None
31
+ self.column_embeddings = None
32
+ if config.use_2d_embeddings:
33
+ self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
34
+ self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
35
+
36
+ self.norm = nn.LayerNorm(config.embed_dim)
37
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
38
+
39
+ def forward(
40
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
41
+ ) -> Tuple[torch.Tensor]:
42
+
43
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
44
+ # Layernorm across the last dimension (each patch is a single row)
45
+ embeddings = self.norm(embeddings)
46
+ batch_size, seq_len, embed_dim = embeddings.size()
47
+
48
+ if bool_masked_pos is not None:
49
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
50
+ # replace the masked visual tokens by mask_tokens
51
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
52
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
53
+
54
+ if self.position_embeddings is not None:
55
+ embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
56
+
57
+ if self.row_embeddings is not None and self.column_embeddings is not None:
58
+ # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
59
+ row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1)
60
+ column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
61
+
62
+ embeddings = embeddings + row_embeddings + column_embeddings
63
+
64
+ embeddings = self.dropout(embeddings)
65
+
66
+ return embeddings, output_dimensions
67
+
68
+
69
+ class VariableDonutSwinModel(DonutSwinModel):
70
+ config_class = VariableDonutSwinConfig
71
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs):
72
+ super().__init__(config)
73
+ self.config = config
74
+ self.num_layers = len(config.depths)
75
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
76
+
77
+ self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token)
78
+ self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
79
+
80
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
81
+
82
+ # Initialize weights and apply final processing
83
+ self.post_init()
surya/model/ordering/encoderdecoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple, List
2
+
3
+ import torch
4
+ from transformers import VisionEncoderDecoderModel
5
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
6
+
7
+
8
+ class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel):
9
+ def forward(
10
+ self,
11
+ pixel_values: Optional[torch.FloatTensor] = None,
12
+ decoder_input_boxes: torch.LongTensor = None,
13
+ # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding
14
+ decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise
15
+ decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image
16
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
17
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
18
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[List[List[int]]] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ return_dict: Optional[bool] = None,
24
+ **kwargs,
25
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
26
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27
+
28
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
29
+
30
+ kwargs_decoder = {
31
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
32
+ }
33
+
34
+ if encoder_outputs is None:
35
+ if pixel_values is None:
36
+ raise ValueError("You have to specify pixel_values")
37
+
38
+ encoder_outputs = self.encoder(
39
+ pixel_values=pixel_values,
40
+ output_attentions=output_attentions,
41
+ output_hidden_states=output_hidden_states,
42
+ return_dict=return_dict,
43
+ **kwargs_encoder,
44
+ )
45
+ elif isinstance(encoder_outputs, tuple):
46
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
47
+
48
+ encoder_hidden_states = encoder_outputs[0]
49
+
50
+ # optionally project encoder_hidden_states
51
+ if (
52
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
53
+ and self.decoder.config.cross_attention_hidden_size is None
54
+ ):
55
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
56
+
57
+ # else:
58
+ encoder_attention_mask = None
59
+
60
+ # Decode
61
+ decoder_outputs = self.decoder(
62
+ input_boxes=decoder_input_boxes,
63
+ input_boxes_mask=decoder_input_boxes_mask,
64
+ input_boxes_counts=decoder_input_boxes_counts,
65
+ encoder_hidden_states=encoder_hidden_states,
66
+ encoder_attention_mask=encoder_attention_mask,
67
+ inputs_embeds=decoder_inputs_embeds,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ use_cache=use_cache,
71
+ past_key_values=past_key_values,
72
+ return_dict=return_dict,
73
+ labels=labels,
74
+ **kwargs_decoder,
75
+ )
76
+
77
+ if not return_dict:
78
+ return decoder_outputs + encoder_outputs
79
+
80
+ return Seq2SeqLMOutput(
81
+ loss=decoder_outputs.loss,
82
+ logits=decoder_outputs.logits,
83
+ past_key_values=decoder_outputs.past_key_values,
84
+ decoder_hidden_states=decoder_outputs.hidden_states,
85
+ decoder_attentions=decoder_outputs.attentions,
86
+ cross_attentions=decoder_outputs.cross_attentions,
87
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
88
+ encoder_hidden_states=encoder_outputs.hidden_states,
89
+ encoder_attentions=encoder_outputs.attentions,
90
+ )
surya/model/ordering/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
2
+ AutoModel
3
+ from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
4
+ from surya.model.ordering.decoder import MBartOrder
5
+ from surya.model.ordering.encoder import VariableDonutSwinModel
6
+ from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
7
+ from surya.model.ordering.processor import OrderImageProcessor
8
+ from surya.settings import settings
9
+
10
+
11
+ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
12
+ config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
13
+
14
+ decoder_config = vars(config.decoder)
15
+ decoder = MBartOrderConfig(**decoder_config)
16
+ config.decoder = decoder
17
+
18
+ encoder_config = vars(config.encoder)
19
+ encoder = VariableDonutSwinConfig(**encoder_config)
20
+ config.encoder = encoder
21
+
22
+ # Get transformers to load custom model
23
+ AutoModel.register(MBartOrderConfig, MBartOrder)
24
+ AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
25
+ AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
26
+
27
+ model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
28
+ assert isinstance(model.decoder, MBartOrder)
29
+ assert isinstance(model.encoder, VariableDonutSwinModel)
30
+
31
+ model = model.to(device)
32
+ model = model.eval()
33
+ print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
34
+ return model
surya/model/ordering/processor.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Dict, Union, Optional, List, Tuple
3
+
4
+ import torch
5
+ from torch import TensorType
6
+ from transformers import DonutImageProcessor, DonutProcessor
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
9
+ valid_images, to_numpy_array
10
+ import numpy as np
11
+ from PIL import Image
12
+ import PIL
13
+ from surya.settings import settings
14
+
15
+
16
+ def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
17
+ processor = OrderImageProcessor.from_pretrained(checkpoint)
18
+ processor.size = settings.ORDER_IMAGE_SIZE
19
+ box_size = 1024
20
+ max_tokens = 256
21
+ processor.token_sep_id = max_tokens + box_size + 1
22
+ processor.token_pad_id = max_tokens + box_size + 2
23
+ processor.max_boxes = settings.ORDER_MAX_BOXES - 1
24
+ processor.box_size = {"height": box_size, "width": box_size}
25
+ return processor
26
+
27
+
28
+ class OrderImageProcessor(DonutImageProcessor):
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ self.patch_size = kwargs.get("patch_size", (4, 4))
33
+
34
+ def process_inner(self, images: List[np.ndarray]):
35
+ images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format
36
+
37
+ assert images[0].shape[0] == 3 # RGB input images, channel dim last
38
+
39
+ # Convert to float32 for rescale/normalize
40
+ images = [img.astype(np.float32) for img in images]
41
+
42
+ # Rescale and normalize
43
+ images = [
44
+ self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
45
+ for img in images
46
+ ]
47
+ images = [
48
+ self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
49
+ for img in images
50
+ ]
51
+
52
+ return images
53
+
54
+ def process_boxes(self, boxes):
55
+ padded_boxes = []
56
+ box_masks = []
57
+ box_counts = []
58
+ for b in boxes:
59
+ # Left pad for generation
60
+ padded_b = deepcopy(b)
61
+ padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
62
+ padded_boxes.append(padded_b)
63
+
64
+ max_boxes = max(len(b) for b in padded_boxes)
65
+ for i in range(len(padded_boxes)):
66
+ pad_len = max_boxes - len(padded_boxes[i])
67
+ box_len = len(padded_boxes[i])
68
+ box_mask = [0] * pad_len + [1] * box_len
69
+ padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
70
+ padded_boxes[i] = padded_box
71
+ box_masks.append(box_mask)
72
+ box_counts.append([pad_len, max_boxes])
73
+
74
+ return padded_boxes, box_masks, box_counts
75
+
76
+ def resize_img_and_boxes(self, img, boxes):
77
+ orig_dim = img.size
78
+ new_size = (self.size["width"], self.size["height"])
79
+ img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
80
+ img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
81
+
82
+ img = np.asarray(img, dtype=np.uint8)
83
+
84
+ width, height = orig_dim
85
+ box_width, box_height = self.box_size["width"], self.box_size["height"]
86
+ for box in boxes:
87
+ # Rescale to 0-1024
88
+ box[0] = box[0] / width * box_width
89
+ box[1] = box[1] / height * box_height
90
+ box[2] = box[2] / width * box_width
91
+ box[3] = box[3] / height * box_height
92
+
93
+ if box[0] < 0:
94
+ box[0] = 0
95
+ if box[1] < 0:
96
+ box[1] = 0
97
+ if box[2] > box_width:
98
+ box[2] = box_width
99
+ if box[3] > box_height:
100
+ box[3] = box_height
101
+
102
+ return img, boxes
103
+
104
+ def preprocess(
105
+ self,
106
+ images: ImageInput,
107
+ boxes: List[List[int]],
108
+ do_resize: bool = None,
109
+ size: Dict[str, int] = None,
110
+ resample: PILImageResampling = None,
111
+ do_thumbnail: bool = None,
112
+ do_align_long_axis: bool = None,
113
+ do_pad: bool = None,
114
+ random_padding: bool = False,
115
+ do_rescale: bool = None,
116
+ rescale_factor: float = None,
117
+ do_normalize: bool = None,
118
+ image_mean: Optional[Union[float, List[float]]] = None,
119
+ image_std: Optional[Union[float, List[float]]] = None,
120
+ return_tensors: Optional[Union[str, TensorType]] = None,
121
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
122
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
123
+ **kwargs,
124
+ ) -> PIL.Image.Image:
125
+ images = make_list_of_images(images)
126
+
127
+ if not valid_images(images):
128
+ raise ValueError(
129
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
130
+ "torch.Tensor, tf.Tensor or jax.ndarray."
131
+ )
132
+
133
+ new_images = []
134
+ new_boxes = []
135
+ for img, box in zip(images, boxes):
136
+ if len(box) > self.max_boxes:
137
+ raise ValueError(f"Too many boxes, max is {self.max_boxes}")
138
+ img, box = self.resize_img_and_boxes(img, box)
139
+ new_images.append(img)
140
+ new_boxes.append(box)
141
+
142
+ images = new_images
143
+ boxes = new_boxes
144
+
145
+ # Convert to numpy for later processing steps
146
+ images = [np.array(image) for image in images]
147
+
148
+ images = self.process_inner(images)
149
+ boxes, box_mask, box_counts = self.process_boxes(boxes)
150
+ data = {
151
+ "pixel_values": images,
152
+ "input_boxes": boxes,
153
+ "input_boxes_mask": box_mask,
154
+ "input_boxes_counts": box_counts,
155
+ }
156
+ return BatchFeature(data=data, tensor_type=return_tensors)
surya/model/recognition/config.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from transformers import PretrainedConfig
5
+ from transformers.utils import ModelOutput
6
+
7
+
8
+ class SuryaOCRConfig(PretrainedConfig):
9
+ model_type = "vision-encoder-decoder"
10
+ is_composition = True
11
+
12
+ def __init__(self, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ encoder_config = kwargs.pop("encoder")
16
+ decoder_config = kwargs.pop("decoder")
17
+
18
+ self.encoder = encoder_config
19
+ self.decoder = decoder_config
20
+ self.is_encoder_decoder = True
21
+
22
+ if isinstance(decoder_config, dict):
23
+ self.decoder_start_token_id = decoder_config["bos_token_id"]
24
+ self.pad_token_id = decoder_config["pad_token_id"]
25
+ self.eos_token_id = decoder_config["eos_token_id"]
26
+ else:
27
+ self.decoder_start_token_id = decoder_config.bos_token_id
28
+ self.pad_token_id = decoder_config.pad_token_id
29
+ self.eos_token_id = decoder_config.eos_token_id
30
+
31
+
32
+ class DonutSwinConfig(PretrainedConfig):
33
+ model_type = "donut-swin"
34
+
35
+ attribute_map = {
36
+ "num_attention_heads": "num_heads",
37
+ "num_hidden_layers": "num_layers",
38
+ }
39
+
40
+ def __init__(
41
+ self,
42
+ image_size=(256, 896),
43
+ patch_size=4,
44
+ num_channels=3,
45
+ embed_dim=128,
46
+ depths=[2, 2, 14, 2],
47
+ num_heads=[4, 8, 16, 32],
48
+ num_kv_heads=[1, 2, 4, 8],
49
+ window_size=7,
50
+ mlp_ratio=4.0,
51
+ qkv_bias=True,
52
+ hidden_dropout_prob=0.0,
53
+ attention_probs_dropout_prob=0.0,
54
+ drop_path_rate=0.1,
55
+ hidden_act="gelu",
56
+ use_absolute_embeddings=True,
57
+ initializer_range=0.02,
58
+ layer_norm_eps=1e-5,
59
+ encoder_length=256,
60
+ **kwargs,
61
+ ):
62
+ super().__init__(**kwargs)
63
+
64
+ self.image_size = image_size
65
+ self.patch_size = patch_size
66
+ self.num_channels = num_channels
67
+ self.embed_dim = embed_dim
68
+ self.depths = depths
69
+ self.num_layers = len(depths)
70
+ self.num_heads = num_heads
71
+ self.num_kv_heads = num_kv_heads
72
+ self.window_size = window_size
73
+ self.mlp_ratio = mlp_ratio
74
+ self.qkv_bias = qkv_bias
75
+ self.hidden_dropout_prob = hidden_dropout_prob
76
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
77
+ self.drop_path_rate = drop_path_rate
78
+ self.hidden_act = hidden_act
79
+ self.use_absolute_embeddings = use_absolute_embeddings
80
+ self.layer_norm_eps = layer_norm_eps
81
+ self.initializer_range = initializer_range
82
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
83
+ # this indicates the channel dimension after the last stage of the model
84
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
85
+ self.encoder_length = encoder_length
86
+
87
+
88
+ class SuryaOCRDecoderConfig(PretrainedConfig):
89
+ model_type = "surya_ocr"
90
+
91
+ def __init__(
92
+ self,
93
+ num_hidden_layers=10,
94
+ vocab_size=65792,
95
+ hidden_size=1024,
96
+ intermediate_size=4 * 1024,
97
+ num_attention_heads=16,
98
+ lru_width=None,
99
+ attention_window_size=16,
100
+ conv1d_width=4,
101
+ logits_soft_cap=30.0,
102
+ rms_norm_eps=1e-6,
103
+ use_cache=True,
104
+ pad_token_id=0,
105
+ eos_token_id=1,
106
+ bos_token_id=1,
107
+ hidden_activation="gelu_pytorch_tanh",
108
+ rope_theta=10000.0,
109
+ block_types=("attention",),
110
+ cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
111
+ self_attn_layers=(0, 1, 3, 5, 7, 9),
112
+ global_attn_layers=(0, 1, 3, 5, 7, 9),
113
+ attention_dropout=0.0,
114
+ num_key_value_heads=2,
115
+ attention_bias=False,
116
+ w_init_variance_scale=0.01,
117
+ init_std=0.02,
118
+ tie_word_embeddings=False,
119
+ aux_heads=0, # How many n-token-ahead heads to add
120
+ encoder_hidden_size=1024,
121
+ causal=False,
122
+ **kwargs,
123
+ ):
124
+ self.num_hidden_layers = num_hidden_layers
125
+ self.vocab_size = vocab_size
126
+ self.hidden_size = hidden_size
127
+ self.intermediate_size = intermediate_size
128
+ self.num_attention_heads = num_attention_heads
129
+ self.lru_width = lru_width if lru_width is not None else hidden_size
130
+ self.attention_window_size = attention_window_size
131
+ self.conv1d_width = conv1d_width
132
+ self.logits_soft_cap = logits_soft_cap
133
+ self.rms_norm_eps = rms_norm_eps
134
+ self.use_cache = use_cache
135
+ self.rope_theta = rope_theta
136
+ self.block_types = list(block_types)
137
+ self.hidden_activation = hidden_activation
138
+ self.head_dim = self.hidden_size // self.num_attention_heads
139
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
140
+ if self.num_key_value_heads > self.num_attention_heads:
141
+ raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
142
+ self.cross_attn_layers = cross_attn_layers
143
+ self.self_attn_layers = self_attn_layers
144
+ self.global_attn_layers = global_attn_layers
145
+ self.attention_dropout = attention_dropout
146
+ self.attention_bias = attention_bias
147
+ self.w_init_variance_scale = w_init_variance_scale
148
+ self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
149
+ self.init_std = init_std
150
+ self.tie_word_embeddings = tie_word_embeddings
151
+ self.aux_heads = aux_heads
152
+ self.encoder_hidden_size = encoder_hidden_size
153
+ self.causal = causal
154
+
155
+ super().__init__(
156
+ pad_token_id=pad_token_id,
157
+ bos_token_id=bos_token_id,
158
+ eos_token_id=eos_token_id,
159
+ **kwargs,
160
+ )
161
+
162
+ @property
163
+ def layers_block_type(self):
164
+ return (self.block_types * 100)[: self.num_hidden_layers]
165
+
166
+
167
+ class SuryaOCRTextEncoderConfig(PretrainedConfig):
168
+ model_type = "surya_ocr"
169
+
170
+ def __init__(
171
+ self,
172
+ num_hidden_layers=10,
173
+ vocab_size=65792,
174
+ hidden_size=1024,
175
+ intermediate_size=4 * 1024,
176
+ num_attention_heads=16,
177
+ lru_width=None,
178
+ attention_window_size=16,
179
+ conv1d_width=4,
180
+ logits_soft_cap=30.0,
181
+ rms_norm_eps=1e-6,
182
+ use_cache=True,
183
+ pad_token_id=0,
184
+ eos_token_id=1,
185
+ bos_token_id=1,
186
+ hidden_activation="gelu_pytorch_tanh",
187
+ rope_theta=10000.0,
188
+ block_types=("attention",),
189
+ cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
190
+ self_attn_layers=(0, 1, 3, 5, 7, 9),
191
+ global_attn_layers=(0, 1, 3, 5, 7, 9),
192
+ attention_dropout=0.0,
193
+ num_key_value_heads=2,
194
+ attention_bias=False,
195
+ w_init_variance_scale=0.01,
196
+ init_std=0.02,
197
+ tie_word_embeddings=False,
198
+ aux_heads=0, # How many n-token-ahead heads to add
199
+ encoder_hidden_size=1024,
200
+ iteration_count=1,
201
+ causal=False,
202
+ query_token_count=128,
203
+ **kwargs,
204
+ ):
205
+ self.num_hidden_layers = num_hidden_layers
206
+ self.vocab_size = vocab_size
207
+ self.hidden_size = hidden_size
208
+ self.intermediate_size = intermediate_size
209
+ self.num_attention_heads = num_attention_heads
210
+ self.lru_width = lru_width if lru_width is not None else hidden_size
211
+ self.attention_window_size = attention_window_size
212
+ self.conv1d_width = conv1d_width
213
+ self.logits_soft_cap = logits_soft_cap
214
+ self.rms_norm_eps = rms_norm_eps
215
+ self.use_cache = use_cache
216
+ self.rope_theta = rope_theta
217
+ self.block_types = list(block_types)
218
+ self.hidden_activation = hidden_activation
219
+ self.head_dim = self.hidden_size // self.num_attention_heads
220
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
221
+ if self.num_key_value_heads > self.num_attention_heads:
222
+ raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
223
+ self.cross_attn_layers = cross_attn_layers
224
+ self.self_attn_layers = self_attn_layers
225
+ self.global_attn_layers = global_attn_layers
226
+ self.attention_dropout = attention_dropout
227
+ self.attention_bias = attention_bias
228
+ self.w_init_variance_scale = w_init_variance_scale
229
+ self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
230
+ self.init_std = init_std
231
+ self.tie_word_embeddings = tie_word_embeddings
232
+ self.aux_heads = aux_heads
233
+ self.encoder_hidden_size = encoder_hidden_size
234
+ self.iteration_count = iteration_count
235
+ self.causal = causal
236
+ self.query_token_count = query_token_count
237
+
238
+ super().__init__(
239
+ pad_token_id=pad_token_id,
240
+ bos_token_id=bos_token_id,
241
+ eos_token_id=eos_token_id,
242
+ **kwargs,
243
+ )
244
+
245
+ @property
246
+ def layers_block_type(self):
247
+ return (self.block_types * 100)[: self.num_hidden_layers]
248
+
249
+ TOTAL_TOKENS = 65536
250
+ TOKEN_OFFSET = 3 # Pad, eos, bos
251
+ SPECIAL_TOKENS = 253
252
+ TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS
253
+ LANGUAGE_MAP = {
254
+ 'af': 0,
255
+ 'am': 1,
256
+ 'ar': 2,
257
+ 'as': 3,
258
+ 'az': 4,
259
+ 'be': 5,
260
+ 'bg': 6,
261
+ 'bn': 7,
262
+ 'br': 8,
263
+ 'bs': 9,
264
+ 'ca': 10,
265
+ 'cs': 11,
266
+ 'cy': 12,
267
+ 'da': 13,
268
+ 'de': 14,
269
+ 'el': 15,
270
+ 'en': 16,
271
+ 'eo': 17,
272
+ 'es': 18,
273
+ 'et': 19,
274
+ 'eu': 20,
275
+ 'fa': 21,
276
+ 'fi': 22,
277
+ 'fr': 23,
278
+ 'fy': 24,
279
+ 'ga': 25,
280
+ 'gd': 26,
281
+ 'gl': 27,
282
+ 'gu': 28,
283
+ 'ha': 29,
284
+ 'he': 30,
285
+ 'hi': 31,
286
+ 'hr': 32,
287
+ 'hu': 33,
288
+ 'hy': 34,
289
+ 'id': 35,
290
+ 'is': 36,
291
+ 'it': 37,
292
+ 'ja': 38,
293
+ 'jv': 39,
294
+ 'ka': 40,
295
+ 'kk': 41,
296
+ 'km': 42,
297
+ 'kn': 43,
298
+ 'ko': 44,
299
+ 'ku': 45,
300
+ 'ky': 46,
301
+ 'la': 47,
302
+ 'lo': 48,
303
+ 'lt': 49,
304
+ 'lv': 50,
305
+ 'mg': 51,
306
+ 'mk': 52,
307
+ 'ml': 53,
308
+ 'mn': 54,
309
+ 'mr': 55,
310
+ 'ms': 56,
311
+ 'my': 57,
312
+ 'ne': 58,
313
+ 'nl': 59,
314
+ 'no': 60,
315
+ 'om': 61,
316
+ 'or': 62,
317
+ 'pa': 63,
318
+ 'pl': 64,
319
+ 'ps': 65,
320
+ 'pt': 66,
321
+ 'ro': 67,
322
+ 'ru': 68,
323
+ 'sa': 69,
324
+ 'sd': 70,
325
+ 'si': 71,
326
+ 'sk': 72,
327
+ 'sl': 73,
328
+ 'so': 74,
329
+ 'sq': 75,
330
+ 'sr': 76,
331
+ 'su': 77,
332
+ 'sv': 78,
333
+ 'sw': 79,
334
+ 'ta': 80,
335
+ 'te': 81,
336
+ 'th': 82,
337
+ 'tl': 83,
338
+ 'tr': 84,
339
+ 'ug': 85,
340
+ 'uk': 86,
341
+ 'ur': 87,
342
+ 'uz': 88,
343
+ 'vi': 89,
344
+ 'xh': 90,
345
+ 'yi': 91,
346
+ 'zh': 92,
347
+ "_math": 93
348
+ }
surya/model/recognition/decoder.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from transformers.utils import ModelOutput
8
+
9
+ from surya.model.recognition.config import SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig
10
+ from transformers import PreTrainedModel
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
14
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
15
+
16
+ from surya.settings import settings
17
+
18
+ _MAX_SQRT_GRADIENT = 1000.0
19
+
20
+
21
+ @dataclass
22
+ class OCRModelOutput(ModelOutput):
23
+ logits: torch.Tensor
24
+ aux_logits: torch.Tensor | None = None
25
+ hidden_states: torch.Tensor | None = None
26
+
27
+
28
+ class SuryaOCRDecoderRMSNorm(nn.Module):
29
+ def __init__(self, dim: int, eps: float = 1e-6):
30
+ super().__init__()
31
+ self.eps = eps
32
+ self.weight = nn.Parameter(torch.zeros(dim))
33
+
34
+ def _norm(self, x):
35
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
36
+
37
+ def forward(self, x):
38
+ output = self._norm(x.float())
39
+ # Llama does x.to(float16) * w whilst SuryaOCRDecoder is (x * w).to(float16)
40
+ # See https://github.com/huggingface/transformers/pull/29402
41
+ output = output * (1.0 + self.weight.float())
42
+ return output.type_as(x)
43
+
44
+ def extra_repr(self):
45
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
46
+
47
+
48
+ ALL_LAYERNORM_LAYERS.append(SuryaOCRDecoderRMSNorm)
49
+
50
+
51
+ class SuryaOCRDecoderRotaryEmbedding(nn.Module):
52
+ def __init__(self, dim, base=10000, device=None):
53
+ super().__init__()
54
+ self.dim = dim
55
+ self.base = base
56
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
57
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
58
+
59
+ @torch.no_grad()
60
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaOCRDecoder
61
+ def forward(self, x, position_ids, seq_len=None):
62
+ # x: [bs, num_attention_heads, seq_len, head_size]
63
+ self.inv_freq.to(x.device)
64
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
65
+ position_ids_expanded = position_ids[:, None, :].float()
66
+
67
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
68
+ emb = torch.cat((freqs, freqs), dim=-1)
69
+ cos = emb.cos()
70
+ sin = emb.sin()
71
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
72
+
73
+
74
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
83
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
84
+ """Applies Rotary Position Embedding to the query and key tensors.
85
+
86
+ Args:
87
+ q (`torch.Tensor`): The query tensor.
88
+ k (`torch.Tensor`): The key tensor.
89
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
90
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
91
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
92
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
93
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
94
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
95
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
96
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
97
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
98
+ Returns:
99
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
100
+ """
101
+ cos = cos.unsqueeze(unsqueeze_dim)
102
+ sin = sin.unsqueeze(unsqueeze_dim)
103
+ q_embed = (q * cos) + (rotate_half(q) * sin)
104
+ k_embed = (k * cos) + (rotate_half(k) * sin)
105
+ return q_embed, k_embed
106
+
107
+
108
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
109
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
110
+ """
111
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
112
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
113
+ """
114
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
115
+ if n_rep == 1:
116
+ return hidden_states
117
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
118
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
119
+
120
+
121
+ class SuryaOCRDecoderSdpaCrossAttention(nn.Module):
122
+ """Multi-headed attention from 'Attention Is All You Need' paper
123
+ Modified for GQA
124
+ """
125
+
126
+ def __init__(self, config: SuryaOCRDecoderConfig):
127
+ super().__init__()
128
+ self.config = config
129
+ self.attention_dropout = config.attention_dropout
130
+ self.hidden_size = config.hidden_size
131
+ self.num_attention_heads = config.num_attention_heads
132
+ self.head_dim = config.head_dim
133
+ self.num_key_value_heads = config.num_key_value_heads
134
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
135
+
136
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
137
+ self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
138
+ self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
139
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
140
+ self.rotary_emb = SuryaOCRDecoderRotaryEmbedding(
141
+ self.head_dim,
142
+ base=config.rope_theta,
143
+ )
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.Tensor,
148
+ encoder_hidden_states: torch.Tensor,
149
+ attention_mask: Optional[torch.Tensor] = None,
150
+ encoder_attention_mask: Optional[torch.Tensor] = None,
151
+ use_cache: bool = False,
152
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
153
+ # Encoder attention mask currently ignored
154
+
155
+ bsz, q_len, _ = hidden_states.size()
156
+ _, v_len, _ = encoder_hidden_states.size()
157
+
158
+ query_states = self.q_proj(hidden_states)
159
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
160
+
161
+ if self.key_states is None:
162
+ key_states = self.k_proj(encoder_hidden_states)
163
+ value_states = self.v_proj(encoder_hidden_states)
164
+ key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
165
+ value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
166
+ if use_cache:
167
+ self._update_cache(key_states, value_states)
168
+ else:
169
+ key_states = self.key_states
170
+ value_states = self.value_states
171
+
172
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
173
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
174
+
175
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
176
+ query_states.contiguous(),
177
+ key_states.contiguous(),
178
+ value_states.contiguous(),
179
+ attn_mask=None,
180
+ dropout_p=self.attention_dropout if self.training else 0.0,
181
+ scale=self.head_dim**-0.5,
182
+ )
183
+
184
+ attn_output = attn_output.transpose(1, 2).contiguous()
185
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
186
+ attn_output = self.o_proj(attn_output)
187
+ return attn_output
188
+
189
+ def _setup_cache(self, batch_size, device, dtype=None):
190
+ # Setup initial caches
191
+ self.value_states = None
192
+ self.key_states = None
193
+
194
+ @torch.no_grad()
195
+ def _update_cache(self, key_states, value_states, **cache_kwargs):
196
+ self.value_states = value_states
197
+ self.key_states = key_states
198
+
199
+
200
+ class SuryaOCRDecoderSdpaAttention(nn.Module):
201
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
202
+
203
+ def __init__(self, config: SuryaOCRDecoderConfig):
204
+ super().__init__()
205
+ self.config = config
206
+ self.attention_dropout = config.attention_dropout
207
+ self.hidden_size = config.hidden_size
208
+ self.num_attention_heads = config.num_attention_heads
209
+ self.head_dim = config.head_dim
210
+ self.num_key_value_heads = config.num_key_value_heads
211
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
212
+
213
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
214
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
215
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
216
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
217
+ self.rotary_emb = SuryaOCRDecoderRotaryEmbedding(
218
+ self.head_dim,
219
+ base=config.rope_theta,
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ position_ids: Optional[torch.LongTensor] = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ cache_position: Optional[torch.LongTensor] = None,
228
+ use_cache: bool = False,
229
+ window_attn: bool = False,
230
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
231
+ bsz, q_len, _ = hidden_states.size()
232
+
233
+ query_states = self.q_proj(hidden_states)
234
+ key_states = self.k_proj(hidden_states)
235
+ value_states = self.v_proj(hidden_states)
236
+
237
+ # Final is bsz, num_attention_heads, seq_len, head_dim
238
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
239
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
240
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
241
+
242
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
243
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
244
+
245
+ if use_cache and hasattr(self, "key_states"):
246
+ cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
247
+ key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
248
+
249
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
250
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
251
+
252
+ causal_mask = attention_mask
253
+ if attention_mask is not None:
254
+ # Mask is batch, head, seq_len, kv_len
255
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
256
+ current_cache_position = cache_position[-1].item() if cache_position is not None else None
257
+ if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
258
+ # Mask out future cache positions
259
+ position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
260
+ position_mask[:, :, :, :current_cache_position + 1] = False
261
+ causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
262
+
263
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
264
+ query_states.contiguous(),
265
+ key_states.contiguous(),
266
+ value_states.contiguous(),
267
+ attn_mask=causal_mask,
268
+ dropout_p=self.attention_dropout if self.training else 0.0,
269
+ scale=self.head_dim**-0.5,
270
+ )
271
+
272
+ attn_output = attn_output.transpose(1, 2).contiguous()
273
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
274
+ attn_output = self.o_proj(attn_output)
275
+ return attn_output
276
+
277
+ def _setup_cache(self, batch_size, device, dtype=None):
278
+ if dtype is None and self.config.torch_dtype is not None:
279
+ dtype = self.config.torch_dtype
280
+ dtype = dtype if dtype is not None else torch.float32
281
+
282
+ # Setup initial caches
283
+ self.value_states = None
284
+ self.key_states = None
285
+
286
+ if settings.RECOGNITION_STATIC_CACHE:
287
+ cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
288
+ self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
289
+ self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
290
+
291
+ def _update_static_cache(self, key_states, value_states, **cache_kwargs):
292
+ cache_position = cache_kwargs.get("cache_position")
293
+ k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
294
+
295
+ k_out[:, :, cache_position] = key_states.to(k_out.dtype)
296
+ v_out[:, :, cache_position] = value_states.to(v_out.dtype)
297
+
298
+ self.key_states, self.value_states = k_out, v_out
299
+ return k_out, v_out
300
+
301
+ def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
302
+ k_out = key_states
303
+ if self.key_states is not None:
304
+ k_out = torch.cat([self.key_states, key_states], dim=2)
305
+
306
+ v_out = value_states
307
+ if self.value_states is not None:
308
+ v_out = torch.cat([self.value_states, value_states], dim=2)
309
+
310
+ self.key_states, self.value_states = k_out, v_out
311
+ return k_out, v_out
312
+
313
+ @torch.no_grad()
314
+ def _update_cache(self, key_states, value_states, **cache_kwargs):
315
+ if settings.RECOGNITION_STATIC_CACHE:
316
+ return self._update_static_cache(key_states, value_states, **cache_kwargs)
317
+
318
+ return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
319
+
320
+
321
+ class SuryaOCRDecoderMlp(nn.Module):
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ self.config = config
325
+ self.hidden_size = config.hidden_size
326
+ self.intermediate_size = config.intermediate_size
327
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
328
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
329
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
330
+ if config.hidden_activation is None:
331
+ config.hidden_activation = "gelu_pytorch_tanh"
332
+ hidden_activation = config.hidden_activation
333
+ self.act_fn = ACT2FN[hidden_activation]
334
+
335
+ def forward(self, x):
336
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
337
+
338
+
339
+ class SuryaOCRDecoderLayer(nn.Module):
340
+ def __init__(self, config, layer_idx):
341
+ super().__init__()
342
+ super().__init__()
343
+ self.cross_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
344
+ self.temporal_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
345
+
346
+ self.temporal_block = None
347
+ if layer_idx in config.self_attn_layers:
348
+ self.temporal_block = SuryaOCRDecoderSdpaAttention(config)
349
+
350
+ self.cross_attn_block = None
351
+ if layer_idx in config.cross_attn_layers:
352
+ self.cross_attn_block = SuryaOCRDecoderSdpaCrossAttention(config)
353
+
354
+ self.window_attn = layer_idx not in config.global_attn_layers
355
+ self.channel_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
356
+ self.mlp_block = SuryaOCRDecoderMlp(config)
357
+
358
+ def forward(
359
+ self,
360
+ activations: torch.Tensor,
361
+ position_ids: torch.Tensor,
362
+ attention_mask: torch.Tensor,
363
+ encoder_hidden_states: torch.Tensor = None,
364
+ encoder_attention_mask: torch.Tensor = None,
365
+ cache_position: torch.Tensor = None,
366
+ use_cache: bool = None,
367
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
368
+ raw_activations = activations
369
+
370
+ if self.cross_attn_block is not None:
371
+ # Do cross-attention on encoder outputs
372
+ cross_attn_inputs = self.cross_pre_norm(activations)
373
+ cross_attn_path = self.cross_attn_block(
374
+ cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
375
+ )
376
+ cross_attn_output = cross_attn_path + raw_activations
377
+ else:
378
+ cross_attn_output = raw_activations
379
+
380
+ if self.temporal_block is not None:
381
+ inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
382
+ hidden_states = self.temporal_block(
383
+ inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
384
+ )
385
+
386
+ residual = hidden_states + raw_activations
387
+ else:
388
+ residual = cross_attn_output
389
+
390
+ hidden_states = self.channel_pre_norm(residual)
391
+ hidden_states = self.mlp_block(hidden_states)
392
+
393
+ hidden_states = hidden_states + residual
394
+ return hidden_states
395
+
396
+
397
+ class SuryaOCRDecoderPreTrainedModel(PreTrainedModel):
398
+ config_class = SuryaOCRDecoderConfig
399
+ base_model_prefix = "model"
400
+ supports_gradient_checkpointing = True
401
+ _no_split_modules = ["SuryaOCRDecoderLayer"]
402
+ _skip_keys_device_placement = ["cache"]
403
+ _supports_flash_attn_2 = False
404
+ _supports_sdpa = False # we can't compare with eager for now
405
+ _supports_cache_class = True
406
+ _supports_quantized_cache = True
407
+
408
+ def _init_weights(self, module):
409
+ if isinstance(module, SuryaOCRDecoderSdpaAttention):
410
+ torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
411
+ torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
412
+ torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
413
+
414
+ torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
415
+ elif isinstance(module, nn.Linear):
416
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
417
+ if getattr(module, "bias", None) is not None:
418
+ torch.nn.init.zeros_(module.bias)
419
+ elif isinstance(module, nn.Embedding):
420
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
421
+ if module.padding_idx is not None:
422
+ module.weight.data[module.padding_idx].zero_()
423
+
424
+ def _setup_cache(self, config, batch, device, dtype):
425
+ layers = getattr(self, "model", self).layers
426
+ for layer in layers:
427
+ if layer.temporal_block:
428
+ layer.temporal_block._setup_cache(batch, device, dtype)
429
+ if layer.cross_attn_block:
430
+ layer.cross_attn_block._setup_cache(batch, device, dtype)
431
+
432
+ def reset_cache(self, batch, device, dtype):
433
+ pass
434
+
435
+ def _tie_weights(self):
436
+ pass
437
+
438
+ def tie_weights(self):
439
+ pass
440
+
441
+
442
+ class SuryaOCRDecoderModel(SuryaOCRDecoderPreTrainedModel):
443
+ """
444
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaOCRDecoderDecoderLayer`]
445
+
446
+ Args:
447
+ config: SuryaOCRDecoderConfig
448
+ """
449
+
450
+ def __init__(self, config: SuryaOCRDecoderConfig):
451
+ super().__init__(config)
452
+ self.padding_idx = config.pad_token_id
453
+ self.vocab_size = config.vocab_size
454
+ self.causal = config.causal
455
+
456
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
457
+ self.layers = nn.ModuleList(
458
+ [SuryaOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
459
+ )
460
+ self.final_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.gradient_checkpointing = False
462
+
463
+ self.register_buffer(
464
+ "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
465
+ )
466
+ # Initialize weights and apply final processing
467
+ self.post_init()
468
+
469
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
470
+ def get_input_embeddings(self):
471
+ return self.embed_tokens
472
+
473
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
474
+ def set_input_embeddings(self, value):
475
+ self.embed_tokens = value
476
+
477
+ def forward(
478
+ self,
479
+ input_ids: torch.LongTensor = None,
480
+ position_ids: Optional[torch.LongTensor] = None,
481
+ attention_mask: Optional[torch.Tensor] = None,
482
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
483
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
484
+ cache_position: Optional[torch.LongTensor] = None,
485
+ use_cache: Optional[bool] = None,
486
+ output_hidden_states: Optional[bool] = None,
487
+ return_dict: Optional[bool] = None,
488
+ prefill: bool = False
489
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
490
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
491
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
492
+
493
+ if self.gradient_checkpointing and self.training and use_cache:
494
+ use_cache = False
495
+
496
+ inputs_embeds = self.embed_tokens(input_ids)
497
+ hidden_states = inputs_embeds
498
+
499
+ if use_cache and prefill:
500
+ self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
501
+
502
+ if cache_position is None:
503
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
504
+ if position_ids is None:
505
+ position_ids = cache_position.unsqueeze(0)
506
+
507
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
508
+
509
+ all_hidden_states = () if output_hidden_states else None
510
+ for i, residual_block in enumerate(self.layers):
511
+ if output_hidden_states:
512
+ all_hidden_states += (hidden_states,)
513
+ if self.gradient_checkpointing and self.training:
514
+ hidden_states = self._gradient_checkpointing_func(
515
+ residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
516
+ )
517
+ else:
518
+ hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
519
+
520
+ hidden_states = self.final_norm(hidden_states)
521
+
522
+ # add hidden states from the last decoder layer
523
+ if output_hidden_states:
524
+ all_hidden_states += (hidden_states,)
525
+
526
+ if not return_dict:
527
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
528
+
529
+ return BaseModelOutputWithNoAttention(
530
+ last_hidden_state=hidden_states,
531
+ hidden_states=all_hidden_states,
532
+ )
533
+
534
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
535
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
536
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
537
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
538
+ # Ignore copy
539
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
540
+ if not self.causal:
541
+ return None
542
+
543
+ dtype, device = input_tensor.dtype, input_tensor.device
544
+ min_dtype = torch.finfo(dtype).min
545
+ sequence_length = input_tensor.shape[1]
546
+ target_length = max(settings.RECOGNITION_MAX_TOKENS, sequence_length)
547
+
548
+ diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
549
+ causal_mask = diagonal
550
+ if sequence_length != 1:
551
+ # Select the upper triangular part of the matrix, but unmask current token (the diagonal)
552
+ # triu will be the min_dtype, everything else is 0 (attended to)
553
+ causal_mask = torch.triu(diagonal, diagonal=1)
554
+
555
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
556
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
557
+ if attention_mask is not None:
558
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
559
+ if attention_mask.dim() == 2:
560
+ # Mask positions in the causal mask that are masked in the attention mask
561
+ mask_length = attention_mask.shape[-1]
562
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
563
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
564
+
565
+ if attention_mask is not None and attention_mask.device.type == "cuda":
566
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
567
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
568
+ # Details: https://github.com/pytorch/pytorch/issues/110213
569
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
570
+
571
+ return causal_mask
572
+
573
+
574
+ class SuryaOCRDecoder(SuryaOCRDecoderPreTrainedModel):
575
+ _tied_weights_keys = None
576
+
577
+ def __init__(self, config, **kwargs):
578
+ super().__init__(config)
579
+ self.model = SuryaOCRDecoderModel(config)
580
+ self.vocab_size = config.vocab_size
581
+ aux_heads = config.aux_heads if config.aux_heads is not None else 0
582
+ lm_heads = aux_heads + 1
583
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * lm_heads, bias=False)
584
+
585
+ # Initialize weights and apply final processing
586
+ self.post_init()
587
+
588
+ def get_input_embeddings(self):
589
+ return self.model.embed_tokens
590
+
591
+ def set_input_embeddings(self, value):
592
+ self.model.embed_tokens = value
593
+
594
+ def get_output_embeddings(self):
595
+ return self.lm_head
596
+
597
+ def set_output_embeddings(self, new_embeddings):
598
+ self.lm_head = new_embeddings
599
+
600
+ def set_decoder(self, decoder):
601
+ self.model = decoder
602
+
603
+ def get_decoder(self):
604
+ return self.model
605
+
606
+ # Ignore copy
607
+ def forward(
608
+ self,
609
+ input_ids: Optional[torch.LongTensor] = None,
610
+ cache_position: Optional[torch.LongTensor] = None,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ use_cache: Optional[bool] = None,
615
+ prefill: bool = False,
616
+ **kwargs
617
+ ) -> Union[Tuple, OCRModelOutput]:
618
+ outputs = self.model(
619
+ input_ids=input_ids,
620
+ cache_position=cache_position,
621
+ attention_mask=attention_mask,
622
+ encoder_hidden_states=encoder_hidden_states,
623
+ encoder_attention_mask=encoder_attention_mask,
624
+ use_cache=use_cache,
625
+ output_hidden_states=True,
626
+ return_dict=True,
627
+ prefill=prefill,
628
+ )
629
+
630
+ hidden_states = outputs[0]
631
+ all_logits = self.lm_head(hidden_states)
632
+ all_logits = torch.split(all_logits, self.vocab_size, dim=-1)
633
+ logits = all_logits[0]
634
+ aux_logits = all_logits[1:] if len(all_logits) > 1 else None
635
+
636
+ return OCRModelOutput(
637
+ logits=logits,
638
+ aux_logits=aux_logits,
639
+ hidden_states=outputs.hidden_states,
640
+ )
641
+
642
+ @dataclass
643
+ class TextEncoderOutput(CausalLMOutput):
644
+ hidden_states: torch.FloatTensor = None
645
+
646
+
647
+ class SuryaOCRTextEncoder(SuryaOCRDecoderPreTrainedModel):
648
+ _tied_weights_keys = None
649
+ config_class = SuryaOCRTextEncoderConfig
650
+
651
+ def __init__(self, config, **kwargs):
652
+ super().__init__(config)
653
+ self.model = SuryaOCRDecoderModel(config)
654
+ self.vocab_size = config.vocab_size
655
+
656
+ # Initialize weights and apply final processing
657
+ self.post_init()
658
+
659
+ def get_input_embeddings(self):
660
+ return self.model.embed_tokens
661
+
662
+ def set_input_embeddings(self, value):
663
+ self.model.embed_tokens = value
664
+
665
+ def set_decoder(self, decoder):
666
+ self.model = decoder
667
+
668
+ def get_decoder(self):
669
+ return self.model
670
+
671
+ # Ignore copy
672
+ def forward(
673
+ self,
674
+ input_ids: Optional[torch.LongTensor] = None,
675
+ cache_position: Optional[torch.LongTensor] = None,
676
+ attention_mask: Optional[torch.Tensor] = None,
677
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
678
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
679
+ use_cache: Optional[bool] = None,
680
+ **kwargs
681
+ ) -> Union[Tuple, CausalLMOutput]:
682
+ outputs = self.model(
683
+ input_ids=input_ids,
684
+ cache_position=cache_position,
685
+ attention_mask=attention_mask,
686
+ encoder_hidden_states=encoder_hidden_states,
687
+ encoder_attention_mask=encoder_attention_mask,
688
+ use_cache=use_cache,
689
+ output_hidden_states=True,
690
+ return_dict=True,
691
+ )
692
+
693
+ return TextEncoderOutput(
694
+ hidden_states=outputs.last_hidden_state,
695
+ )
surya/model/recognition/encoder.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EfficientViT (by MIT Song Han's Lab)
2
+
3
+ Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
4
+ - https://arxiv.org/abs/2205.14756
5
+
6
+ Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
7
+ Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
8
+ """
9
+
10
+ import collections.abc
11
+ import math
12
+ from dataclasses import dataclass
13
+ from typing import Optional, Tuple, Union
14
+
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from torch import nn
18
+
19
+ from transformers.activations import ACT2FN
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
22
+ from transformers.utils import ModelOutput
23
+ from surya.model.recognition.config import DonutSwinConfig
24
+
25
+ _EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]
26
+
27
+
28
+ @dataclass
29
+ # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
30
+ class DonutSwinEncoderOutput(ModelOutput):
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
34
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
35
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
36
+
37
+
38
+ @dataclass
39
+ class DonutSwinModelOutput(ModelOutput):
40
+ last_hidden_state: torch.FloatTensor = None
41
+
42
+
43
+ # Copied from transformers.models.swin.modeling_swin.window_partition
44
+ def window_partition(input_feature, window_size):
45
+ """
46
+ Partitions the given input into windows.
47
+ """
48
+ batch_size, height, width, num_channels = input_feature.shape
49
+ input_feature = input_feature.view(
50
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
51
+ )
52
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
53
+ return windows
54
+
55
+
56
+ # Copied from transformers.models.swin.modeling_swin.window_reverse
57
+ def window_reverse(windows, window_size, height, width):
58
+ """
59
+ Merges windows to produce higher resolution features.
60
+ """
61
+ num_channels = windows.shape[-1]
62
+ windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
63
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
64
+ return windows
65
+
66
+
67
+ # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
68
+ class DonutSwinEmbeddings(nn.Module):
69
+ """
70
+ Construct the patch and position embeddings. Optionally, also the mask token.
71
+ """
72
+
73
+ def __init__(self, config, use_mask_token=False):
74
+ super().__init__()
75
+
76
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
77
+ num_patches = self.patch_embeddings.num_patches
78
+ self.patch_grid = self.patch_embeddings.grid_size
79
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
80
+
81
+ if config.use_absolute_embeddings:
82
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
83
+ else:
84
+ self.position_embeddings = None
85
+
86
+ self.norm = nn.LayerNorm(config.embed_dim)
87
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
88
+
89
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
90
+ """
91
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
92
+ resolution images.
93
+
94
+ Source:
95
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
96
+ """
97
+
98
+ num_patches = embeddings.shape[1] - 1
99
+ num_positions = self.position_embeddings.shape[1] - 1
100
+ if num_patches == num_positions and height == width:
101
+ return self.position_embeddings
102
+ class_pos_embed = self.position_embeddings[:, 0]
103
+ patch_pos_embed = self.position_embeddings[:, 1:]
104
+ dim = embeddings.shape[-1]
105
+ h0 = height // self.config.patch_size
106
+ w0 = width // self.config.patch_size
107
+ # we add a small number to avoid floating point error in the interpolation
108
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
109
+ h0, w0 = h0 + 0.1, w0 + 0.1
110
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
111
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
112
+ patch_pos_embed = nn.functional.interpolate(
113
+ patch_pos_embed,
114
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
115
+ mode="bicubic",
116
+ align_corners=False,
117
+ )
118
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
119
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
120
+
121
+ def forward(
122
+ self,
123
+ pixel_values: Optional[torch.FloatTensor],
124
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
125
+ interpolate_pos_encoding: bool = False,
126
+ ) -> Tuple[torch.Tensor]:
127
+ _, num_channels, height, width = pixel_values.shape
128
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
129
+ embeddings = self.norm(embeddings)
130
+ batch_size, seq_len, _ = embeddings.size()
131
+
132
+ if bool_masked_pos is not None:
133
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
134
+ # replace the masked visual tokens by mask_tokens
135
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
136
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
137
+
138
+ if self.position_embeddings is not None:
139
+ if interpolate_pos_encoding:
140
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
141
+ else:
142
+ embeddings = embeddings + self.position_embeddings[:, :seq_len]
143
+
144
+ embeddings = self.dropout(embeddings)
145
+
146
+ return embeddings, output_dimensions
147
+
148
+
149
+ # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
150
+ class DonutSwinPatchEmbeddings(nn.Module):
151
+ """
152
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
153
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
154
+ Transformer.
155
+ """
156
+
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ image_size, patch_size = config.image_size, config.patch_size
160
+ num_channels, hidden_size = config.num_channels, config.embed_dim
161
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
162
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
163
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
164
+ self.image_size = image_size
165
+ self.patch_size = patch_size
166
+ self.num_channels = num_channels
167
+ self.num_patches = num_patches
168
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
169
+
170
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
171
+
172
+ def maybe_pad(self, pixel_values, height, width):
173
+ if width % self.patch_size[1] != 0:
174
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
175
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
176
+ if height % self.patch_size[0] != 0:
177
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
178
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
179
+ return pixel_values
180
+
181
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
182
+ _, num_channels, height, width = pixel_values.shape
183
+ # pad the input to be divisible by self.patch_size, if needed
184
+ pixel_values = self.maybe_pad(pixel_values, height, width)
185
+ embeddings = self.projection(pixel_values)
186
+ _, _, height, width = embeddings.shape
187
+ output_dimensions = (height, width)
188
+ embeddings = embeddings.flatten(2).transpose(1, 2)
189
+
190
+ return embeddings, output_dimensions
191
+
192
+
193
+ # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
194
+ class DonutSwinPatchMerging(nn.Module):
195
+ """
196
+ Patch Merging Layer.
197
+
198
+ Args:
199
+ input_resolution (`Tuple[int]`):
200
+ Resolution of input feature.
201
+ dim (`int`):
202
+ Number of input channels.
203
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
204
+ Normalization layer class.
205
+ """
206
+
207
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.dim = dim
211
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
212
+ self.norm = norm_layer(4 * dim)
213
+
214
+ def maybe_pad(self, input_feature, height, width):
215
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
216
+ if should_pad:
217
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
218
+ input_feature = nn.functional.pad(input_feature, pad_values)
219
+
220
+ return input_feature
221
+
222
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
223
+ height, width = input_dimensions
224
+ # `dim` is height * width
225
+ batch_size, dim, num_channels = input_feature.shape
226
+
227
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
228
+ # pad input to be disible by width and height, if needed
229
+ input_feature = self.maybe_pad(input_feature, height, width)
230
+ # [batch_size, height/2, width/2, num_channels]
231
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
232
+ # [batch_size, height/2, width/2, num_channels]
233
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
234
+ # [batch_size, height/2, width/2, num_channels]
235
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
236
+ # [batch_size, height/2, width/2, num_channels]
237
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
238
+ # batch_size height/2 width/2 4*num_channels
239
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
240
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
241
+
242
+ input_feature = self.norm(input_feature)
243
+ input_feature = self.reduction(input_feature)
244
+
245
+ return input_feature
246
+
247
+
248
+ # Copied from transformers.models.beit.modeling_beit.drop_path
249
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
250
+ """
251
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
252
+
253
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
254
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
255
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
256
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
257
+ argument.
258
+ """
259
+ if drop_prob == 0.0 or not training:
260
+ return input
261
+ keep_prob = 1 - drop_prob
262
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
263
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
264
+ random_tensor.floor_() # binarize
265
+ output = input.div(keep_prob) * random_tensor
266
+ return output
267
+
268
+
269
+ # Copied from transformers.models.swin.modeling_swin.SwinDropPath
270
+ class DonutSwinDropPath(nn.Module):
271
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
272
+
273
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
274
+ super().__init__()
275
+ self.drop_prob = drop_prob
276
+
277
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
278
+ return drop_path(hidden_states, self.drop_prob, self.training)
279
+
280
+ def extra_repr(self) -> str:
281
+ return "p={}".format(self.drop_prob)
282
+
283
+
284
+ # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
285
+ class DonutSwinSelfAttention(nn.Module):
286
+ def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
287
+ super().__init__()
288
+ if dim % num_heads != 0:
289
+ raise ValueError(
290
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
291
+ )
292
+
293
+ self.num_attention_heads = num_heads
294
+ self.num_kv_heads = num_kv_heads
295
+ self.kv_repeats = self.num_attention_heads // self.num_kv_heads
296
+ self.attention_head_size = int(dim / num_heads)
297
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
298
+ self.kv_head_size = self.num_kv_heads * self.attention_head_size
299
+ self.window_size = (
300
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
301
+ )
302
+
303
+ self.relative_position_bias_table = nn.Parameter(
304
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
305
+ )
306
+
307
+ # get pair-wise relative position index for each token inside the window
308
+ coords_h = torch.arange(self.window_size[0])
309
+ coords_w = torch.arange(self.window_size[1])
310
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
311
+ coords_flatten = torch.flatten(coords, 1)
312
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
313
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
314
+ relative_coords[:, :, 0] += self.window_size[0] - 1
315
+ relative_coords[:, :, 1] += self.window_size[1] - 1
316
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
317
+ relative_position_index = relative_coords.sum(-1)
318
+ self.register_buffer("relative_position_index", relative_position_index)
319
+
320
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
321
+ self.key = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)
322
+ self.value = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)
323
+
324
+ self.dropout_p = config.attention_probs_dropout_prob
325
+
326
+ def transpose_for_scores(self, x):
327
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
328
+ x = x.view(new_x_shape)
329
+ return x.permute(0, 2, 1, 3)
330
+
331
+ def transpose_kv_for_scores(self, x, repeats):
332
+ new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
333
+ x = x.view(new_x_shape)
334
+ x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
335
+ return x.permute(0, 2, 1, 3).contiguous()
336
+
337
+ def forward(
338
+ self,
339
+ hidden_states: torch.Tensor,
340
+ attention_mask: Optional[torch.FloatTensor] = None,
341
+ head_mask: Optional[torch.FloatTensor] = None,
342
+ output_attentions: Optional[bool] = False,
343
+ ) -> Tuple[torch.Tensor]:
344
+ batch_size, dim, num_channels = hidden_states.shape
345
+ mixed_query_layer = self.query(hidden_states)
346
+
347
+ # Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
348
+ key_layer = self.transpose_kv_for_scores(self.key(hidden_states), self.kv_repeats)
349
+ value_layer = self.transpose_kv_for_scores(self.value(hidden_states), self.kv_repeats)
350
+ query_layer = self.transpose_for_scores(mixed_query_layer)
351
+
352
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
353
+ relative_position_bias = relative_position_bias.view(
354
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
355
+ )
356
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
357
+ if attention_mask is None:
358
+ attention_mask = relative_position_bias
359
+ else:
360
+ mask_shape = attention_mask.shape[0]
361
+ repeat_count = (batch_size // mask_shape)
362
+ attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
363
+ attention_mask = attention_mask + relative_position_bias
364
+
365
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
366
+ query_layer.contiguous(),
367
+ key_layer.contiguous(),
368
+ value_layer.contiguous(),
369
+ attn_mask=attention_mask,
370
+ dropout_p=self.dropout_p if self.training else 0.0,
371
+ scale=self.attention_head_size**-0.5,
372
+ )
373
+
374
+ attn_output = attn_output.transpose(1, 2).contiguous()
375
+ attn_output = attn_output.view(batch_size, dim, num_channels)
376
+
377
+ outputs = (attn_output,)
378
+ return outputs
379
+
380
+ # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
381
+ class DonutSwinSelfOutput(nn.Module):
382
+ def __init__(self, config, dim):
383
+ super().__init__()
384
+ self.dense = nn.Linear(dim, dim)
385
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
386
+
387
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
388
+ hidden_states = self.dense(hidden_states)
389
+ hidden_states = self.dropout(hidden_states)
390
+
391
+ return hidden_states
392
+
393
+
394
+ # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
395
+ class DonutSwinAttention(nn.Module):
396
+ def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
397
+ super().__init__()
398
+ self.self = DonutSwinSelfAttention(config, dim, num_heads, num_kv_heads, window_size)
399
+ self.output = DonutSwinSelfOutput(config, dim)
400
+ self.pruned_heads = set()
401
+
402
+ def prune_heads(self, heads):
403
+ if len(heads) == 0:
404
+ return
405
+ heads, index = find_pruneable_heads_and_indices(
406
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
407
+ )
408
+
409
+ # Prune linear layers
410
+ self.self.query = prune_linear_layer(self.self.query, index)
411
+ self.self.key = prune_linear_layer(self.self.key, index)
412
+ self.self.value = prune_linear_layer(self.self.value, index)
413
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
414
+
415
+ # Update hyper params and store pruned heads
416
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
417
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
418
+ self.pruned_heads = self.pruned_heads.union(heads)
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.FloatTensor] = None,
424
+ head_mask: Optional[torch.FloatTensor] = None,
425
+ output_attentions: Optional[bool] = False,
426
+ ) -> Tuple[torch.Tensor]:
427
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
428
+ attention_output = self.output(self_outputs[0], hidden_states)
429
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
430
+ return outputs
431
+
432
+
433
+ # Copied from transformers.models.swin.modeling_swin.SwinIntermediate
434
+ class DonutSwinIntermediate(nn.Module):
435
+ def __init__(self, config, dim):
436
+ super().__init__()
437
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
438
+ if isinstance(config.hidden_act, str):
439
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
440
+ else:
441
+ self.intermediate_act_fn = config.hidden_act
442
+
443
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
444
+ hidden_states = self.dense(hidden_states)
445
+ hidden_states = self.intermediate_act_fn(hidden_states)
446
+ return hidden_states
447
+
448
+
449
+ # Copied from transformers.models.swin.modeling_swin.SwinOutput
450
+ class DonutSwinOutput(nn.Module):
451
+ def __init__(self, config, dim):
452
+ super().__init__()
453
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
454
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
455
+
456
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
457
+ hidden_states = self.dense(hidden_states)
458
+ hidden_states = self.dropout(hidden_states)
459
+ return hidden_states
460
+
461
+
462
+ # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
463
+ class DonutSwinLayer(nn.Module):
464
+ def __init__(self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0):
465
+ super().__init__()
466
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
467
+ self.shift_size = shift_size
468
+ self.window_size = config.window_size
469
+ self.input_resolution = input_resolution
470
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
471
+ self.attention = DonutSwinAttention(config, dim, num_heads, num_kv_heads, window_size=self.window_size)
472
+ self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
473
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
474
+ self.intermediate = DonutSwinIntermediate(config, dim)
475
+ self.output = DonutSwinOutput(config, dim)
476
+
477
+ def set_shift_and_window_size(self, input_resolution):
478
+ if min(input_resolution) <= self.window_size:
479
+ # if window size is larger than input resolution, we don't partition windows
480
+ self.shift_size = int(0)
481
+ self.window_size = (
482
+ torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
483
+ )
484
+
485
+ def get_attn_mask(self, height, width, dtype, device):
486
+ if self.shift_size > 0:
487
+ # calculate attention mask for SW-MSA
488
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
489
+ height_slices = (
490
+ slice(0, -self.window_size),
491
+ slice(-self.window_size, -self.shift_size),
492
+ slice(-self.shift_size, None),
493
+ )
494
+ width_slices = (
495
+ slice(0, -self.window_size),
496
+ slice(-self.window_size, -self.shift_size),
497
+ slice(-self.shift_size, None),
498
+ )
499
+ count = 0
500
+ for height_slice in height_slices:
501
+ for width_slice in width_slices:
502
+ img_mask[:, height_slice, width_slice, :] = count
503
+ count += 1
504
+
505
+ mask_windows = window_partition(img_mask, self.window_size)
506
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
507
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
508
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
509
+ else:
510
+ attn_mask = None
511
+ return attn_mask
512
+
513
+ def maybe_pad(self, hidden_states, height, width):
514
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
515
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
516
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
517
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
518
+ return hidden_states, pad_values
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.Tensor,
523
+ input_dimensions: Tuple[int, int],
524
+ head_mask: Optional[torch.FloatTensor] = None,
525
+ output_attentions: Optional[bool] = False,
526
+ always_partition: Optional[bool] = False,
527
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
528
+ if not always_partition:
529
+ self.set_shift_and_window_size(input_dimensions)
530
+ else:
531
+ pass
532
+ height, width = input_dimensions
533
+ batch_size, _, channels = hidden_states.size()
534
+ shortcut = hidden_states
535
+
536
+ hidden_states = self.layernorm_before(hidden_states)
537
+
538
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
539
+
540
+ # pad hidden_states to multiples of window size
541
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
542
+
543
+ _, height_pad, width_pad, _ = hidden_states.shape
544
+ # cyclic shift
545
+ if self.shift_size > 0:
546
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
547
+ else:
548
+ shifted_hidden_states = hidden_states
549
+
550
+ # partition windows
551
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
552
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
553
+ attn_mask = self.get_attn_mask(
554
+ height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
555
+ )
556
+
557
+ attention_outputs = self.attention(
558
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
559
+ )
560
+
561
+ attention_output = attention_outputs[0]
562
+
563
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
564
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
565
+
566
+ # reverse cyclic shift
567
+ if self.shift_size > 0:
568
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
569
+ else:
570
+ attention_windows = shifted_windows
571
+
572
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
573
+ if was_padded:
574
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
575
+
576
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
577
+
578
+ hidden_states = shortcut + self.drop_path(attention_windows)
579
+
580
+ layer_output = self.layernorm_after(hidden_states)
581
+ layer_output = self.intermediate(layer_output)
582
+ layer_output = hidden_states + self.output(layer_output)
583
+
584
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
585
+ return layer_outputs
586
+
587
+
588
+ # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
589
+ class DonutSwinStage(nn.Module):
590
+ def __init__(self, config, dim, input_resolution, depth, num_heads, num_kv_heads, drop_path, downsample):
591
+ super().__init__()
592
+ self.config = config
593
+ self.dim = dim
594
+ self.blocks = nn.ModuleList(
595
+ [
596
+ DonutSwinLayer(
597
+ config=config,
598
+ dim=dim,
599
+ input_resolution=input_resolution,
600
+ num_heads=num_heads,
601
+ num_kv_heads=num_kv_heads,
602
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
603
+ )
604
+ for i in range(depth)
605
+ ]
606
+ )
607
+
608
+ # patch merging layer
609
+ if downsample is not None:
610
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
611
+ else:
612
+ self.downsample = None
613
+
614
+ self.pointing = False
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states: torch.Tensor,
619
+ input_dimensions: Tuple[int, int],
620
+ head_mask: Optional[torch.FloatTensor] = None,
621
+ output_attentions: Optional[bool] = False,
622
+ always_partition: Optional[bool] = False,
623
+ ) -> Tuple[torch.Tensor]:
624
+ height, width = input_dimensions
625
+ for i, layer_module in enumerate(self.blocks):
626
+ layer_head_mask = head_mask[i] if head_mask is not None else None
627
+
628
+ layer_outputs = layer_module(
629
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
630
+ )
631
+
632
+ hidden_states = layer_outputs[0]
633
+
634
+ hidden_states_before_downsampling = hidden_states
635
+ if self.downsample is not None:
636
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
637
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
638
+ hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
639
+ else:
640
+ output_dimensions = (height, width, height, width)
641
+
642
+ stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
643
+
644
+ if output_attentions:
645
+ stage_outputs += layer_outputs[1:]
646
+ return stage_outputs
647
+
648
+
649
+ # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
650
+ class DonutSwinEncoder(nn.Module):
651
+ def __init__(self, config, grid_size):
652
+ super().__init__()
653
+ self.num_layers = len(config.depths)
654
+ self.config = config
655
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
656
+ self.layers = nn.ModuleList(
657
+ [
658
+ DonutSwinStage(
659
+ config=config,
660
+ dim=int(config.embed_dim * 2**i_layer),
661
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
662
+ depth=config.depths[i_layer],
663
+ num_heads=config.num_heads[i_layer],
664
+ num_kv_heads=config.num_kv_heads[i_layer],
665
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
666
+ downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
667
+ )
668
+ for i_layer in range(self.num_layers)
669
+ ]
670
+ )
671
+
672
+ self.gradient_checkpointing = False
673
+
674
+ def forward(
675
+ self,
676
+ hidden_states: torch.Tensor,
677
+ input_dimensions: Tuple[int, int],
678
+ head_mask: Optional[torch.FloatTensor] = None,
679
+ output_attentions: Optional[bool] = False,
680
+ output_hidden_states: Optional[bool] = False,
681
+ output_hidden_states_before_downsampling: Optional[bool] = False,
682
+ always_partition: Optional[bool] = False,
683
+ return_dict: Optional[bool] = True,
684
+ ) -> Union[Tuple, DonutSwinEncoderOutput]:
685
+ all_hidden_states = () if output_hidden_states else None
686
+ all_reshaped_hidden_states = () if output_hidden_states else None
687
+ all_self_attentions = () if output_attentions else None
688
+
689
+ if output_hidden_states:
690
+ batch_size, _, hidden_size = hidden_states.shape
691
+ # rearrange b (h w) c -> b c h w
692
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
693
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
694
+ all_hidden_states += (hidden_states,)
695
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
696
+
697
+ for i, layer_module in enumerate(self.layers):
698
+ layer_head_mask = head_mask[i] if head_mask is not None else None
699
+
700
+ if self.gradient_checkpointing and self.training:
701
+ layer_outputs = self._gradient_checkpointing_func(
702
+ layer_module.__call__,
703
+ hidden_states,
704
+ input_dimensions,
705
+ layer_head_mask,
706
+ output_attentions,
707
+ always_partition,
708
+ )
709
+ else:
710
+ layer_outputs = layer_module(
711
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
712
+ )
713
+
714
+ hidden_states = layer_outputs[0]
715
+ hidden_states_before_downsampling = layer_outputs[1]
716
+ output_dimensions = layer_outputs[2]
717
+
718
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
719
+
720
+ if output_hidden_states and output_hidden_states_before_downsampling:
721
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
722
+ # rearrange b (h w) c -> b c h w
723
+ # here we use the original (not downsampled) height and width
724
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
725
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
726
+ )
727
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
728
+ all_hidden_states += (hidden_states_before_downsampling,)
729
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
730
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
731
+ batch_size, _, hidden_size = hidden_states.shape
732
+ # rearrange b (h w) c -> b c h w
733
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
734
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
735
+ all_hidden_states += (hidden_states,)
736
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
737
+
738
+ if output_attentions:
739
+ all_self_attentions += layer_outputs[3:]
740
+
741
+ if not return_dict:
742
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
743
+
744
+ return DonutSwinEncoderOutput(
745
+ last_hidden_state=hidden_states,
746
+ hidden_states=all_hidden_states,
747
+ attentions=all_self_attentions,
748
+ reshaped_hidden_states=all_reshaped_hidden_states,
749
+ )
750
+
751
+
752
+ # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
753
+ class DonutSwinPreTrainedModel(PreTrainedModel):
754
+ """
755
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
756
+ models.
757
+ """
758
+
759
+ config_class = DonutSwinConfig
760
+ base_model_prefix = "swin"
761
+ main_input_name = "pixel_values"
762
+ supports_gradient_checkpointing = True
763
+ _no_split_modules = ["DonutSwinStage"]
764
+
765
+ def _init_weights(self, module):
766
+ """Initialize the weights"""
767
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
768
+ # Slightly different from the TF version which uses truncated_normal for initialization
769
+ # cf https://github.com/pytorch/pytorch/pull/5617
770
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
771
+ if module.bias is not None:
772
+ module.bias.data.zero_()
773
+ elif isinstance(module, nn.LayerNorm):
774
+ module.bias.data.zero_()
775
+ module.weight.data.fill_(1.0)
776
+
777
+
778
+ class DonutSwinModel(DonutSwinPreTrainedModel):
779
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
780
+ super().__init__(config)
781
+ self.config = config
782
+ self.num_layers = len(config.depths)
783
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
784
+
785
+ self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
786
+ self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
787
+
788
+ self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size))
789
+
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ def get_input_embeddings(self):
794
+ return self.embeddings.patch_embeddings
795
+
796
+ def _prune_heads(self, heads_to_prune):
797
+ """
798
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
799
+ class PreTrainedModel
800
+ """
801
+ for layer, heads in heads_to_prune.items():
802
+ self.encoder.layer[layer].attention.prune_heads(heads)
803
+
804
+ def forward(
805
+ self,
806
+ pixel_values: Optional[torch.FloatTensor] = None,
807
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
808
+ head_mask: Optional[torch.FloatTensor] = None,
809
+ output_attentions: Optional[bool] = None,
810
+ output_hidden_states: Optional[bool] = None,
811
+ interpolate_pos_encoding: bool = False,
812
+ return_dict: Optional[bool] = None,
813
+ ) -> Union[Tuple, DonutSwinModelOutput]:
814
+ r"""
815
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
816
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
817
+ """
818
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
819
+ output_hidden_states = (
820
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
821
+ )
822
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
823
+
824
+ if pixel_values is None:
825
+ raise ValueError("You have to specify pixel_values")
826
+
827
+ # Prepare head mask if needed
828
+ # 1.0 in head_mask indicate we keep the head
829
+ # attention_probs has shape bsz x n_heads x N x N
830
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
831
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
832
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
833
+
834
+ embedding_output, input_dimensions = self.embeddings(
835
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
836
+ )
837
+
838
+ encoder_outputs = self.encoder(
839
+ embedding_output,
840
+ input_dimensions,
841
+ head_mask=head_mask,
842
+ output_attentions=output_attentions,
843
+ output_hidden_states=output_hidden_states,
844
+ return_dict=return_dict,
845
+ )
846
+
847
+ last_hidden_state = encoder_outputs[0]
848
+ last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :]
849
+
850
+ return DonutSwinModelOutput(
851
+ last_hidden_state=last_hidden_state,
852
+ )
surya/model/recognition/encoderdecoder.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
5
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
6
+ from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
7
+ from surya.model.recognition.encoder import DonutSwinModel
8
+ from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder
9
+
10
+
11
+ class OCREncoderDecoderModel(PreTrainedModel):
12
+ config_class = VisionEncoderDecoderConfig
13
+ base_model_prefix = "vision_encoder_decoder"
14
+ main_input_name = "pixel_values"
15
+ supports_gradient_checkpointing = True
16
+ _supports_param_buffer_assignment = False
17
+
18
+ def __init__(
19
+ self,
20
+ config: Optional[PretrainedConfig] = None,
21
+ encoder: Optional[PreTrainedModel] = None,
22
+ decoder: Optional[PreTrainedModel] = None,
23
+ text_encoder: Optional[PreTrainedModel] = None,
24
+ ):
25
+ # initialize with config
26
+ # make sure input & output embeddings is not tied
27
+ config.tie_word_embeddings = False
28
+ config.decoder.tie_word_embeddings = False
29
+ super().__init__(config)
30
+
31
+ if encoder is None:
32
+ encoder = DonutSwinModel(config.encoder)
33
+
34
+ if decoder is None:
35
+ decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation)
36
+
37
+ if text_encoder is None:
38
+ text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
39
+
40
+ self.encoder = encoder
41
+ self.decoder = decoder
42
+ self.text_encoder = text_encoder
43
+
44
+ # make sure that the individual model's config refers to the shared config
45
+ # so that the updates to the config will be synced
46
+ self.encoder.config = self.config.encoder
47
+ self.decoder.config = self.config.decoder
48
+ self.text_encoder.config = self.config.text_encoder
49
+
50
+ def get_encoder(self):
51
+ return self.encoder
52
+
53
+ def get_decoder(self):
54
+ return self.decoder
55
+
56
+ def get_output_embeddings(self):
57
+ return self.decoder.get_output_embeddings()
58
+
59
+ def set_output_embeddings(self, new_embeddings):
60
+ return self.decoder.set_output_embeddings(new_embeddings)
61
+
62
+ def forward(
63
+ self,
64
+ pixel_values: Optional[torch.FloatTensor] = None,
65
+ decoder_input_ids: Optional[torch.LongTensor] = None,
66
+ decoder_cache_position: Optional[torch.LongTensor] = None,
67
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
68
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
69
+ use_cache: Optional[bool] = None,
70
+ **kwargs,
71
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
72
+
73
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
74
+
75
+ kwargs_decoder = {
76
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
77
+ }
78
+
79
+ if encoder_outputs is None:
80
+ if pixel_values is None:
81
+ raise ValueError("You have to specify pixel_values")
82
+
83
+ encoder_outputs = self.encoder(
84
+ pixel_values=pixel_values,
85
+ **kwargs_encoder,
86
+ )
87
+ elif isinstance(encoder_outputs, tuple):
88
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
89
+
90
+ encoder_hidden_states = encoder_outputs[0]
91
+
92
+ # optionally project encoder_hidden_states
93
+ if (
94
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
95
+ and self.decoder.config.cross_attention_hidden_size is None
96
+ ):
97
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
98
+
99
+ # else:
100
+ encoder_attention_mask = None
101
+
102
+ # Decode
103
+ decoder_outputs = self.decoder(
104
+ input_ids=decoder_input_ids,
105
+ cache_position=decoder_cache_position,
106
+ attention_mask=decoder_attention_mask,
107
+ encoder_hidden_states=encoder_hidden_states,
108
+ encoder_attention_mask=encoder_attention_mask,
109
+ use_cache=use_cache,
110
+ **kwargs_decoder,
111
+ )
112
+
113
+ return Seq2SeqLMOutput(
114
+ logits=decoder_outputs.logits,
115
+ decoder_hidden_states=decoder_outputs.hidden_states,
116
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state
117
+ )
118
+
119
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
120
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
121
+
122
+ def prepare_inputs_for_generation(
123
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
124
+ ):
125
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
126
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
127
+ input_dict = {
128
+ "attention_mask": attention_mask,
129
+ "decoder_attention_mask": decoder_attention_mask,
130
+ "decoder_input_ids": decoder_inputs["input_ids"],
131
+ "encoder_outputs": encoder_outputs,
132
+ "past_key_values": decoder_inputs["past_key_values"],
133
+ "use_cache": use_cache,
134
+ }
135
+ return input_dict
136
+
137
+ def resize_token_embeddings(self, *args, **kwargs):
138
+ raise NotImplementedError(
139
+ "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
140
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
141
+ )
142
+
143
+ def _reorder_cache(self, past_key_values, beam_idx):
144
+ # apply decoder cache reordering here
145
+ return self.decoder._reorder_cache(past_key_values, beam_idx)
surya/model/recognition/model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+
5
+ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
6
+
7
+ import logging
8
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
9
+
10
+ from typing import List, Optional, Tuple
11
+ from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
12
+ from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig
13
+ from surya.model.recognition.encoder import DonutSwinModel
14
+ from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder
15
+ from surya.settings import settings
16
+
17
+ if not settings.ENABLE_EFFICIENT_ATTENTION:
18
+ print("Efficient attention is disabled. This will use significantly more VRAM.")
19
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
20
+ torch.backends.cuda.enable_flash_sdp(True)
21
+ torch.backends.cuda.enable_math_sdp(True)
22
+
23
+
24
+ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
25
+
26
+ config = SuryaOCRConfig.from_pretrained(checkpoint)
27
+ decoder_config = config.decoder
28
+ decoder = SuryaOCRDecoderConfig(**decoder_config)
29
+ config.decoder = decoder
30
+
31
+ encoder_config = config.encoder
32
+ encoder = DonutSwinConfig(**encoder_config)
33
+ config.encoder = encoder
34
+
35
+ text_encoder_config = config.text_encoder
36
+ text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config)
37
+ config.text_encoder = text_encoder
38
+
39
+ model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
40
+
41
+ assert isinstance(model.decoder, SuryaOCRDecoder)
42
+ assert isinstance(model.encoder, DonutSwinModel)
43
+ assert isinstance(model.text_encoder, SuryaOCRTextEncoder)
44
+
45
+ model = model.to(device)
46
+ model = model.eval()
47
+
48
+ print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
49
+ return model
surya/model/recognition/processor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union, Optional, List, Iterable
2
+
3
+ import cv2
4
+ from torch import TensorType
5
+ from transformers import DonutImageProcessor, DonutProcessor
6
+ from transformers.image_processing_utils import BatchFeature
7
+ from transformers.image_transforms import pad, normalize
8
+ from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
9
+ import numpy as np
10
+ from PIL import Image
11
+ import PIL
12
+ from surya.model.recognition.tokenizer import Byt5LangTokenizer
13
+ from surya.settings import settings
14
+
15
+
16
+ def load_processor():
17
+ processor = SuryaProcessor()
18
+ processor.image_processor.train = False
19
+ processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE
20
+ processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS
21
+ return processor
22
+
23
+
24
+ class SuryaImageProcessor(DonutImageProcessor):
25
+ def __init__(self, *args, max_size=None, train=False, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+
28
+ self.patch_size = kwargs.get("patch_size", (4, 4))
29
+ self.max_size = max_size
30
+ self.train = train
31
+
32
+ @classmethod
33
+ def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
34
+ max_width, max_height = size["width"], size["height"]
35
+
36
+ resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
37
+ resized_image = resized_image.transpose(2, 0, 1)
38
+
39
+ return resized_image
40
+
41
+ def process_inner(self, images: List[np.ndarray]):
42
+ assert images[0].shape[2] == 3 # RGB input images, channel dim last
43
+
44
+ # Rotate if the bbox is wider than it is tall
45
+ images = [SuryaImageProcessor.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images]
46
+
47
+ # Verify that the image is wider than it is tall
48
+ for img in images:
49
+ assert img.shape[1] >= img.shape[0]
50
+
51
+ # This also applies the right channel dim format, to channel x height x width
52
+ images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
53
+ assert images[0].shape[0] == 3 # RGB input images, channel dim first
54
+
55
+ # Convert to float32 for rescale/normalize
56
+ images = [img.astype(np.float32) for img in images]
57
+
58
+ # Pads with 255 (whitespace)
59
+ # Pad to max size to improve performance
60
+ max_size = self.max_size
61
+ images = [
62
+ SuryaImageProcessor.pad_image(
63
+ image=image,
64
+ size=max_size,
65
+ input_data_format=ChannelDimension.FIRST,
66
+ pad_value=settings.RECOGNITION_PAD_VALUE
67
+ )
68
+ for image in images
69
+ ]
70
+ # Rescale and normalize
71
+ for idx in range(len(images)):
72
+ images[idx] = images[idx] * self.rescale_factor
73
+ images = [
74
+ SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
75
+ for img in images
76
+ ]
77
+
78
+ return images
79
+
80
+ def preprocess(
81
+ self,
82
+ images: ImageInput,
83
+ do_resize: bool = None,
84
+ size: Dict[str, int] = None,
85
+ resample: PILImageResampling = None,
86
+ do_thumbnail: bool = None,
87
+ do_align_long_axis: bool = None,
88
+ do_pad: bool = None,
89
+ random_padding: bool = False,
90
+ do_rescale: bool = None,
91
+ rescale_factor: float = None,
92
+ do_normalize: bool = None,
93
+ image_mean: Optional[Union[float, List[float]]] = None,
94
+ image_std: Optional[Union[float, List[float]]] = None,
95
+ return_tensors: Optional[Union[str, TensorType]] = None,
96
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
97
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
98
+ **kwargs,
99
+ ) -> PIL.Image.Image:
100
+ images = make_list_of_images(images)
101
+
102
+ # Convert to numpy for later processing steps
103
+ images = [np.array(img) for img in images]
104
+ images = self.process_inner(images)
105
+
106
+ data = {"pixel_values": images}
107
+ return BatchFeature(data=data, tensor_type=return_tensors)
108
+
109
+ @classmethod
110
+ def pad_image(
111
+ cls,
112
+ image: np.ndarray,
113
+ size: Dict[str, int],
114
+ data_format: Optional[Union[str, ChannelDimension]] = None,
115
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
116
+ pad_value: float = 0.0,
117
+ ) -> np.ndarray:
118
+ output_height, output_width = size["height"], size["width"]
119
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
120
+
121
+ delta_width = output_width - input_width
122
+ delta_height = output_height - input_height
123
+
124
+ assert delta_width >= 0 and delta_height >= 0
125
+
126
+ pad_top = delta_height // 2
127
+ pad_left = delta_width // 2
128
+
129
+ pad_bottom = delta_height - pad_top
130
+ pad_right = delta_width - pad_left
131
+
132
+ padding = ((pad_top, pad_bottom), (pad_left, pad_right))
133
+ return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
134
+
135
+ @classmethod
136
+ def align_long_axis(
137
+ cls,
138
+ image: np.ndarray,
139
+ size: Dict[str, int],
140
+ data_format: Optional[Union[str, ChannelDimension]] = None,
141
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
142
+ ) -> np.ndarray:
143
+ input_height, input_width = image.shape[:2]
144
+ output_height, output_width = size["height"], size["width"]
145
+
146
+ if (output_width < output_height and input_width > input_height) or (
147
+ output_width > output_height and input_width < input_height
148
+ ):
149
+ image = np.rot90(image, 3)
150
+
151
+ return image
152
+
153
+ @classmethod
154
+ def normalize(
155
+ cls,
156
+ image: np.ndarray,
157
+ mean: Union[float, Iterable[float]],
158
+ std: Union[float, Iterable[float]],
159
+ data_format: Optional[Union[str, ChannelDimension]] = None,
160
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
161
+ **kwargs,
162
+ ) -> np.ndarray:
163
+ return normalize(
164
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
165
+ )
166
+
167
+
168
+ class SuryaProcessor(DonutProcessor):
169
+ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
170
+ image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
171
+ tokenizer = Byt5LangTokenizer()
172
+ if image_processor is None:
173
+ raise ValueError("You need to specify an `image_processor`.")
174
+ if tokenizer is None:
175
+ raise ValueError("You need to specify a `tokenizer`.")
176
+
177
+ super().__init__(image_processor, tokenizer)
178
+ self.current_processor = self.image_processor
179
+ self._in_target_context_manager = False
180
+
181
+ def __call__(self, *args, **kwargs):
182
+ images = kwargs.pop("images", None)
183
+ text = kwargs.pop("text", None)
184
+ langs = kwargs.pop("langs", None)
185
+
186
+ if len(args) > 0:
187
+ images = args[0]
188
+ args = args[1:]
189
+
190
+ if images is None and text is None:
191
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
192
+
193
+ if images is not None:
194
+ inputs = self.image_processor(images, *args, **kwargs)
195
+
196
+ if text is not None:
197
+ encodings = self.tokenizer(text, langs, **kwargs)
198
+
199
+ if text is None:
200
+ return inputs
201
+ elif images is None:
202
+ return encodings
203
+ else:
204
+ inputs["labels"] = encodings["input_ids"]
205
+ inputs["langs"] = encodings["langs"]
206
+ return inputs
surya/model/recognition/tokenizer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ import random
3
+ from typing import List, Optional, Tuple, Union
4
+ from tokenizers import AddedToken
5
+ from transformers import ByT5Tokenizer
6
+ import numpy as np
7
+ import torch
8
+ from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET
9
+
10
+
11
+ def text_to_utf16_numbers(text):
12
+ utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling
13
+
14
+ numbers = []
15
+
16
+ # Iterate through each pair of bytes and combine them into a single number
17
+ for i in range(0, len(utf16_bytes), 2):
18
+ # Combine two adjacent bytes into a single number
19
+ number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
20
+ numbers.append(number)
21
+
22
+ return numbers
23
+
24
+
25
+ def utf16_numbers_to_text(numbers):
26
+ byte_array = bytearray()
27
+ for number in numbers:
28
+ # Extract the two bytes from the number and add them to the byte array
29
+ byte_array.append(number & 0xFF) # Lower byte
30
+ byte_array.append((number >> 8) & 0xFF) # Upper byte
31
+
32
+ text = byte_array.decode('utf-16le', errors="ignore")
33
+ return text
34
+
35
+
36
+ def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True):
37
+ tokens = text_to_utf16_numbers(text)
38
+ tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens
39
+
40
+ lang_list = []
41
+ if langs:
42
+ for lang in langs:
43
+ code = LANGUAGE_MAP[lang]
44
+ lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS)
45
+
46
+ tokens = lang_list + tokens
47
+
48
+ if add_bos:
49
+ tokens.insert(0, eos_token_id)
50
+
51
+ return tokens, lang_list
52
+
53
+
54
+ class Byt5LangTokenizer(ByT5Tokenizer):
55
+ def __init__(self,
56
+ eos_token="</s>",
57
+ unk_token="<unk>",
58
+ pad_token="<pad>",
59
+ model_max_length=None,
60
+ **kwargs,
61
+ ):
62
+ self.pad_token = pad_token
63
+ self.eos_token = eos_token
64
+ self.unk_token = unk_token
65
+ self.bos_token = eos_token
66
+ self.offset = TOKEN_OFFSET
67
+
68
+ self.pad_id = 0
69
+ self.eos_id = 1
70
+ self.unk_id = 2
71
+
72
+ self.model_max_length = model_max_length
73
+ self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS
74
+
75
+ super().__init__()
76
+
77
+ def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs):
78
+ tokenized = []
79
+ all_langs = []
80
+
81
+ is_list = True
82
+ # Convert to list of lists format
83
+ if isinstance(texts, str):
84
+ texts = [texts]
85
+ is_list = False
86
+
87
+ if langs is None:
88
+ langs = [None] * len(texts)
89
+
90
+ if isinstance(langs[0], str):
91
+ langs = [langs]
92
+
93
+ assert len(langs) == len(texts)
94
+
95
+ for text, lang in zip(texts, langs):
96
+ tokens, lang_list = _tokenize(text, lang)
97
+ tokenized.append(tokens)
98
+ all_langs.append(lang_list)
99
+
100
+ # Convert back to flat format
101
+ if not is_list:
102
+ tokenized = tokenized[0]
103
+ all_langs = all_langs[0]
104
+
105
+ return {"input_ids": tokenized, "langs": all_langs}
106
+
107
+ def decode(
108
+ self,
109
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
110
+ skip_special_tokens: bool = False,
111
+ clean_up_tokenization_spaces: bool = None,
112
+ **kwargs,
113
+ ) -> str:
114
+ if isinstance(token_ids, (np.ndarray, torch.Tensor)):
115
+ token_ids = token_ids.tolist()
116
+
117
+ token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start]
118
+ token_ids = [t - TOKEN_OFFSET for t in token_ids]
119
+ text = utf16_numbers_to_text(token_ids)
120
+ return text
surya/model/table_rec/config.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from surya.settings import settings
3
+
4
+ BOX_DIM = 1024
5
+ SPECIAL_TOKENS = 7
6
+ MAX_ROWS = 384
7
+
8
+
9
+ class SuryaTableRecConfig(PretrainedConfig):
10
+ model_type = "vision-encoder-decoder"
11
+ is_composition = True
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+
16
+ encoder_config = kwargs.pop("encoder")
17
+ decoder_config = kwargs.pop("decoder")
18
+ text_enc_config = kwargs.pop("text_encoder")
19
+
20
+ self.encoder = encoder_config
21
+ self.decoder = decoder_config
22
+ self.text_encoder = text_enc_config
23
+ self.is_encoder_decoder = True
24
+
25
+ if isinstance(decoder_config, dict):
26
+ self.decoder_start_token_id = decoder_config["bos_token_id"]
27
+ self.pad_token_id = decoder_config["pad_token_id"]
28
+ self.eos_token_id = decoder_config["eos_token_id"]
29
+ else:
30
+ self.decoder_start_token_id = decoder_config.bos_token_id
31
+ self.pad_token_id = decoder_config.pad_token_id
32
+ self.eos_token_id = decoder_config.eos_token_id
33
+
34
+
35
+ class DonutSwinTableRecConfig(PretrainedConfig):
36
+ model_type = "donut-swin"
37
+
38
+ attribute_map = {
39
+ "num_attention_heads": "num_heads",
40
+ "num_hidden_layers": "num_layers",
41
+ }
42
+
43
+ def __init__(
44
+ self,
45
+ image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
46
+ patch_size=4,
47
+ num_channels=3,
48
+ embed_dim=128,
49
+ depths=[2, 2, 14, 2],
50
+ num_heads=[4, 8, 16, 32],
51
+ num_kv_heads=[4, 8, 16, 32],
52
+ window_size=8,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ hidden_dropout_prob=0.0,
56
+ attention_probs_dropout_prob=0.0,
57
+ drop_path_rate=0.1,
58
+ hidden_act="gelu",
59
+ use_absolute_embeddings=True,
60
+ initializer_range=0.02,
61
+ layer_norm_eps=1e-5,
62
+ encoder_length=1024,
63
+ **kwargs,
64
+ ):
65
+ super().__init__(**kwargs)
66
+
67
+ self.image_size = image_size
68
+ self.patch_size = patch_size
69
+ self.num_channels = num_channels
70
+ self.embed_dim = embed_dim
71
+ self.depths = depths
72
+ self.num_layers = len(depths)
73
+ self.num_heads = num_heads
74
+ self.num_kv_heads = num_kv_heads
75
+ self.window_size = window_size
76
+ self.mlp_ratio = mlp_ratio
77
+ self.qkv_bias = qkv_bias
78
+ self.hidden_dropout_prob = hidden_dropout_prob
79
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
80
+ self.drop_path_rate = drop_path_rate
81
+ self.hidden_act = hidden_act
82
+ self.use_absolute_embeddings = use_absolute_embeddings
83
+ self.layer_norm_eps = layer_norm_eps
84
+ self.initializer_range = initializer_range
85
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
86
+ # this indicates the channel dimension after the last stage of the model
87
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
88
+ self.encoder_length = encoder_length
89
+
90
+
91
+ class SuryaTableRecDecoderConfig(PretrainedConfig):
92
+ model_type = "surya_tablerec"
93
+
94
+ def __init__(
95
+ self,
96
+ num_hidden_layers=3,
97
+ vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
98
+ hidden_size=512,
99
+ intermediate_size=4 * 512,
100
+ encoder_hidden_size=1024,
101
+ num_attention_heads=8,
102
+ lru_width=None,
103
+ attention_window_size=16,
104
+ conv1d_width=4,
105
+ logits_soft_cap=30.0,
106
+ rms_norm_eps=1e-6,
107
+ use_cache=True,
108
+ pad_token_id=0,
109
+ eos_token_id=1,
110
+ bos_token_id=2,
111
+ hidden_activation="gelu_pytorch_tanh",
112
+ rope_theta=10000.0,
113
+ block_types=("attention",),
114
+ cross_attn_layers=(0, 1, 2, 3),
115
+ encoder_cross_attn_layers=(0, 1, 2, 3),
116
+ self_attn_layers=(0, 1, 2, 3),
117
+ global_attn_layers=(0, 1, 2, 3),
118
+ attention_dropout=0.0,
119
+ num_key_value_heads=4,
120
+ attention_bias=False,
121
+ w_init_variance_scale=0.01,
122
+ init_std=0.02,
123
+ tie_word_embeddings=False,
124
+ aux_heads=0, # How many n-token-ahead heads to add
125
+ causal=True,
126
+ max_classes=2 + SPECIAL_TOKENS,
127
+ max_width=1024 + SPECIAL_TOKENS,
128
+ max_height=1024 + SPECIAL_TOKENS,
129
+ out_box_size=1024,
130
+ **kwargs,
131
+ ):
132
+ self.num_hidden_layers = num_hidden_layers
133
+ self.vocab_size = vocab_size
134
+ self.hidden_size = hidden_size
135
+ self.intermediate_size = intermediate_size
136
+ self.num_attention_heads = num_attention_heads
137
+ self.lru_width = lru_width if lru_width is not None else hidden_size
138
+ self.attention_window_size = attention_window_size
139
+ self.conv1d_width = conv1d_width
140
+ self.logits_soft_cap = logits_soft_cap
141
+ self.rms_norm_eps = rms_norm_eps
142
+ self.use_cache = use_cache
143
+ self.rope_theta = rope_theta
144
+ self.block_types = list(block_types)
145
+ self.hidden_activation = hidden_activation
146
+ self.head_dim = self.hidden_size // self.num_attention_heads
147
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
148
+ if self.num_key_value_heads > self.num_attention_heads:
149
+ raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
150
+ self.cross_attn_layers = cross_attn_layers
151
+ self.self_attn_layers = self_attn_layers
152
+ self.global_attn_layers = global_attn_layers
153
+ self.attention_dropout = attention_dropout
154
+ self.attention_bias = attention_bias
155
+ self.w_init_variance_scale = w_init_variance_scale
156
+ self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
157
+ self.init_std = init_std
158
+ self.tie_word_embeddings = tie_word_embeddings
159
+ self.aux_heads = aux_heads
160
+ self.encoder_hidden_size=encoder_hidden_size
161
+ self.causal = causal
162
+ self.encoder_cross_attn_layers = encoder_cross_attn_layers
163
+ self.max_classes = max_classes
164
+ self.max_width = max_width
165
+ self.max_height = max_height
166
+ self.out_box_size = out_box_size
167
+
168
+ super().__init__(
169
+ pad_token_id=pad_token_id,
170
+ bos_token_id=bos_token_id,
171
+ eos_token_id=eos_token_id,
172
+ **kwargs,
173
+ )
174
+
175
+ @property
176
+ def layers_block_type(self):
177
+ return (self.block_types * 100)[: self.num_hidden_layers]
178
+
179
+
180
+ class SuryaTableRecTextEncoderConfig(PretrainedConfig):
181
+ model_type = "surya_tablerec"
182
+
183
+ def __init__(
184
+ self,
185
+ num_hidden_layers=4,
186
+ vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
187
+ hidden_size=1024,
188
+ intermediate_size=4 * 1024,
189
+ encoder_hidden_size=1024,
190
+ num_attention_heads=16,
191
+ lru_width=None,
192
+ attention_window_size=16,
193
+ conv1d_width=4,
194
+ logits_soft_cap=30.0,
195
+ rms_norm_eps=1e-6,
196
+ use_cache=True,
197
+ pad_token_id=0,
198
+ eos_token_id=1,
199
+ bos_token_id=2,
200
+ hidden_activation="gelu_pytorch_tanh",
201
+ rope_theta=10000.0,
202
+ block_types=("attention",),
203
+ cross_attn_layers=(0, 1, 2, 3, 4, 5),
204
+ self_attn_layers=(0, 1, 2, 3, 4, 5),
205
+ global_attn_layers=(0, 1, 2, 3, 4, 5),
206
+ attention_dropout=0.0,
207
+ num_key_value_heads=16,
208
+ attention_bias=False,
209
+ w_init_variance_scale=0.01,
210
+ init_std=0.02,
211
+ tie_word_embeddings=False,
212
+ causal=False,
213
+ max_width=BOX_DIM + SPECIAL_TOKENS,
214
+ max_height=BOX_DIM + SPECIAL_TOKENS,
215
+ max_position_embeddings=1024,
216
+ **kwargs,
217
+ ):
218
+ self.num_hidden_layers = num_hidden_layers
219
+ self.vocab_size = vocab_size
220
+ self.hidden_size = hidden_size
221
+ self.intermediate_size = intermediate_size
222
+ self.num_attention_heads = num_attention_heads
223
+ self.lru_width = lru_width if lru_width is not None else hidden_size
224
+ self.attention_window_size = attention_window_size
225
+ self.conv1d_width = conv1d_width
226
+ self.logits_soft_cap = logits_soft_cap
227
+ self.rms_norm_eps = rms_norm_eps
228
+ self.use_cache = use_cache
229
+ self.rope_theta = rope_theta
230
+ self.block_types = list(block_types)
231
+ self.hidden_activation = hidden_activation
232
+ self.head_dim = self.hidden_size // self.num_attention_heads
233
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
234
+ if self.num_key_value_heads > self.num_attention_heads:
235
+ raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
236
+ self.cross_attn_layers = cross_attn_layers
237
+ self.self_attn_layers = self_attn_layers
238
+ self.global_attn_layers = global_attn_layers
239
+ self.attention_dropout = attention_dropout
240
+ self.attention_bias = attention_bias
241
+ self.w_init_variance_scale = w_init_variance_scale
242
+ self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
243
+ self.init_std = init_std
244
+ self.tie_word_embeddings = tie_word_embeddings
245
+ self.encoder_hidden_size = encoder_hidden_size
246
+ self.causal = causal
247
+ self.max_width = max_width
248
+ self.max_height = max_height
249
+ self.max_position_embeddings = max_position_embeddings
250
+
251
+ super().__init__(
252
+ pad_token_id=pad_token_id,
253
+ bos_token_id=bos_token_id,
254
+ eos_token_id=eos_token_id,
255
+ **kwargs,
256
+ )
257
+
258
+ @property
259
+ def layers_block_type(self):
260
+ return (self.block_types * 100)[: self.num_hidden_layers]
surya/model/table_rec/decoder.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from transformers.utils import ModelOutput
8
+
9
+ from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
10
+ from transformers import PreTrainedModel
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
14
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
15
+
16
+ from surya.settings import settings
17
+
18
+ _MAX_SQRT_GRADIENT = 1000.0
19
+
20
+ @dataclass
21
+ class TableRecModelOutput(ModelOutput):
22
+ bbox_logits: torch.Tensor
23
+ class_logits: torch.Tensor | None = None
24
+ hidden_states: torch.Tensor | None = None
25
+
26
+
27
+ class SuryaTableRecDecoderRMSNorm(nn.Module):
28
+ def __init__(self, dim: int, eps: float = 1e-6):
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.weight = nn.Parameter(torch.zeros(dim))
32
+
33
+ def _norm(self, x):
34
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
35
+
36
+ def forward(self, x):
37
+ output = self._norm(x.float())
38
+ # Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
39
+ # See https://github.com/huggingface/transformers/pull/29402
40
+ output = output * (1.0 + self.weight.float())
41
+ return output.type_as(x)
42
+
43
+ def extra_repr(self):
44
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
45
+
46
+
47
+ ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)
48
+
49
+
50
+ class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
51
+ def __init__(self, dim, base=10000, device=None):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.base = base
55
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
56
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
57
+
58
+ @torch.no_grad()
59
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
60
+ def forward(self, x, position_ids, seq_len=None):
61
+ # x: [bs, num_attention_heads, seq_len, head_size]
62
+ self.inv_freq.to(x.device)
63
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
64
+ position_ids_expanded = position_ids[:, None, :].float()
65
+
66
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
67
+ emb = torch.cat((freqs, freqs), dim=-1)
68
+ cos = emb.cos()
69
+ sin = emb.sin()
70
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
71
+
72
+
73
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
74
+ def rotate_half(x):
75
+ """Rotates half the hidden dims of the input."""
76
+ x1 = x[..., : x.shape[-1] // 2]
77
+ x2 = x[..., x.shape[-1] // 2 :]
78
+ return torch.cat((-x2, x1), dim=-1)
79
+
80
+
81
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
82
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
83
+ """Applies Rotary Position Embedding to the query and key tensors.
84
+
85
+ Args:
86
+ q (`torch.Tensor`): The query tensor.
87
+ k (`torch.Tensor`): The key tensor.
88
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
89
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
90
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
91
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
92
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
93
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
94
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
95
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
96
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
97
+ Returns:
98
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
99
+ """
100
+ cos = cos.unsqueeze(unsqueeze_dim)
101
+ sin = sin.unsqueeze(unsqueeze_dim)
102
+ q_embed = (q * cos) + (rotate_half(q) * sin)
103
+ k_embed = (k * cos) + (rotate_half(k) * sin)
104
+ return q_embed, k_embed
105
+
106
+
107
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
108
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
109
+ """
110
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
111
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
112
+ """
113
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
114
+ if n_rep == 1:
115
+ return hidden_states
116
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
117
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
118
+
119
+
120
+ class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
121
+ """Multi-headed attention from 'Attention Is All You Need' paper
122
+ Modified for GQA
123
+ """
124
+
125
+ def __init__(self, config: SuryaTableRecDecoderConfig):
126
+ super().__init__()
127
+ self.config = config
128
+ self.attention_dropout = config.attention_dropout
129
+ self.hidden_size = config.hidden_size
130
+ self.num_attention_heads = config.num_attention_heads
131
+ self.head_dim = config.head_dim
132
+ self.num_key_value_heads = config.num_key_value_heads
133
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
134
+
135
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
136
+ self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
137
+ self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
138
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
139
+ self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
140
+ self.head_dim,
141
+ base=config.rope_theta,
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ encoder_hidden_states: torch.Tensor,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ encoder_attention_mask: Optional[torch.Tensor] = None,
150
+ use_cache: bool = False,
151
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
152
+ # Encoder attention mask currently ignored
153
+
154
+ bsz, q_len, _ = hidden_states.size()
155
+ _, v_len, _ = encoder_hidden_states.size()
156
+
157
+ query_states = self.q_proj(hidden_states)
158
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
159
+
160
+ if self.key_states is None:
161
+ key_states = self.k_proj(encoder_hidden_states)
162
+ value_states = self.v_proj(encoder_hidden_states)
163
+ key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
164
+ value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
165
+ if use_cache:
166
+ self._update_cache(key_states, value_states)
167
+ else:
168
+ key_states = self.key_states
169
+ value_states = self.value_states
170
+
171
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
172
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
173
+
174
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
175
+ query_states.contiguous(),
176
+ key_states.contiguous(),
177
+ value_states.contiguous(),
178
+ attn_mask=None,
179
+ dropout_p=self.attention_dropout if self.training else 0.0,
180
+ scale=self.head_dim**-0.5,
181
+ )
182
+
183
+ attn_output = attn_output.transpose(1, 2).contiguous()
184
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
185
+ attn_output = self.o_proj(attn_output)
186
+ return attn_output
187
+
188
+ def _setup_cache(self, batch_size, device, dtype=None):
189
+ # Setup initial caches
190
+ self.value_states = None
191
+ self.key_states = None
192
+
193
+ @torch.no_grad()
194
+ def _update_cache(self, key_states, value_states, **cache_kwargs):
195
+ self.value_states = value_states
196
+ self.key_states = key_states
197
+
198
+
199
+ class SuryaTableRecDecoderSdpaAttention(nn.Module):
200
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
201
+
202
+ def __init__(self, config: SuryaTableRecDecoderConfig):
203
+ super().__init__()
204
+ self.config = config
205
+ self.attention_dropout = config.attention_dropout
206
+ self.hidden_size = config.hidden_size
207
+ self.num_attention_heads = config.num_attention_heads
208
+ self.head_dim = config.head_dim
209
+ self.num_key_value_heads = config.num_key_value_heads
210
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
211
+
212
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
213
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
214
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
215
+ self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
216
+ self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
217
+ self.head_dim,
218
+ base=config.rope_theta,
219
+ )
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ position_ids: Optional[torch.LongTensor] = None,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ cache_position: Optional[torch.LongTensor] = None,
227
+ use_cache: bool = False,
228
+ window_attn: bool = False,
229
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
+ bsz, q_len, _ = hidden_states.size()
231
+
232
+ query_states = self.q_proj(hidden_states)
233
+ key_states = self.k_proj(hidden_states)
234
+ value_states = self.v_proj(hidden_states)
235
+
236
+ # Final is bsz, num_attention_heads, seq_len, head_dim
237
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
238
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
239
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
240
+
241
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
242
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
243
+
244
+ if use_cache and hasattr(self, "key_states"):
245
+ cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
246
+ key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
247
+
248
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
249
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
250
+
251
+ causal_mask = attention_mask
252
+ if attention_mask is not None:
253
+ # Mask is batch, head, seq_len, kv_len
254
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
255
+ current_cache_position = cache_position[-1].item() if cache_position is not None else None
256
+ if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
257
+ # Mask out future cache positions
258
+ position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
259
+ position_mask[:, :, :, :current_cache_position + 1] = False
260
+ causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
261
+
262
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
263
+ query_states.contiguous(),
264
+ key_states.contiguous(),
265
+ value_states.contiguous(),
266
+ attn_mask=causal_mask,
267
+ dropout_p=self.attention_dropout if self.training else 0.0,
268
+ scale=self.head_dim**-0.5,
269
+ )
270
+
271
+ attn_output = attn_output.transpose(1, 2).contiguous()
272
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
273
+ attn_output = self.o_proj(attn_output)
274
+ return attn_output
275
+
276
+ def _setup_cache(self, batch_size, device, dtype=None):
277
+ if dtype is None and self.config.torch_dtype is not None:
278
+ dtype = self.config.torch_dtype
279
+ dtype = dtype if dtype is not None else torch.float32
280
+
281
+ # Setup initial caches
282
+ self.value_states = None
283
+ self.key_states = None
284
+
285
+ if settings.RECOGNITION_STATIC_CACHE:
286
+ cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
287
+ self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
288
+ self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
289
+
290
+ def _update_static_cache(self, key_states, value_states, **cache_kwargs):
291
+ cache_position = cache_kwargs.get("cache_position")
292
+ k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
293
+
294
+ k_out[:, :, cache_position] = key_states.to(k_out.dtype)
295
+ v_out[:, :, cache_position] = value_states.to(v_out.dtype)
296
+
297
+ self.key_states, self.value_states = k_out, v_out
298
+ return k_out, v_out
299
+
300
+ def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
301
+ k_out = key_states
302
+ if self.key_states is not None:
303
+ k_out = torch.cat([self.key_states, key_states], dim=2)
304
+
305
+ v_out = value_states
306
+ if self.value_states is not None:
307
+ v_out = torch.cat([self.value_states, value_states], dim=2)
308
+
309
+ self.key_states, self.value_states = k_out, v_out
310
+ return k_out, v_out
311
+
312
+ @torch.no_grad()
313
+ def _update_cache(self, key_states, value_states, **cache_kwargs):
314
+ if settings.RECOGNITION_STATIC_CACHE:
315
+ return self._update_static_cache(key_states, value_states, **cache_kwargs)
316
+
317
+ return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
318
+
319
+
320
+ class SuryaTableRecDecoderMlp(nn.Module):
321
+ def __init__(self, config):
322
+ super().__init__()
323
+ self.config = config
324
+ self.hidden_size = config.hidden_size
325
+ self.intermediate_size = config.intermediate_size
326
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
327
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
328
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
329
+ if config.hidden_activation is None:
330
+ config.hidden_activation = "gelu_pytorch_tanh"
331
+ hidden_activation = config.hidden_activation
332
+ self.act_fn = ACT2FN[hidden_activation]
333
+
334
+ def forward(self, x):
335
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
336
+
337
+
338
+ class SuryaTableRecDecoderLayer(nn.Module):
339
+ def __init__(self, config, layer_idx):
340
+ super().__init__()
341
+ super().__init__()
342
+ self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
343
+ self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
344
+
345
+ self.temporal_block = None
346
+ if layer_idx in config.self_attn_layers:
347
+ self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)
348
+
349
+ self.cross_attn_block = None
350
+ if layer_idx in config.cross_attn_layers:
351
+ self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)
352
+
353
+ self.window_attn = layer_idx not in config.global_attn_layers
354
+ self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
355
+ self.mlp_block = SuryaTableRecDecoderMlp(config)
356
+
357
+ def forward(
358
+ self,
359
+ activations: torch.Tensor,
360
+ position_ids: torch.Tensor,
361
+ attention_mask: torch.Tensor,
362
+ encoder_hidden_states: torch.Tensor = None,
363
+ encoder_attention_mask: torch.Tensor = None,
364
+ cache_position: torch.Tensor = None,
365
+ use_cache: bool = None,
366
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
367
+ raw_activations = activations
368
+
369
+ if self.cross_attn_block is not None:
370
+ # Do cross-attention on encoder outputs
371
+ cross_attn_inputs = self.cross_pre_norm(activations)
372
+ cross_attn_path = self.cross_attn_block(
373
+ cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
374
+ )
375
+ cross_attn_output = cross_attn_path + raw_activations
376
+ else:
377
+ cross_attn_output = raw_activations
378
+
379
+ if self.temporal_block is not None:
380
+ inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
381
+ hidden_states = self.temporal_block(
382
+ inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
383
+ )
384
+
385
+ residual = hidden_states + raw_activations
386
+ else:
387
+ residual = cross_attn_output
388
+
389
+ hidden_states = self.channel_pre_norm(residual)
390
+ hidden_states = self.mlp_block(hidden_states)
391
+
392
+ hidden_states = hidden_states + residual
393
+ return hidden_states
394
+
395
+
396
+ class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
397
+ config_class = SuryaTableRecDecoderConfig
398
+ base_model_prefix = "model"
399
+ supports_gradient_checkpointing = True
400
+ _no_split_modules = ["SuryaTableRecDecoderLayer"]
401
+ _skip_keys_device_placement = ["cache"]
402
+ _supports_flash_attn_2 = False
403
+ _supports_sdpa = False # we can't compare with eager for now
404
+ _supports_cache_class = True
405
+ _supports_quantized_cache = True
406
+
407
+ def _init_weights(self, module):
408
+ if isinstance(module, SuryaTableRecDecoderSdpaAttention):
409
+ torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
410
+ torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
411
+ torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
412
+
413
+ torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
414
+ elif isinstance(module, nn.Linear):
415
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
416
+ if getattr(module, "bias", None) is not None:
417
+ torch.nn.init.zeros_(module.bias)
418
+ elif isinstance(module, nn.Embedding):
419
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
420
+ if module.padding_idx is not None:
421
+ module.weight.data[module.padding_idx].zero_()
422
+
423
+ def _setup_cache(self, config, batch, device, dtype):
424
+ layers = getattr(self, "model", self).layers
425
+ for layer in layers:
426
+ if layer.temporal_block:
427
+ layer.temporal_block._setup_cache(batch, device, dtype)
428
+ if layer.cross_attn_block:
429
+ layer.cross_attn_block._setup_cache(batch, device, dtype)
430
+
431
+ def reset_cache(self, batch, device, dtype):
432
+ pass
433
+
434
+ def _tie_weights(self):
435
+ pass
436
+
437
+ def tie_weights(self):
438
+ pass
439
+
440
+
441
+ class LabelEmbedding(nn.Module):
442
+ def __init__(self, config):
443
+ super().__init__()
444
+ self.vocab_size = config.vocab_size
445
+ self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
446
+ self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
447
+ self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
448
+ self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
449
+ self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
450
+ self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
451
+ self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
452
+ self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
453
+ self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
454
+ self.max_width = config.max_width
455
+ self.max_height = config.max_height
456
+ self.max_classes = config.max_classes
457
+
458
+ def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
459
+ cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
460
+ # Shape is (batch_size, num_boxes/seq len, d_model)
461
+ x1 = (cx - w // 2).long()
462
+ y1 = (cy - h // 2).long()
463
+ x2 = (cx + w // 2).long()
464
+ y2 = (cy + h // 2).long()
465
+ x1 = torch.clamp(x1, 0, self.max_width - 1)
466
+ y1 = torch.clamp(y1, 0, self.max_height - 1)
467
+ x2 = torch.clamp(x2, 0, self.max_width - 1)
468
+ y2 = torch.clamp(y2, 0, self.max_height - 1)
469
+
470
+ class_ = torch.clamp(class_, 0, self.max_classes - 1).long()
471
+
472
+ w = torch.clamp(w, 0, self.max_width - 1).long()
473
+ h = torch.clamp(h, 0, self.max_height - 1).long()
474
+ cx = torch.clamp(cx, 0, self.max_width - 1).long()
475
+ cy = torch.clamp(cy, 0, self.max_height - 1).long()
476
+
477
+ coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
478
+ class_embeds = self.class_embed(class_)
479
+ embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds
480
+
481
+ return embedded
482
+
483
+
484
+ class BboxEmbedding(nn.Module):
485
+ def __init__(self, config, embed_positions=False):
486
+ super().__init__()
487
+ self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
488
+ self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
489
+ self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
490
+ self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
491
+ self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
492
+ self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
493
+ self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
494
+ self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
495
+ self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
496
+ self.max_width = config.max_width
497
+ self.max_height = config.max_height
498
+ self.embed_positions = embed_positions
499
+
500
+ def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
501
+ x1, y1, x2, y2 = boxes.unbind(dim=-1)
502
+ x1 = torch.clamp(x1, 0, self.max_width - 1).long()
503
+ y1 = torch.clamp(y1, 0, self.max_height - 1).long()
504
+ x2 = torch.clamp(x2, 0, self.max_width - 1).long()
505
+ y2 = torch.clamp(y2, 0, self.max_height - 1).long()
506
+
507
+ # Shape is (batch_size, num_boxes/seq len, d_model)
508
+ w = x2 - x1
509
+ h = y2 - y1
510
+ # Center x and y in torch long tensors
511
+ cx = (x1 + x2) / 2
512
+ cy = (y1 + y2) / 2
513
+ cx = cx.long()
514
+ cy = cy.long()
515
+
516
+ w = torch.clamp(w, 0, self.max_width - 1).long()
517
+ h = torch.clamp(h, 0, self.max_height - 1).long()
518
+ cx = torch.clamp(cx, 0, self.max_width - 1).long()
519
+ cy = torch.clamp(cy, 0, self.max_height - 1).long()
520
+
521
+ coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
522
+ embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
523
+
524
+ # Add in positional embeddings for the boxes and labels
525
+ if self.embed_positions:
526
+ for j in range(embedded.shape[0]):
527
+ box_start = input_box_counts[j, 0]
528
+ box_end = input_box_counts[j, 1] - 1 # Skip the sep token
529
+ box_count = box_end - box_start
530
+ embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
531
+
532
+ return embedded
533
+
534
+
535
+ class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
536
+ """
537
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]
538
+
539
+ Args:
540
+ config: SuryaTableRecDecoderConfig
541
+ """
542
+
543
+ def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
544
+ super().__init__(config)
545
+ self.padding_idx = config.pad_token_id
546
+ self.vocab_size = config.vocab_size
547
+ self.causal = config.causal
548
+
549
+ if embed_labels:
550
+ self.embed_tokens = LabelEmbedding(config)
551
+ else:
552
+ self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)
553
+
554
+ self.layers = nn.ModuleList(
555
+ [SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
556
+ )
557
+ self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
558
+ self.gradient_checkpointing = False
559
+
560
+ self.register_buffer(
561
+ "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
562
+ )
563
+ # Initialize weights and apply final processing
564
+ self.post_init()
565
+
566
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
567
+ def get_input_embeddings(self):
568
+ return self.embed_tokens
569
+
570
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
571
+ def set_input_embeddings(self, value):
572
+ self.embed_tokens = value
573
+
574
+ def forward(
575
+ self,
576
+ input_ids: torch.LongTensor = None,
577
+ input_boxes_counts: torch.LongTensor = None,
578
+ position_ids: Optional[torch.LongTensor] = None,
579
+ attention_mask: Optional[torch.Tensor] = None,
580
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
581
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
582
+ cache_position: Optional[torch.LongTensor] = None,
583
+ use_cache: Optional[bool] = None,
584
+ output_hidden_states: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ prefill: bool = False
587
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
588
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
589
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
590
+
591
+ if self.gradient_checkpointing and self.training and use_cache:
592
+ use_cache = False
593
+
594
+ inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
595
+ hidden_states = inputs_embeds
596
+
597
+ if use_cache and prefill:
598
+ self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
599
+
600
+ if cache_position is None:
601
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
602
+ if position_ids is None:
603
+ position_ids = cache_position.unsqueeze(0)
604
+
605
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
606
+
607
+ all_hidden_states = () if output_hidden_states else None
608
+ for i, residual_block in enumerate(self.layers):
609
+ if output_hidden_states:
610
+ all_hidden_states += (hidden_states,)
611
+ if self.gradient_checkpointing and self.training:
612
+ hidden_states = self._gradient_checkpointing_func(
613
+ residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
614
+ )
615
+ else:
616
+ hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
617
+
618
+ hidden_states = self.final_norm(hidden_states)
619
+
620
+ # add hidden states from the last decoder layer
621
+ if output_hidden_states:
622
+ all_hidden_states += (hidden_states,)
623
+
624
+ if not return_dict:
625
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
626
+
627
+ return BaseModelOutputWithNoAttention(
628
+ last_hidden_state=hidden_states,
629
+ hidden_states=all_hidden_states,
630
+ )
631
+
632
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
633
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
634
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
635
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
636
+ # Ignore copy
637
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
638
+ if not self.causal:
639
+ return None
640
+
641
+ dtype, device = input_tensor.dtype, input_tensor.device
642
+ min_dtype = torch.finfo(dtype).min
643
+ sequence_length = input_tensor.shape[1]
644
+ target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)
645
+
646
+ diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
647
+ causal_mask = diagonal
648
+ if sequence_length != 1:
649
+ # Select the upper triangular part of the matrix, but unmask current token (the diagonal)
650
+ # triu will be the min_dtype, everything else is 0 (attended to)
651
+ causal_mask = torch.triu(diagonal, diagonal=1)
652
+
653
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
654
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
655
+ if attention_mask is not None:
656
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
657
+ if attention_mask.dim() == 2:
658
+ # Mask positions in the causal mask that are masked in the attention mask
659
+ mask_length = attention_mask.shape[-1]
660
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
661
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
662
+
663
+ if attention_mask is not None and attention_mask.device.type == "cuda":
664
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
665
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
666
+ # Details: https://github.com/pytorch/pytorch/issues/110213
667
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
668
+
669
+ return causal_mask
670
+
671
+
672
+ class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
673
+ _tied_weights_keys = None
674
+
675
+ def __init__(self, config, **kwargs):
676
+ super().__init__(config)
677
+ self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
678
+ self.vocab_size = config.vocab_size
679
+
680
+ self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
681
+ self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
682
+ self.max_width = config.max_width
683
+
684
+ # Initialize weights and apply final processing
685
+ self.post_init()
686
+
687
+ def get_input_embeddings(self):
688
+ return self.model.embed_tokens
689
+
690
+ def set_input_embeddings(self, value):
691
+ self.model.embed_tokens = value
692
+
693
+ def get_output_embeddings(self):
694
+ return self.lm_head
695
+
696
+ def set_output_embeddings(self, new_embeddings):
697
+ self.lm_head = new_embeddings
698
+
699
+ def set_decoder(self, decoder):
700
+ self.model = decoder
701
+
702
+ def get_decoder(self):
703
+ return self.model
704
+
705
+ # Ignore copy
706
+ def forward(
707
+ self,
708
+ input_ids: Optional[torch.LongTensor] = None,
709
+ cache_position: Optional[torch.LongTensor] = None,
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
712
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
713
+ use_cache: Optional[bool] = None,
714
+ prefill: bool = False,
715
+ **kwargs
716
+ ) -> Union[Tuple, TableRecModelOutput]:
717
+ outputs = self.model(
718
+ input_ids=input_ids,
719
+ cache_position=cache_position,
720
+ attention_mask=attention_mask,
721
+ encoder_hidden_states=encoder_hidden_states,
722
+ encoder_attention_mask=encoder_attention_mask,
723
+ use_cache=use_cache,
724
+ output_hidden_states=True,
725
+ return_dict=True,
726
+ prefill=prefill,
727
+ )
728
+
729
+ hidden_states = outputs[0]
730
+ bbox_logits = self.bbox_head(hidden_states)
731
+ class_logits = self.class_head(hidden_states)
732
+ bsz, seq_len = class_logits.shape[:2]
733
+ bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)
734
+
735
+ return TableRecModelOutput(
736
+ bbox_logits=bbox_logits,
737
+ class_logits=class_logits,
738
+ hidden_states=hidden_states,
739
+ )
740
+ @dataclass
741
+ class TextEncoderOutput(CausalLMOutput):
742
+ hidden_states: torch.FloatTensor = None
743
+
744
+
745
+ class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
746
+ _tied_weights_keys = None
747
+ config_class = SuryaTableRecTextEncoderConfig
748
+
749
+ def __init__(self, config, **kwargs):
750
+ super().__init__(config)
751
+ self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
752
+ self.vocab_size = config.vocab_size
753
+
754
+ # Initialize weights and apply final processing
755
+ self.post_init()
756
+
757
+ def get_input_embeddings(self):
758
+ return self.model.embed_tokens
759
+
760
+ def set_input_embeddings(self, value):
761
+ self.model.embed_tokens = value
762
+
763
+ def set_decoder(self, decoder):
764
+ self.model = decoder
765
+
766
+ def get_decoder(self):
767
+ return self.model
768
+
769
+ # Ignore copy
770
+ def forward(
771
+ self,
772
+ input_boxes: Optional[torch.LongTensor] = None,
773
+ input_boxes_counts: Optional[torch.LongTensor] = None,
774
+ cache_position: Optional[torch.LongTensor] = None,
775
+ attention_mask: Optional[torch.Tensor] = None,
776
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
777
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
778
+ use_cache: Optional[bool] = None,
779
+ **kwargs
780
+ ) -> Union[Tuple, CausalLMOutput]:
781
+ outputs = self.model(
782
+ input_ids=input_boxes,
783
+ input_boxes_counts=input_boxes_counts,
784
+ cache_position=cache_position,
785
+ attention_mask=attention_mask,
786
+ encoder_hidden_states=encoder_hidden_states,
787
+ encoder_attention_mask=encoder_attention_mask,
788
+ use_cache=use_cache,
789
+ output_hidden_states=True,
790
+ return_dict=True,
791
+ )
792
+
793
+ return TextEncoderOutput(
794
+ hidden_states=outputs.last_hidden_state,
795
+ )
surya/model/table_rec/encoderdecoder.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
9
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
10
+ from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
11
+ from surya.model.table_rec.decoder import SuryaTableRecTextEncoder, SuryaTableRecDecoder
12
+ from surya.model.recognition.encoder import DonutSwinModel
13
+ import torch.nn.functional as F
14
+ from transformers.utils import ModelOutput
15
+
16
+
17
+ @dataclass
18
+ class TableRecOutput(ModelOutput):
19
+ row_logits: torch.FloatTensor = None
20
+ col_logits: torch.FloatTensor = None
21
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
+
23
+
24
+ class TableRecEncoderDecoderModel(PreTrainedModel):
25
+ config_class = VisionEncoderDecoderConfig
26
+ base_model_prefix = "vision_encoder_decoder"
27
+ main_input_name = "pixel_values"
28
+ supports_gradient_checkpointing = True
29
+ _supports_param_buffer_assignment = False
30
+
31
+ def __init__(
32
+ self,
33
+ config: Optional[PretrainedConfig] = None,
34
+ encoder: Optional[PreTrainedModel] = None,
35
+ text_encoder: Optional[PreTrainedModel] = None,
36
+ decoder: Optional[PreTrainedModel] = None,
37
+ ):
38
+ # initialize with config
39
+ # make sure input & output embeddings is not tied
40
+ config.tie_word_embeddings = False
41
+ config.decoder.tie_word_embeddings = False
42
+ super().__init__(config)
43
+
44
+ if encoder is None:
45
+ encoder = DonutSwinModel(config.encoder)
46
+
47
+ if text_encoder is None:
48
+ text_encoder = SuryaTableRecTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
49
+
50
+ if decoder is None:
51
+ decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation)
52
+
53
+ self.encoder = encoder
54
+ self.decoder = decoder
55
+ self.text_encoder = text_encoder
56
+
57
+ # make sure that the individual model's config refers to the shared config
58
+ # so that the updates to the config will be synced
59
+ self.encoder.config = self.config.encoder
60
+ self.decoder.config = self.config.decoder
61
+ self.text_encoder.config = self.config.text_encoder
62
+
63
+ def get_encoder(self):
64
+ return self.encoder
65
+
66
+ def get_decoder(self):
67
+ return self.decoder
68
+
69
+ def get_output_embeddings(self):
70
+ return self.decoder.get_output_embeddings()
71
+
72
+ def set_output_embeddings(self, new_embeddings):
73
+ return self.decoder.set_output_embeddings(new_embeddings)
74
+
75
+ def forward(
76
+ self,
77
+ decoder_input_ids: torch.LongTensor = None,
78
+ decoder_cache_position: Optional[torch.LongTensor] = None,
79
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
80
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
81
+ use_cache: Optional[bool] = None,
82
+ return_dict: Optional[bool] = None,
83
+ **kwargs,
84
+ ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
85
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
86
+
87
+ kwargs_decoder = {
88
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
89
+ }
90
+
91
+ # Decode
92
+ decoder_outputs = self.decoder(
93
+ input_labels=decoder_input_ids,
94
+ input_boxes_counts=None,
95
+ cache_position=decoder_cache_position,
96
+ attention_mask=decoder_attention_mask,
97
+ encoder_hidden_states=encoder_outputs,
98
+ encoder_attention_mask=None,
99
+ use_cache=use_cache,
100
+ **kwargs_decoder,
101
+ )
102
+
103
+ return TableRecOutput(
104
+ row_logits=decoder_outputs.row_logits,
105
+ col_logits=decoder_outputs.col_logits,
106
+ decoder_hidden_states=decoder_outputs.hidden_states,
107
+ )
108
+
109
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
110
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
111
+
112
+ def prepare_inputs_for_generation(
113
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
114
+ ):
115
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
116
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
117
+ input_dict = {
118
+ "attention_mask": attention_mask,
119
+ "decoder_attention_mask": decoder_attention_mask,
120
+ "decoder_input_ids": decoder_inputs["input_ids"],
121
+ "encoder_outputs": encoder_outputs,
122
+ "past_key_values": decoder_inputs["past_key_values"],
123
+ "use_cache": use_cache,
124
+ }
125
+ return input_dict
126
+
127
+ def resize_token_embeddings(self, *args, **kwargs):
128
+ raise NotImplementedError(
129
+ "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
130
+ " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
131
+ )
132
+
133
+ def _reorder_cache(self, past_key_values, beam_idx):
134
+ # apply decoder cache reordering here
135
+ return self.decoder._reorder_cache(past_key_values, beam_idx)
surya/model/table_rec/model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from surya.model.recognition.encoder import DonutSwinModel
2
+ from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
3
+ SuryaTableRecTextEncoderConfig
4
+ from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
5
+ from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
6
+ from surya.settings import settings
7
+
8
+
9
+ def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
10
+
11
+ config = SuryaTableRecConfig.from_pretrained(checkpoint)
12
+ decoder_config = config.decoder
13
+ decoder = SuryaTableRecDecoderConfig(**decoder_config)
14
+ config.decoder = decoder
15
+
16
+ encoder_config = config.encoder
17
+ encoder = DonutSwinTableRecConfig(**encoder_config)
18
+ config.encoder = encoder
19
+
20
+ text_encoder_config = config.text_encoder
21
+ text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config)
22
+ config.text_encoder = text_encoder
23
+
24
+ model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
25
+
26
+ assert isinstance(model.decoder, SuryaTableRecDecoder)
27
+ assert isinstance(model.encoder, DonutSwinModel)
28
+ assert isinstance(model.text_encoder, SuryaTableRecTextEncoder)
29
+
30
+ model = model.to(device)
31
+ model = model.eval()
32
+
33
+ print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
34
+ return model
surya/model/table_rec/processor.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, Union, Optional, List, Iterable
3
+
4
+ import cv2
5
+ import torch
6
+ from torch import TensorType
7
+ from transformers import DonutImageProcessor, DonutProcessor
8
+ from transformers.image_processing_utils import BatchFeature
9
+ from transformers.image_transforms import pad, normalize
10
+ from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
11
+ import numpy as np
12
+ from PIL import Image
13
+ import PIL
14
+ from surya.model.recognition.tokenizer import Byt5LangTokenizer
15
+ from surya.settings import settings
16
+ from surya.model.table_rec.config import BOX_DIM, SPECIAL_TOKENS
17
+
18
+
19
+ def load_processor():
20
+ processor = SuryaProcessor()
21
+ processor.image_processor.train = False
22
+ processor.image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
23
+
24
+ processor.token_pad_id = 0
25
+ processor.token_eos_id = 1
26
+ processor.token_bos_id = 2
27
+ processor.token_row_id = 3
28
+ processor.token_unused_id = 4
29
+ processor.box_size = (BOX_DIM, BOX_DIM)
30
+ processor.special_token_count = SPECIAL_TOKENS
31
+ return processor
32
+
33
+
34
+ class SuryaImageProcessor(DonutImageProcessor):
35
+ def __init__(self, *args, max_size=None, train=False, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+
38
+ self.patch_size = kwargs.get("patch_size", (4, 4))
39
+ self.max_size = max_size
40
+ self.train = train
41
+
42
+ @classmethod
43
+ def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
44
+ max_width, max_height = size["width"], size["height"]
45
+
46
+ resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
47
+ resized_image = resized_image.transpose(2, 0, 1)
48
+
49
+ return resized_image
50
+
51
+ def process_inner(self, images: List[np.ndarray]):
52
+ assert images[0].shape[2] == 3 # RGB input images, channel dim last
53
+
54
+ # This also applies the right channel dim format, to channel x height x width
55
+ images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
56
+ assert images[0].shape[0] == 3 # RGB input images, channel dim first
57
+
58
+ # Convert to float32 for rescale/normalize
59
+ images = [img.astype(np.float32) for img in images]
60
+
61
+ # Pads with 255 (whitespace)
62
+ # Pad to max size to improve performance
63
+ max_size = self.max_size
64
+ images = [
65
+ SuryaImageProcessor.pad_image(
66
+ image=image,
67
+ size=max_size,
68
+ input_data_format=ChannelDimension.FIRST,
69
+ pad_value=settings.RECOGNITION_PAD_VALUE
70
+ )
71
+ for image in images
72
+ ]
73
+ # Rescale and normalize
74
+ for idx in range(len(images)):
75
+ images[idx] = images[idx] * self.rescale_factor
76
+ images = [
77
+ SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
78
+ for img in images
79
+ ]
80
+
81
+ return images
82
+
83
+ def preprocess(
84
+ self,
85
+ images: ImageInput,
86
+ do_resize: bool = None,
87
+ size: Dict[str, int] = None,
88
+ resample: PILImageResampling = None,
89
+ do_thumbnail: bool = None,
90
+ do_align_long_axis: bool = None,
91
+ do_pad: bool = None,
92
+ random_padding: bool = False,
93
+ do_rescale: bool = None,
94
+ rescale_factor: float = None,
95
+ do_normalize: bool = None,
96
+ image_mean: Optional[Union[float, List[float]]] = None,
97
+ image_std: Optional[Union[float, List[float]]] = None,
98
+ return_tensors: Optional[Union[str, TensorType]] = None,
99
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
100
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
101
+ **kwargs,
102
+ ) -> PIL.Image.Image:
103
+ images = make_list_of_images(images)
104
+
105
+ # Convert to numpy for later processing steps
106
+ images = [np.array(img) for img in images]
107
+ images = self.process_inner(images)
108
+
109
+ data = {"pixel_values": images}
110
+ return BatchFeature(data=data, tensor_type=return_tensors)
111
+
112
+ @classmethod
113
+ def pad_image(
114
+ cls,
115
+ image: np.ndarray,
116
+ size: Dict[str, int],
117
+ data_format: Optional[Union[str, ChannelDimension]] = None,
118
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
119
+ pad_value: float = 0.0,
120
+ ) -> np.ndarray:
121
+ output_height, output_width = size["height"], size["width"]
122
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
123
+
124
+ delta_width = output_width - input_width
125
+ delta_height = output_height - input_height
126
+
127
+ assert delta_width >= 0 and delta_height >= 0
128
+
129
+ pad_top = delta_height // 2
130
+ pad_left = delta_width // 2
131
+
132
+ pad_bottom = delta_height - pad_top
133
+ pad_right = delta_width - pad_left
134
+
135
+ padding = ((pad_top, pad_bottom), (pad_left, pad_right))
136
+ return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
137
+
138
+ @classmethod
139
+ def align_long_axis(
140
+ cls,
141
+ image: np.ndarray,
142
+ size: Dict[str, int],
143
+ data_format: Optional[Union[str, ChannelDimension]] = None,
144
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
145
+ ) -> np.ndarray:
146
+ input_height, input_width = image.shape[:2]
147
+ output_height, output_width = size["height"], size["width"]
148
+
149
+ if (output_width < output_height and input_width > input_height) or (
150
+ output_width > output_height and input_width < input_height
151
+ ):
152
+ image = np.rot90(image, 3)
153
+
154
+ return image
155
+
156
+ @classmethod
157
+ def normalize(
158
+ cls,
159
+ image: np.ndarray,
160
+ mean: Union[float, Iterable[float]],
161
+ std: Union[float, Iterable[float]],
162
+ data_format: Optional[Union[str, ChannelDimension]] = None,
163
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
164
+ **kwargs,
165
+ ) -> np.ndarray:
166
+ return normalize(
167
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
168
+ )
169
+
170
+
171
+ class SuryaProcessor(DonutProcessor):
172
+ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
173
+ image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
174
+ if image_processor is None:
175
+ raise ValueError("You need to specify an `image_processor`.")
176
+
177
+ tokenizer = Byt5LangTokenizer()
178
+ super().__init__(image_processor, tokenizer)
179
+ self.current_processor = self.image_processor
180
+ self._in_target_context_manager = False
181
+ self.max_input_boxes = kwargs.get("max_input_boxes", 256)
182
+ self.extra_input_boxes = kwargs.get("extra_input_boxes", 32)
183
+
184
+ def resize_boxes(self, img, boxes):
185
+ width, height = img.size
186
+ box_width, box_height = self.box_size
187
+ for box in boxes:
188
+ # Rescale to 0-1024
189
+ box[0] = box[0] / width * box_width
190
+ box[1] = box[1] / height * box_height
191
+ box[2] = box[2] / width * box_width
192
+ box[3] = box[3] / height * box_height
193
+
194
+ if box[0] < 0:
195
+ box[0] = 0
196
+ if box[1] < 0:
197
+ box[1] = 0
198
+ if box[2] > box_width:
199
+ box[2] = box_width
200
+ if box[3] > box_height:
201
+ box[3] = box_height
202
+
203
+ return boxes
204
+
205
+ def __call__(self, *args, **kwargs):
206
+ images = kwargs.pop("images", [])
207
+ boxes = kwargs.pop("boxes", [])
208
+ assert len(images) == len(boxes)
209
+
210
+ if len(args) > 0:
211
+ images = args[0]
212
+ args = args[1:]
213
+
214
+ for i in range(len(boxes)):
215
+ if len(boxes[i]) > self.max_input_boxes:
216
+ downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes)
217
+ boxes[i] = boxes[i][::downsample_ratio]
218
+
219
+ new_boxes = []
220
+ max_len = self.max_input_boxes + self.extra_input_boxes
221
+ box_masks = []
222
+ box_ends = []
223
+ for i in range(len(boxes)):
224
+ nb = self.resize_boxes(images[i], boxes[i])
225
+ nb = [[b + self.special_token_count for b in box] for box in nb] # shift up
226
+ nb = nb[:self.max_input_boxes - 1]
227
+
228
+ nb.insert(0, [self.token_row_id] * 4) # Insert special token for max rows/cols
229
+ for _ in range(self.extra_input_boxes):
230
+ nb.append([self.token_unused_id] * 4)
231
+
232
+ pad_length = max_len - len(nb)
233
+ box_mask = [1] * len(nb) + [1] * (pad_length)
234
+ box_ends.append(len(nb))
235
+ nb = nb + [[self.token_unused_id] * 4] * pad_length
236
+
237
+ new_boxes.append(nb)
238
+ box_masks.append(box_mask)
239
+
240
+ box_ends = torch.tensor(box_ends, dtype=torch.long)
241
+ box_starts = torch.tensor([0] * len(boxes), dtype=torch.long)
242
+ box_ranges = torch.stack([box_starts, box_ends], dim=1)
243
+
244
+ inputs = self.image_processor(images, *args, **kwargs)
245
+ inputs["input_boxes"] = torch.tensor(new_boxes, dtype=torch.long)
246
+ inputs["input_boxes_mask"] = torch.tensor(box_masks, dtype=torch.long)
247
+ inputs["input_boxes_counts"] = box_ranges
248
+ return inputs
surya/ocr.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List
3
+ from PIL import Image
4
+
5
+ from surya.detection import batch_text_detection
6
+ from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb
7
+ from surya.postprocessing.text import sort_text_lines
8
+ from surya.recognition import batch_recognition
9
+ from surya.schema import TextLine, OCRResult
10
+
11
+
12
+ def run_recognition(images: List[Image.Image], langs: List[List[str] | None], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]:
13
+ # Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format
14
+ assert bboxes is not None or polygons is not None
15
+ assert len(images) == len(langs), "You need to pass in one list of languages for each image"
16
+
17
+ images = convert_if_not_rgb(images)
18
+
19
+ slice_map = []
20
+ all_slices = []
21
+ all_langs = []
22
+ for idx, (image, lang) in enumerate(zip(images, langs)):
23
+ if polygons is not None:
24
+ slices = slice_polys_from_image(image, polygons[idx])
25
+ else:
26
+ slices = slice_bboxes_from_image(image, bboxes[idx])
27
+ slice_map.append(len(slices))
28
+ all_slices.extend(slices)
29
+ all_langs.extend([deepcopy(lang)] * len(slices))
30
+
31
+ rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)
32
+
33
+ predictions_by_image = []
34
+ slice_start = 0
35
+ for idx, (image, lang) in enumerate(zip(images, langs)):
36
+ slice_end = slice_start + slice_map[idx]
37
+ image_lines = rec_predictions[slice_start:slice_end]
38
+ slice_start = slice_end
39
+
40
+ text_lines = []
41
+ for i in range(len(image_lines)):
42
+ if polygons is not None:
43
+ poly = polygons[idx][i]
44
+ else:
45
+ bbox = bboxes[idx][i]
46
+ poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]
47
+
48
+ text_lines.append(TextLine(
49
+ text=image_lines[i],
50
+ polygon=poly
51
+ ))
52
+
53
+ pred = OCRResult(
54
+ text_lines=text_lines,
55
+ languages=lang,
56
+ image_bbox=[0, 0, image.size[0], image.size[1]]
57
+ )
58
+ predictions_by_image.append(pred)
59
+
60
+ return predictions_by_image
61
+
62
+
63
+ def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
64
+ images = convert_if_not_rgb(images)
65
+ highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images)
66
+ det_predictions = batch_text_detection(images, det_model, det_processor)
67
+
68
+ all_slices = []
69
+ slice_map = []
70
+ all_langs = []
71
+
72
+ for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)):
73
+ polygons = [p.polygon for p in det_pred.bboxes]
74
+ if highres_image:
75
+ width_scaler = highres_image.size[0] / image.size[0]
76
+ height_scaler = highres_image.size[1] / image.size[1]
77
+ scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons]
78
+ slices = slice_polys_from_image(highres_image, scaled_polygons)
79
+ else:
80
+ slices = slice_polys_from_image(image, polygons)
81
+ slice_map.append(len(slices))
82
+ all_langs.extend([lang] * len(slices))
83
+ all_slices.extend(slices)
84
+
85
+ rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)
86
+
87
+ predictions_by_image = []
88
+ slice_start = 0
89
+ for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
90
+ slice_end = slice_start + slice_map[idx]
91
+ image_lines = rec_predictions[slice_start:slice_end]
92
+ line_confidences = confidence_scores[slice_start:slice_end]
93
+ slice_start = slice_end
94
+
95
+ assert len(image_lines) == len(det_pred.bboxes)
96
+
97
+ lines = []
98
+ for text_line, confidence, bbox in zip(image_lines, line_confidences, det_pred.bboxes):
99
+ lines.append(TextLine(
100
+ text=text_line,
101
+ polygon=bbox.polygon,
102
+ bbox=bbox.bbox,
103
+ confidence=confidence
104
+ ))
105
+
106
+ lines = sort_text_lines(lines)
107
+
108
+ predictions_by_image.append(OCRResult(
109
+ text_lines=lines,
110
+ languages=lang,
111
+ image_bbox=det_pred.image_bbox
112
+ ))
113
+
114
+ return predictions_by_image
surya/ordering.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from surya.input.processing import convert_if_not_rgb
7
+ from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
8
+ from surya.schema import OrderBox, OrderResult
9
+ from surya.settings import settings
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+
14
+ def get_batch_size():
15
+ batch_size = settings.ORDER_BATCH_SIZE
16
+ if batch_size is None:
17
+ batch_size = 8
18
+ if settings.TORCH_DEVICE_MODEL == "mps":
19
+ batch_size = 8
20
+ if settings.TORCH_DEVICE_MODEL == "cuda":
21
+ batch_size = 32
22
+ return batch_size
23
+
24
+
25
+ def rank_elements(arr):
26
+ enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1])
27
+ rank = [0] * len(arr)
28
+
29
+ for rank_value, (original_index, value) in enumerate(enumerated_and_sorted):
30
+ rank[original_index] = rank_value
31
+
32
+ return rank
33
+
34
+
35
+ def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]:
36
+ assert all([isinstance(image, Image.Image) for image in images])
37
+ assert len(images) == len(bboxes)
38
+ if batch_size is None:
39
+ batch_size = get_batch_size()
40
+
41
+
42
+ output_order = []
43
+ for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"):
44
+ batch_bboxes = deepcopy(bboxes[i:i+batch_size])
45
+ batch_images = images[i:i+batch_size]
46
+ batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
47
+
48
+ orig_sizes = [image.size for image in batch_images]
49
+ model_inputs = processor(images=batch_images, boxes=batch_bboxes)
50
+
51
+ batch_pixel_values = model_inputs["pixel_values"]
52
+ batch_bboxes = model_inputs["input_boxes"]
53
+ batch_bbox_mask = model_inputs["input_boxes_mask"]
54
+ batch_bbox_counts = model_inputs["input_boxes_counts"]
55
+
56
+ batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
57
+ batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
58
+ batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
59
+ batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
60
+
61
+ token_count = 0
62
+ past_key_values = None
63
+ encoder_outputs = None
64
+ batch_predictions = [[] for _ in range(len(batch_images))]
65
+ done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)
66
+
67
+ with torch.inference_mode():
68
+ while token_count < settings.ORDER_MAX_BOXES:
69
+ return_dict = model(
70
+ pixel_values=batch_pixel_values,
71
+ decoder_input_boxes=batch_bboxes,
72
+ decoder_input_boxes_mask=batch_bbox_mask,
73
+ decoder_input_boxes_counts=batch_bbox_counts,
74
+ encoder_outputs=encoder_outputs,
75
+ past_key_values=past_key_values,
76
+ )
77
+ logits = return_dict["logits"].detach()
78
+
79
+ last_tokens = []
80
+ last_token_mask = []
81
+ min_val = torch.finfo(model.dtype).min
82
+ for j in range(logits.shape[0]):
83
+ label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
84
+ new_logits = logits[j, -1]
85
+ new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
86
+ new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
87
+ pred = int(torch.argmax(new_logits, dim=-1).item())
88
+
89
+ # Add one to avoid colliding with the 1000 height/width token for bboxes
90
+ last_tokens.append([[pred + processor.box_size["height"] + 1] * 4])
91
+ if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label
92
+ last_token_mask.append([0])
93
+ batch_predictions[j].append(pred)
94
+ done[j] = True
95
+ elif len(batch_predictions[j]) < label_count - 1:
96
+ last_token_mask.append([1])
97
+ batch_predictions[j].append(pred) # Get rank prediction for given position
98
+ else:
99
+ last_token_mask.append([0])
100
+
101
+ if done.all():
102
+ break
103
+
104
+ past_key_values = return_dict["past_key_values"]
105
+ encoder_outputs = (return_dict["encoder_last_hidden_state"],)
106
+
107
+ batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
108
+ token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
109
+ batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1)
110
+ token_count += 1
111
+
112
+ for j, row_pred in enumerate(batch_predictions):
113
+ row_bboxes = bboxes[i+j]
114
+ assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}"
115
+
116
+ orig_size = orig_sizes[j]
117
+ ranks = [0] * len(row_bboxes)
118
+
119
+ for box_idx in range(len(row_bboxes)):
120
+ ranks[row_pred[box_idx]] = box_idx
121
+
122
+ order_boxes = []
123
+ for row_bbox, rank in zip(row_bboxes, ranks):
124
+ order_box = OrderBox(
125
+ bbox=row_bbox,
126
+ position=rank,
127
+ )
128
+ order_boxes.append(order_box)
129
+
130
+ result = OrderResult(
131
+ bboxes=order_boxes,
132
+ image_bbox=[0, 0, orig_size[0], orig_size[1]],
133
+ )
134
+ output_order.append(result)
135
+ return output_order
136
+
137
+
138
+
139
+
140
+
141
+
surya/postprocessing/affinity.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from PIL import Image, ImageDraw
7
+
8
+ from surya.postprocessing.util import get_line_angle, rescale_bbox
9
+ from surya.schema import ColumnLine
10
+
11
+
12
+ def get_detected_lines_sobel(image, vertical=True):
13
+ # Apply Sobel operator with a kernel size of 3 to detect vertical edges
14
+ if vertical:
15
+ dx = 1
16
+ dy = 0
17
+ else:
18
+ dx = 0
19
+ dy = 1
20
+
21
+ sobelx = cv2.Sobel(image, cv2.CV_32F, dx, dy, ksize=3)
22
+
23
+
24
+ # Absolute Sobel (to capture both edges)
25
+ abs_sobelx = np.absolute(sobelx)
26
+
27
+ # Convert to 8-bit image
28
+ scaled_sobel = np.uint8(255 * abs_sobelx / np.max(abs_sobelx))
29
+
30
+ kernel = np.ones((20, 1), np.uint8)
31
+ eroded = cv2.erode(scaled_sobel, kernel, iterations=1)
32
+ scaled_sobel = cv2.dilate(eroded, kernel, iterations=3)
33
+
34
+ return scaled_sobel
35
+
36
+
37
+ def get_detected_lines(image, slope_tol_deg=2, vertical=False, horizontal=False) -> List[ColumnLine]:
38
+ assert not (vertical and horizontal)
39
+ new_image = image.astype(np.float32) * 255 # Convert to 0-255 range
40
+ if vertical or horizontal:
41
+ new_image = get_detected_lines_sobel(new_image, vertical)
42
+ new_image = new_image.astype(np.uint8)
43
+
44
+ edges = cv2.Canny(new_image, 150, 200, apertureSize=3)
45
+ if vertical:
46
+ max_gap = 100
47
+ min_length = 10
48
+ else:
49
+ max_gap = 10
50
+ min_length = 4
51
+
52
+ lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=150, minLineLength=min_length, maxLineGap=max_gap)
53
+
54
+ line_info = []
55
+ if lines is not None:
56
+ for line in lines:
57
+ vertical_line = False
58
+ horizontal_line = False
59
+ x1, y1, x2, y2 = line[0]
60
+ bbox = [x1, y1, x2, y2]
61
+
62
+ if x2 == x1:
63
+ vertical_line = True
64
+ else:
65
+ line_angle = get_line_angle(x1, y1, x2, y2)
66
+ if 90 - slope_tol_deg < line_angle < 90 + slope_tol_deg:
67
+ vertical_line = True
68
+ elif -90 - slope_tol_deg < line_angle < -90 + slope_tol_deg:
69
+ vertical_line = True
70
+ elif -slope_tol_deg < line_angle < slope_tol_deg:
71
+ horizontal_line = True
72
+
73
+ if bbox[3] < bbox[1]:
74
+ bbox[1], bbox[3] = bbox[3], bbox[1]
75
+ if bbox[2] < bbox[0]:
76
+ bbox[0], bbox[2] = bbox[2], bbox[0]
77
+ row = ColumnLine(bbox=bbox, vertical=vertical_line, horizontal=horizontal_line)
78
+ line_info.append(row)
79
+
80
+ if vertical:
81
+ line_info = [line for line in line_info if line.vertical]
82
+
83
+ if horizontal:
84
+ line_info = [line for line in line_info if line.horizontal]
85
+
86
+ return line_info
87
+
88
+
89
+ def draw_lines_on_image(line_info: List[ColumnLine], img):
90
+ draw = ImageDraw.Draw(img)
91
+
92
+ for line in line_info:
93
+ divisor = 20
94
+ if line.horizontal:
95
+ divisor = 200
96
+ x1, y1, x2, y2 = [x // divisor * divisor for x in line.bbox]
97
+ if line.vertical:
98
+ draw.line((x1, y1, x2, y2), fill="red", width=3)
99
+
100
+ return img
101
+
102
+
103
+ def get_vertical_lines(image, processor_size, image_size, divisor=20, x_tolerance=40, y_tolerance=20) -> List[ColumnLine]:
104
+ vertical_lines = get_detected_lines(image, vertical=True)
105
+ for line in vertical_lines:
106
+ line.rescale_bbox(processor_size, image_size)
107
+ vertical_lines = sorted(vertical_lines, key=lambda x: x.bbox[0])
108
+ for line in vertical_lines:
109
+ line.round_bbox(divisor)
110
+
111
+ # Merge adjacent line segments together
112
+ to_remove = []
113
+ for i, line in enumerate(vertical_lines):
114
+ for j, line2 in enumerate(vertical_lines):
115
+ if j <= i:
116
+ continue
117
+ if line.bbox[0] != line2.bbox[0]:
118
+ continue
119
+
120
+ expanded_line1 = [line.bbox[0], line.bbox[1] - y_tolerance, line.bbox[2],
121
+ line.bbox[3] + y_tolerance]
122
+
123
+ line1_points = set(range(int(expanded_line1[1]), int(expanded_line1[3])))
124
+ line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3])))
125
+ intersect_y = len(line1_points.intersection(line2_points)) > 0
126
+
127
+ if intersect_y:
128
+ vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1])
129
+ vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3])
130
+ to_remove.append(i)
131
+
132
+ vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove]
133
+
134
+ # Remove redundant segments
135
+ to_remove = []
136
+ for i, line in enumerate(vertical_lines):
137
+ if i in to_remove:
138
+ continue
139
+ for j, line2 in enumerate(vertical_lines):
140
+ if j <= i or j in to_remove:
141
+ continue
142
+ close_in_x = abs(line.bbox[0] - line2.bbox[0]) < x_tolerance
143
+ line1_points = set(range(int(line.bbox[1]), int(line.bbox[3])))
144
+ line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3])))
145
+
146
+ intersect_y = len(line1_points.intersection(line2_points)) > 0
147
+
148
+ if close_in_x and intersect_y:
149
+ # Keep the longer line and extend it
150
+ if len(line2_points) > len(line1_points):
151
+ vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1])
152
+ vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3])
153
+ to_remove.append(i)
154
+ else:
155
+ vertical_lines[i].bbox[1] = min(line.bbox[1], line2.bbox[1])
156
+ vertical_lines[i].bbox[3] = max(line.bbox[3], line2.bbox[3])
157
+ to_remove.append(j)
158
+
159
+ vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove]
160
+
161
+ if len(vertical_lines) > 0:
162
+ # Always start with top left of page
163
+ vertical_lines[0].bbox[1] = 0
164
+
165
+ return vertical_lines
surya/postprocessing/fonts.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import os
3
+ import requests
4
+
5
+ from surya.settings import settings
6
+
7
+
8
+ def get_font_path(langs: Optional[List[str]] = None) -> str:
9
+ font_path = settings.RECOGNITION_RENDER_FONTS["all"]
10
+ if langs is not None:
11
+ for k in settings.RECOGNITION_RENDER_FONTS:
12
+ if k in langs and len(langs) == 1:
13
+ font_path = settings.RECOGNITION_RENDER_FONTS[k]
14
+ break
15
+
16
+ if not os.path.exists(font_path):
17
+ os.makedirs(os.path.dirname(font_path), exist_ok=True)
18
+ font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}"
19
+ with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f:
20
+ r.raise_for_status()
21
+ for chunk in r.iter_content(chunk_size=8192):
22
+ f.write(chunk)
23
+
24
+ return font_path
surya/postprocessing/heatmap.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import math
6
+ from PIL import ImageDraw, ImageFont
7
+
8
+ from surya.postprocessing.fonts import get_font_path
9
+ from surya.postprocessing.util import rescale_bbox
10
+ from surya.schema import PolygonBox
11
+ from surya.settings import settings
12
+ from surya.postprocessing.text import get_text_size
13
+
14
+
15
+ def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
16
+ new_boxes = []
17
+ for box_obj in boxes:
18
+ box = box_obj.bbox
19
+ box_area = (box[2] - box[0]) * (box[3] - box[1])
20
+ contained = False
21
+ for other_box_obj in boxes:
22
+ if other_box_obj.polygon == box_obj.polygon:
23
+ continue
24
+
25
+ other_box = other_box_obj.bbox
26
+ other_box_area = (other_box[2] - other_box[0]) * (other_box[3] - other_box[1])
27
+ if box == other_box:
28
+ continue
29
+ # find overlap percentage
30
+ overlap = box_obj.intersection_pct(other_box_obj)
31
+ if overlap > .9 and box_area < other_box_area:
32
+ contained = True
33
+ break
34
+ if not contained:
35
+ new_boxes.append(box_obj)
36
+ return new_boxes
37
+
38
+
39
+ def clean_contained_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
40
+ new_boxes = []
41
+ for box_obj in boxes:
42
+ box = box_obj.bbox
43
+ contained = False
44
+ for other_box_obj in boxes:
45
+ if other_box_obj.polygon == box_obj.polygon:
46
+ continue
47
+
48
+ other_box = other_box_obj.bbox
49
+ if box == other_box:
50
+ continue
51
+ if box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3]:
52
+ contained = True
53
+ break
54
+ if not contained:
55
+ new_boxes.append(box_obj)
56
+ return new_boxes
57
+
58
+
59
+ def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7):
60
+ # Find average intensity of top 10% pixels
61
+ flat_map = linemap.ravel()
62
+ top_10_count = int(len(flat_map) * 0.9)
63
+ avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:])
64
+ scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2)
65
+
66
+ low_text = np.clip(low_text * scaling_factor, 0.1, 0.6)
67
+ text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8)
68
+
69
+ return text_threshold, low_text
70
+
71
+
72
+ def detect_boxes(linemap, text_threshold, low_text):
73
+ # From CRAFT - https://github.com/clovaai/CRAFT-pytorch
74
+ # Modified to return boxes and for speed, accuracy
75
+ img_h, img_w = linemap.shape
76
+
77
+ text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text)
78
+
79
+ text_score_comb = (linemap > low_text).astype(np.uint8)
80
+ label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb, connectivity=4)
81
+
82
+ det = []
83
+ confidences = []
84
+ max_confidence = 0
85
+
86
+ for k in range(1, label_count):
87
+ # size filtering
88
+ size = stats[k, cv2.CC_STAT_AREA]
89
+ if size < 10:
90
+ continue
91
+
92
+ # make segmentation map
93
+ x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]]
94
+
95
+ try:
96
+ niter = int(np.sqrt(min(w, h)))
97
+ except ValueError:
98
+ niter = 0
99
+
100
+ buffer = 1
101
+ sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
102
+ ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)
103
+
104
+ mask = (labels[sy:ey, sx:ex] == k)
105
+ selected_linemap = linemap[sy:ey, sx:ex][mask]
106
+ line_max = np.max(selected_linemap)
107
+
108
+ # thresholding
109
+ if line_max < text_threshold:
110
+ continue
111
+
112
+ segmap = mask.astype(np.uint8)
113
+
114
+ ksize = buffer + niter
115
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize))
116
+ selected_segmap = cv2.dilate(segmap, kernel)
117
+
118
+ # make box
119
+ indices = np.nonzero(selected_segmap)
120
+ x_inds = indices[1] + sx
121
+ y_inds = indices[0] + sy
122
+ np_contours = np.column_stack((x_inds, y_inds))
123
+ rectangle = cv2.minAreaRect(np_contours)
124
+ box = cv2.boxPoints(rectangle)
125
+
126
+ # align diamond-shape
127
+ w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
128
+ box_ratio = max(w, h) / (min(w, h) + 1e-5)
129
+ if abs(1 - box_ratio) <= 0.1:
130
+ l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
131
+ t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
132
+ box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
133
+
134
+ # make clock-wise order
135
+ startidx = box.sum(axis=1).argmin()
136
+ box = np.roll(box, 4-startidx, 0)
137
+ box = np.array(box)
138
+
139
+ confidence = line_max
140
+ max_confidence = max(max_confidence, line_max)
141
+
142
+ confidences.append(confidence)
143
+ det.append(box)
144
+
145
+ if max_confidence > 0:
146
+ confidences = [c / max_confidence for c in confidences]
147
+ return det, confidences
148
+
149
+
150
+ def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
151
+ if text_threshold is None:
152
+ text_threshold = settings.DETECTOR_TEXT_THRESHOLD
153
+
154
+ if low_text is None:
155
+ low_text = settings.DETECTOR_BLANK_THRESHOLD
156
+
157
+ textmap = textmap.copy()
158
+ textmap = textmap.astype(np.float32)
159
+ boxes, confidences = detect_boxes(textmap, text_threshold, low_text)
160
+ # From point form to box form
161
+ boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)]
162
+ return boxes
163
+
164
+
165
+ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None, low_text=None) -> List[PolygonBox]:
166
+ bboxes = get_detected_boxes(textmap, text_threshold, low_text)
167
+ for bbox in bboxes:
168
+ bbox.rescale(processor_size, image_size)
169
+ bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]])
170
+
171
+ bboxes = clean_contained_boxes(bboxes)
172
+ return bboxes
173
+
174
+
175
+
176
+ def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'):
177
+ polys = []
178
+ for bb in bboxes:
179
+ # Clockwise polygon
180
+ poly = [
181
+ [bb[0], bb[1]],
182
+ [bb[2], bb[1]],
183
+ [bb[2], bb[3]],
184
+ [bb[0], bb[3]]
185
+ ]
186
+ polys.append(poly)
187
+
188
+ return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color)
189
+
190
+
191
+ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'):
192
+ draw = ImageDraw.Draw(image)
193
+ font_path = get_font_path()
194
+ label_font = ImageFont.truetype(font_path, label_font_size)
195
+
196
+ for i in range(len(corners)):
197
+ poly = corners[i]
198
+ poly = [(int(p[0]), int(p[1])) for p in poly]
199
+ draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1)
200
+
201
+ if labels is not None:
202
+ label = labels[i]
203
+ text_position = (
204
+ min([p[0] for p in poly]) + label_offset,
205
+ min([p[1] for p in poly]) + label_offset
206
+ )
207
+ text_size = get_text_size(label, label_font)
208
+ box_position = (
209
+ text_position[0] - box_padding + label_offset,
210
+ text_position[1] - box_padding + label_offset,
211
+ text_position[0] + text_size[0] + box_padding + label_offset,
212
+ text_position[1] + text_size[1] + box_padding + label_offset
213
+ )
214
+ draw.rectangle(box_position, fill="white")
215
+ draw.text(
216
+ text_position,
217
+ label,
218
+ fill=color[i] if isinstance(color, list) else color,
219
+ font=label_font
220
+ )
221
+
222
+ return image
223
+
224
+
surya/postprocessing/math/latex.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from ftfy import fix_text
3
+
4
+
5
+ def contains_math(text):
6
+ return text.startswith("$") or text.endswith("$")
7
+
8
+
9
+ def fix_math(text):
10
+ # Fix any issues with the text
11
+ text = fix_text(text)
12
+
13
+ # Remove LaTeX labels and references
14
+ text = remove_labels(text)
15
+ text = replace_katex_invalid(text)
16
+ text = fix_fences(text)
17
+ return text
18
+
19
+
20
+ def remove_labels(text):
21
+ pattern = r'\\label\{[^}]*\}'
22
+ text = re.sub(pattern, '', text)
23
+
24
+ ref_pattern = r'\\ref\{[^}]*\}'
25
+ text = re.sub(ref_pattern, '', text)
26
+
27
+ pageref_pattern = r'\\pageref\{[^}]*\}'
28
+ text = re.sub(pageref_pattern, '', text)
29
+ return text
30
+
31
+
32
+ def replace_katex_invalid(string):
33
+ # KaTeX cannot render all LaTeX, so we need to replace some things
34
+ string = re.sub(r'\\tag\{.*?\}', '', string)
35
+ string = re.sub(r'\\(?:Bigg?|bigg?)\{(.*?)\}', r'\1', string)
36
+ string = re.sub(r'\\quad\\mbox\{(.*?)\}', r'\1', string)
37
+ string = re.sub(r'\\mbox\{(.*?)\}', r'\1', string)
38
+ string = remove_inner_dollars(string)
39
+ return string
40
+
41
+
42
+ def remove_inner_dollars(text):
43
+ def replace_dollar(match):
44
+ # Replace single $ with nothing, keep $$ intact
45
+ math_block = match.group(1)
46
+ return '$$' + math_block.replace('$', '') + '$$'
47
+
48
+ pattern = r'\$\$(.*?)\$\$'
49
+ return re.sub(pattern, replace_dollar, text, flags=re.DOTALL)
50
+
51
+
52
+ def extract_latex_with_positions(text):
53
+ pattern = r'(\$\$.*?\$\$|\$.*?\$)'
54
+ matches = []
55
+ for match in re.finditer(pattern, text, re.DOTALL):
56
+ matches.append((match.group(), match.start(), match.end()))
57
+ return matches
58
+
59
+
60
+ def slice_latex(text):
61
+ # Extract LaTeX blocks along with their positions
62
+ latex_blocks_with_positions = extract_latex_with_positions(text)
63
+
64
+ chunks = []
65
+ last_position = 0
66
+ for block, start, end in latex_blocks_with_positions:
67
+ # Add text before the current LaTeX block, if any
68
+ if start > last_position:
69
+ chunks.append({"text": text[last_position:start], "type": "text"})
70
+ # Add the LaTeX block
71
+ chunks.append({"text": block, "type": "latex"})
72
+ last_position = end
73
+ # Add remaining text after the last LaTeX block, if any
74
+ if last_position < len(text):
75
+ chunks.append({"text": text[last_position:], "type": "text"})
76
+
77
+ return chunks
78
+
79
+
80
+ def is_latex(text):
81
+ latex_patterns = [
82
+ r'\\(?:begin|end)\{[a-zA-Z]*\}',
83
+ r'\$.*?\$',
84
+ r'\$\$.*?\$\$',
85
+ r'\\[a-zA-Z]+',
86
+ r'\\[^a-zA-Z]',
87
+ ]
88
+
89
+ combined_pattern = '|'.join(latex_patterns)
90
+ if re.search(combined_pattern, text, re.DOTALL):
91
+ return True
92
+
93
+ return False
94
+
95
+
96
+ def fix_fences(text):
97
+ if text.startswith("$$") and not text.endswith("$$"):
98
+ if text[-1] == "$":
99
+ text += "$"
100
+ else:
101
+ text += "$$"
102
+
103
+ if text.endswith("$$") and not text.startswith("$$"):
104
+ if text[0] == "$":
105
+ text = "$" + text
106
+ else:
107
+ text = "$$" + text
108
+
109
+ if text.startswith("$") and not text.endswith("$"):
110
+ text = "$" + text + "$$"
111
+
112
+ if text.endswith("$") and not text.startswith("$"):
113
+ text = "$$" + text + "$"
114
+
115
+ return text
116
+
117
+
118
+ def strip_fences(text):
119
+ while text.startswith("$"):
120
+ text = text[1:]
121
+ while text.endswith("$"):
122
+ text = text[:-1]
123
+ return text
124
+
125
+
surya/postprocessing/math/render.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from playwright.sync_api import sync_playwright
2
+ from PIL import Image
3
+ import io
4
+
5
+
6
+ def latex_to_pil(latex_code, target_width, target_height, fontsize=18):
7
+ html_template = """
8
+ <!DOCTYPE html>
9
+ <html>
10
+ <head>
11
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.0/dist/katex.min.css">
12
+ <script src="https://cdn.jsdelivr.net/npm/katex@0.12.0/dist/katex.min.js"></script>
13
+ <style>
14
+ body {
15
+ margin: 0;
16
+ padding: 0;
17
+ display: flex;
18
+ }
19
+ #content {
20
+ font-size: {fontsize}px;
21
+ }
22
+ </style>
23
+ </head>
24
+ <body>
25
+ <div id="content">{content}</div>
26
+ <script>
27
+ function renderMath() {
28
+ let content = document.getElementById('content');
29
+ let html = content.innerHTML;
30
+
31
+ // Replace display equations
32
+ html = html.replace(/\\$\\$(.*?)\\$\\$/gs, (match, equation) => {
33
+ let span = document.createElement('span');
34
+ katex.render(equation, span, { displayMode: true, throwOnError: false, errorColor: '#000' });
35
+ return span.outerHTML;
36
+ });
37
+
38
+ // Replace inline equations
39
+ html = html.replace(/\\$(.*?)\\$/g, (match, equation) => {
40
+ if(match.startsWith('\\\\$')) return match; // Ignore escaped dollars
41
+ let span = document.createElement('span');
42
+ katex.render(equation, span, { displayMode: false, throwOnError: false, errorColor: '#000' });
43
+ return span.outerHTML;
44
+ });
45
+
46
+ content.innerHTML = html;
47
+ }
48
+
49
+ renderMath();
50
+ </script>
51
+ </body>
52
+ </html>
53
+ """
54
+
55
+ formatted_latex = latex_code.replace('\n', '\\n').replace('"', '\\"')
56
+ with sync_playwright() as p:
57
+ browser = p.chromium.launch()
58
+ page = browser.new_page()
59
+ page.set_viewport_size({'width': target_width, 'height': target_height})
60
+
61
+ while fontsize <= 30:
62
+ html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize))
63
+ page.set_content(html_content)
64
+
65
+ dimensions = page.evaluate("""() => {
66
+ const render = document.getElementById('content');
67
+ return {
68
+ width: render.offsetWidth,
69
+ height: render.offsetHeight
70
+ };
71
+ }""")
72
+
73
+ if dimensions['width'] >= target_width or dimensions['height'] >= target_height:
74
+ fontsize -= 1
75
+ break
76
+ else:
77
+ fontsize += 1
78
+
79
+ html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize))
80
+ page.set_content(html_content)
81
+
82
+ screenshot_bytes = page.screenshot()
83
+ browser.close()
84
+
85
+ image_stream = io.BytesIO(screenshot_bytes)
86
+ pil_image = Image.open(image_stream)
87
+ pil_image.load()
88
+ return pil_image
surya/postprocessing/text.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ import requests
5
+ from PIL import Image, ImageDraw, ImageFont
6
+
7
+ from surya.postprocessing.fonts import get_font_path
8
+ from surya.schema import TextLine
9
+ from surya.settings import settings
10
+ from surya.postprocessing.math.latex import is_latex
11
+
12
+
13
+ def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
14
+ # Sorts in reading order. Not 100% accurate, this should only
15
+ # be used as a starting point for more advanced sorting.
16
+ vertical_groups = {}
17
+ for line in lines:
18
+ group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance
19
+ if group_key not in vertical_groups:
20
+ vertical_groups[group_key] = []
21
+ vertical_groups[group_key].append(line)
22
+
23
+ # Sort each group horizontally and flatten the groups into a single list
24
+ sorted_lines = []
25
+ for _, group in sorted(vertical_groups.items()):
26
+ sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0])
27
+ sorted_lines.extend(sorted_group)
28
+
29
+ return sorted_lines
30
+
31
+
32
+ def truncate_repetitions(text: str, min_len=15):
33
+ # From nougat, with some cleanup
34
+ if len(text) < 2 * min_len:
35
+ return text
36
+
37
+ # try to find a length at which the tail is repeating
38
+ max_rep_len = None
39
+ for rep_len in range(min_len, int(len(text) / 2)):
40
+ # check if there is a repetition at the end
41
+ same = True
42
+ for i in range(0, rep_len):
43
+ if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]:
44
+ same = False
45
+ break
46
+
47
+ if same:
48
+ max_rep_len = rep_len
49
+
50
+ if max_rep_len is None:
51
+ return text
52
+
53
+ lcs = text[-max_rep_len:]
54
+
55
+ # remove all but the last repetition
56
+ text_to_truncate = text
57
+ while text_to_truncate.endswith(lcs):
58
+ text_to_truncate = text_to_truncate[:-max_rep_len]
59
+
60
+ return text[:len(text_to_truncate)]
61
+
62
+
63
+ def get_text_size(text, font):
64
+ im = Image.new(mode="P", size=(0, 0))
65
+ draw = ImageDraw.Draw(im)
66
+ _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
67
+ return width, height
68
+
69
+
70
+ def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
71
+ font = ImageFont.truetype(font_path, box_font_size)
72
+ text_width, text_height = get_text_size(text, font)
73
+ while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
74
+ box_font_size = box_font_size - 1
75
+ font = ImageFont.truetype(font_path, box_font_size)
76
+ text_width, text_height = get_text_size(text, font)
77
+
78
+ # Calculate text position (centered in bbox)
79
+ text_width, text_height = get_text_size(text, font)
80
+ x = s_bbox[0]
81
+ y = s_bbox[1] + (bbox_height - text_height) / 2
82
+
83
+ draw.text((x, y), text, fill="black", font=font)
84
+
85
+
86
+ def render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path):
87
+ try:
88
+ from surya.postprocessing.math.render import latex_to_pil
89
+ box_font_size = max(10, min(int(.2 * bbox_height), 24))
90
+ img = latex_to_pil(text, bbox_width, bbox_height, fontsize=box_font_size)
91
+ img.thumbnail((bbox_width, bbox_height))
92
+ image.paste(img, (s_bbox[0], s_bbox[1]))
93
+ except Exception as e:
94
+ print(f"Failed to render math: {e}")
95
+ box_font_size = max(10, min(int(.75 * bbox_height), 24))
96
+ render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size)
97
+
98
+
99
+ def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[str], font_path=None, max_font_size=60, res_upscale=2, has_math=False):
100
+ if font_path is None:
101
+ font_path = get_font_path(langs)
102
+ new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
103
+ image = Image.new('RGB', new_image_size, color='white')
104
+ draw = ImageDraw.Draw(image)
105
+
106
+ for bbox, text in zip(bboxes, texts):
107
+ s_bbox = [int(coord * res_upscale) for coord in bbox]
108
+ bbox_width = s_bbox[2] - s_bbox[0]
109
+ bbox_height = s_bbox[3] - s_bbox[1]
110
+
111
+ # Shrink the text to fit in the bbox if needed
112
+ if has_math and is_latex(text):
113
+ render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path)
114
+ else:
115
+ box_font_size = max(6, min(int(.75 * bbox_height), max_font_size))
116
+ render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size)
117
+
118
+ return image