Spaces:
Running
Running
Upload 56 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +13 -0
- detect_layout.py +67 -0
- detect_text.py +81 -0
- ocr_app.py +257 -0
- ocr_text.py +98 -0
- pyproject.toml +59 -0
- reading_order.py +81 -0
- requirements.txt +5 -0
- scripts/verify_benchmark_scores.py +61 -0
- surya/benchmark/bbox.py +22 -0
- surya/benchmark/metrics.py +193 -0
- surya/benchmark/tatr.py +117 -0
- surya/benchmark/tesseract.py +179 -0
- surya/benchmark/util.py +31 -0
- surya/detection.py +144 -0
- surya/input/langs.py +19 -0
- surya/input/load.py +87 -0
- surya/input/pdflines.py +86 -0
- surya/input/processing.py +118 -0
- surya/languages.py +102 -0
- surya/layout.py +229 -0
- surya/model/detection/config.py +51 -0
- surya/model/detection/model.py +767 -0
- surya/model/detection/processor.py +284 -0
- surya/model/ordering/config.py +8 -0
- surya/model/ordering/decoder.py +557 -0
- surya/model/ordering/encoder.py +83 -0
- surya/model/ordering/encoderdecoder.py +90 -0
- surya/model/ordering/model.py +34 -0
- surya/model/ordering/processor.py +156 -0
- surya/model/recognition/config.py +348 -0
- surya/model/recognition/decoder.py +695 -0
- surya/model/recognition/encoder.py +852 -0
- surya/model/recognition/encoderdecoder.py +145 -0
- surya/model/recognition/model.py +49 -0
- surya/model/recognition/processor.py +206 -0
- surya/model/recognition/tokenizer.py +120 -0
- surya/model/table_rec/config.py +260 -0
- surya/model/table_rec/decoder.py +795 -0
- surya/model/table_rec/encoderdecoder.py +135 -0
- surya/model/table_rec/model.py +34 -0
- surya/model/table_rec/processor.py +248 -0
- surya/ocr.py +114 -0
- surya/ordering.py +141 -0
- surya/postprocessing/affinity.py +165 -0
- surya/postprocessing/fonts.py +24 -0
- surya/postprocessing/heatmap.py +224 -0
- surya/postprocessing/math/latex.py +125 -0
- surya/postprocessing/math/render.py +88 -0
- surya/postprocessing/text.py +118 -0
app.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def run_app():
|
7 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
8 |
+
ocr_app_path = os.path.join(cur_dir, "ocr_app.py")
|
9 |
+
cmd = ["streamlit", "run", ocr_app_path]
|
10 |
+
subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
run_app()
|
detect_layout.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pypdfium2 # Causes a warning if not the top import
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from surya.detection import batch_text_detection
|
8 |
+
from surya.input.load import load_from_folder, load_from_file
|
9 |
+
from surya.layout import batch_layout_detection
|
10 |
+
from surya.model.detection.model import load_model, load_processor
|
11 |
+
from surya.postprocessing.heatmap import draw_polys_on_image
|
12 |
+
from surya.settings import settings
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).")
|
18 |
+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.")
|
19 |
+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
|
20 |
+
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
21 |
+
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
|
22 |
+
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
26 |
+
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
27 |
+
det_model = load_model()
|
28 |
+
det_processor = load_processor()
|
29 |
+
|
30 |
+
if os.path.isdir(args.input_path):
|
31 |
+
images, names, _ = load_from_folder(args.input_path, args.max)
|
32 |
+
folder_name = os.path.basename(args.input_path)
|
33 |
+
else:
|
34 |
+
images, names, _ = load_from_file(args.input_path, args.max)
|
35 |
+
folder_name = os.path.basename(args.input_path).split(".")[0]
|
36 |
+
|
37 |
+
line_predictions = batch_text_detection(images, det_model, det_processor)
|
38 |
+
|
39 |
+
layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
|
40 |
+
result_path = os.path.join(args.results_dir, folder_name)
|
41 |
+
os.makedirs(result_path, exist_ok=True)
|
42 |
+
|
43 |
+
if args.images:
|
44 |
+
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
|
45 |
+
polygons = [p.polygon for p in layout_pred.bboxes]
|
46 |
+
labels = [p.label for p in layout_pred.bboxes]
|
47 |
+
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
|
48 |
+
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))
|
49 |
+
|
50 |
+
if args.debug:
|
51 |
+
heatmap = layout_pred.segmentation_map
|
52 |
+
heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png"))
|
53 |
+
|
54 |
+
predictions_by_page = defaultdict(list)
|
55 |
+
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
|
56 |
+
out_pred = pred.model_dump(exclude=["segmentation_map"])
|
57 |
+
out_pred["page"] = len(predictions_by_page[name]) + 1
|
58 |
+
predictions_by_page[name].append(out_pred)
|
59 |
+
|
60 |
+
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
|
61 |
+
json.dump(predictions_by_page, f, ensure_ascii=False)
|
62 |
+
|
63 |
+
print(f"Wrote results to {result_path}")
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|
detect_text.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from surya.input.load import load_from_folder, load_from_file
|
8 |
+
from surya.model.detection.model import load_model, load_processor
|
9 |
+
from surya.detection import batch_text_detection
|
10 |
+
from surya.postprocessing.affinity import draw_lines_on_image
|
11 |
+
from surya.postprocessing.heatmap import draw_polys_on_image
|
12 |
+
from surya.settings import settings
|
13 |
+
import os
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
|
19 |
+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
|
20 |
+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
|
21 |
+
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
22 |
+
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
|
23 |
+
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
|
27 |
+
model = load_model(checkpoint=checkpoint)
|
28 |
+
processor = load_processor(checkpoint=checkpoint)
|
29 |
+
|
30 |
+
if os.path.isdir(args.input_path):
|
31 |
+
images, names, _ = load_from_folder(args.input_path, args.max)
|
32 |
+
folder_name = os.path.basename(args.input_path)
|
33 |
+
else:
|
34 |
+
images, names, _ = load_from_file(args.input_path, args.max)
|
35 |
+
folder_name = os.path.basename(args.input_path).split(".")[0]
|
36 |
+
|
37 |
+
start = time.time()
|
38 |
+
predictions = batch_text_detection(images, model, processor)
|
39 |
+
result_path = os.path.join(args.results_dir, folder_name)
|
40 |
+
os.makedirs(result_path, exist_ok=True)
|
41 |
+
end = time.time()
|
42 |
+
if args.debug:
|
43 |
+
print(f"Detection took {end - start} seconds")
|
44 |
+
|
45 |
+
if args.images:
|
46 |
+
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
|
47 |
+
polygons = [p.polygon for p in pred.bboxes]
|
48 |
+
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
|
49 |
+
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))
|
50 |
+
|
51 |
+
column_image = draw_lines_on_image(pred.vertical_lines, copy.deepcopy(image))
|
52 |
+
column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png"))
|
53 |
+
|
54 |
+
if args.debug:
|
55 |
+
heatmap = pred.heatmap
|
56 |
+
heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png"))
|
57 |
+
|
58 |
+
affinity_map = pred.affinity_map
|
59 |
+
affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png"))
|
60 |
+
|
61 |
+
predictions_by_page = defaultdict(list)
|
62 |
+
for idx, (pred, name, image) in enumerate(zip(predictions, names, images)):
|
63 |
+
out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
|
64 |
+
out_pred["page"] = len(predictions_by_page[name]) + 1
|
65 |
+
predictions_by_page[name].append(out_pred)
|
66 |
+
|
67 |
+
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
|
68 |
+
json.dump(predictions_by_page, f, ensure_ascii=False)
|
69 |
+
|
70 |
+
print(f"Wrote results to {result_path}")
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
ocr_app.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import pypdfium2
|
5 |
+
import streamlit as st
|
6 |
+
from pypdfium2 import PdfiumError
|
7 |
+
|
8 |
+
from surya.detection import batch_text_detection
|
9 |
+
from surya.input.pdflines import get_page_text_lines, get_table_blocks
|
10 |
+
from surya.layout import batch_layout_detection
|
11 |
+
from surya.model.detection.model import load_model, load_processor
|
12 |
+
from surya.model.recognition.model import load_model as load_rec_model
|
13 |
+
from surya.model.recognition.processor import load_processor as load_rec_processor
|
14 |
+
from surya.model.ordering.processor import load_processor as load_order_processor
|
15 |
+
from surya.model.ordering.model import load_model as load_order_model
|
16 |
+
from surya.model.table_rec.model import load_model as load_table_model
|
17 |
+
from surya.model.table_rec.processor import load_processor as load_table_processor
|
18 |
+
from surya.ordering import batch_ordering
|
19 |
+
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
|
20 |
+
from surya.ocr import run_ocr
|
21 |
+
from surya.postprocessing.text import draw_text_on_image
|
22 |
+
from PIL import Image
|
23 |
+
from surya.languages import CODE_TO_LANGUAGE
|
24 |
+
from surya.input.langs import replace_lang_with_code
|
25 |
+
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
|
26 |
+
from surya.settings import settings
|
27 |
+
from surya.tables import batch_table_recognition
|
28 |
+
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
|
29 |
+
|
30 |
+
|
31 |
+
@st.cache_resource()
|
32 |
+
def load_det_cached():
|
33 |
+
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
|
34 |
+
return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)
|
35 |
+
|
36 |
+
|
37 |
+
@st.cache_resource()
|
38 |
+
def load_rec_cached():
|
39 |
+
return load_rec_model(), load_rec_processor()
|
40 |
+
|
41 |
+
|
42 |
+
@st.cache_resource()
|
43 |
+
def load_layout_cached():
|
44 |
+
return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
45 |
+
|
46 |
+
@st.cache_resource()
|
47 |
+
def load_order_cached():
|
48 |
+
return load_order_model(), load_order_processor()
|
49 |
+
|
50 |
+
|
51 |
+
@st.cache_resource()
|
52 |
+
def load_table_cached():
|
53 |
+
return load_table_model(), load_table_processor()
|
54 |
+
|
55 |
+
|
56 |
+
def text_detection(img) -> (Image.Image, TextDetectionResult):
|
57 |
+
pred = batch_text_detection([img], det_model, det_processor)[0]
|
58 |
+
polygons = [p.polygon for p in pred.bboxes]
|
59 |
+
det_img = draw_polys_on_image(polygons, img.copy())
|
60 |
+
return det_img, pred
|
61 |
+
|
62 |
+
|
63 |
+
def layout_detection(img) -> (Image.Image, LayoutResult):
|
64 |
+
_, det_pred = text_detection(img)
|
65 |
+
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
|
66 |
+
polygons = [p.polygon for p in pred.bboxes]
|
67 |
+
labels = [p.label for p in pred.bboxes]
|
68 |
+
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
|
69 |
+
return layout_img, pred
|
70 |
+
|
71 |
+
|
72 |
+
def order_detection(img) -> (Image.Image, OrderResult):
|
73 |
+
_, layout_pred = layout_detection(img)
|
74 |
+
bboxes = [l.bbox for l in layout_pred.bboxes]
|
75 |
+
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
|
76 |
+
polys = [l.polygon for l in pred.bboxes]
|
77 |
+
positions = [str(l.position) for l in pred.bboxes]
|
78 |
+
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
|
79 |
+
return order_img, pred
|
80 |
+
|
81 |
+
|
82 |
+
def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
|
83 |
+
if skip_table_detection:
|
84 |
+
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
|
85 |
+
table_imgs = [highres_img]
|
86 |
+
else:
|
87 |
+
_, layout_pred = layout_detection(img)
|
88 |
+
layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
|
89 |
+
table_imgs = []
|
90 |
+
layout_tables = []
|
91 |
+
for tb in layout_tables_lowres:
|
92 |
+
highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
|
93 |
+
table_imgs.append(
|
94 |
+
highres_img.crop(highres_bbox)
|
95 |
+
)
|
96 |
+
layout_tables.append(highres_bbox)
|
97 |
+
|
98 |
+
try:
|
99 |
+
page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
|
100 |
+
table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
|
101 |
+
except PdfiumError:
|
102 |
+
# This happens when we try to get text from an image
|
103 |
+
table_bboxes = [[] for _ in layout_tables]
|
104 |
+
|
105 |
+
if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
|
106 |
+
det_results = batch_text_detection(table_imgs, det_model, det_processor)
|
107 |
+
table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
|
108 |
+
|
109 |
+
table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
|
110 |
+
table_img = img.copy()
|
111 |
+
|
112 |
+
for results, table_bbox in zip(table_preds, layout_tables):
|
113 |
+
adjusted_bboxes = []
|
114 |
+
labels = []
|
115 |
+
|
116 |
+
for item in results.cells:
|
117 |
+
adjusted_bboxes.append([
|
118 |
+
(item.bbox[0] + table_bbox[0]),
|
119 |
+
(item.bbox[1] + table_bbox[1]),
|
120 |
+
(item.bbox[2] + table_bbox[0]),
|
121 |
+
(item.bbox[3] + table_bbox[1])
|
122 |
+
])
|
123 |
+
labels.append(f"{item.row_id} / {item.col_id}")
|
124 |
+
table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18)
|
125 |
+
return table_img, table_preds
|
126 |
+
|
127 |
+
|
128 |
+
# Function for OCR
|
129 |
+
def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
|
130 |
+
replace_lang_with_code(langs)
|
131 |
+
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
|
132 |
+
|
133 |
+
bboxes = [l.bbox for l in img_pred.text_lines]
|
134 |
+
text = [l.text for l in img_pred.text_lines]
|
135 |
+
rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
|
136 |
+
return rec_img, img_pred
|
137 |
+
|
138 |
+
|
139 |
+
def open_pdf(pdf_file):
|
140 |
+
stream = io.BytesIO(pdf_file.getvalue())
|
141 |
+
return pypdfium2.PdfDocument(stream)
|
142 |
+
|
143 |
+
|
144 |
+
@st.cache_data()
|
145 |
+
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
|
146 |
+
doc = open_pdf(pdf_file)
|
147 |
+
renderer = doc.render(
|
148 |
+
pypdfium2.PdfBitmap.to_pil,
|
149 |
+
page_indices=[page_num - 1],
|
150 |
+
scale=dpi / 72,
|
151 |
+
)
|
152 |
+
png = list(renderer)[0]
|
153 |
+
png_image = png.convert("RGB")
|
154 |
+
return png_image
|
155 |
+
|
156 |
+
|
157 |
+
@st.cache_data()
|
158 |
+
def page_count(pdf_file):
|
159 |
+
doc = open_pdf(pdf_file)
|
160 |
+
return len(doc)
|
161 |
+
|
162 |
+
|
163 |
+
st.set_page_config(layout="wide")
|
164 |
+
col1, col2 = st.columns([.5, .5])
|
165 |
+
|
166 |
+
det_model, det_processor = load_det_cached()
|
167 |
+
rec_model, rec_processor = load_rec_cached()
|
168 |
+
layout_model, layout_processor = load_layout_cached()
|
169 |
+
order_model, order_processor = load_order_cached()
|
170 |
+
table_model, table_processor = load_table_cached()
|
171 |
+
|
172 |
+
|
173 |
+
st.markdown("""
|
174 |
+
# Surya OCR Demo
|
175 |
+
|
176 |
+
This app will let you try surya, a multilingual OCR model. It supports text detection + layout analysis in any language, and text recognition in 90+ languages.
|
177 |
+
|
178 |
+
Notes:
|
179 |
+
- This works best on documents with printed text.
|
180 |
+
- Preprocessing the image (e.g. increasing contrast) can improve results.
|
181 |
+
- If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).
|
182 |
+
- This supports 90+ languages, see [here](https://github.com/VikParuchuri/surya/tree/master/surya/languages.py) for a full list.
|
183 |
+
|
184 |
+
Find the project [here](https://github.com/VikParuchuri/surya).
|
185 |
+
""")
|
186 |
+
|
187 |
+
in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
|
188 |
+
languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=[], max_selections=4, help="Select the languages in the image (if known) to improve OCR accuracy. Optional.")
|
189 |
+
|
190 |
+
if in_file is None:
|
191 |
+
st.stop()
|
192 |
+
|
193 |
+
filetype = in_file.type
|
194 |
+
whole_image = False
|
195 |
+
if "pdf" in filetype:
|
196 |
+
page_count = page_count(in_file)
|
197 |
+
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
|
198 |
+
|
199 |
+
pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
|
200 |
+
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
201 |
+
else:
|
202 |
+
pil_image = Image.open(in_file).convert("RGB")
|
203 |
+
pil_image_highres = pil_image
|
204 |
+
page_number = None
|
205 |
+
|
206 |
+
text_det = st.sidebar.button("Run Text Detection")
|
207 |
+
text_rec = st.sidebar.button("Run OCR")
|
208 |
+
layout_det = st.sidebar.button("Run Layout Analysis")
|
209 |
+
order_det = st.sidebar.button("Run Reading Order")
|
210 |
+
table_rec = st.sidebar.button("Run Table Rec")
|
211 |
+
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
|
212 |
+
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
|
213 |
+
|
214 |
+
if pil_image is None:
|
215 |
+
st.stop()
|
216 |
+
|
217 |
+
# Run Text Detection
|
218 |
+
if text_det:
|
219 |
+
det_img, pred = text_detection(pil_image)
|
220 |
+
with col1:
|
221 |
+
st.image(det_img, caption="Detected Text", use_column_width=True)
|
222 |
+
st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
|
223 |
+
|
224 |
+
|
225 |
+
# Run layout
|
226 |
+
if layout_det:
|
227 |
+
layout_img, pred = layout_detection(pil_image)
|
228 |
+
with col1:
|
229 |
+
st.image(layout_img, caption="Detected Layout", use_column_width=True)
|
230 |
+
st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)
|
231 |
+
|
232 |
+
# Run OCR
|
233 |
+
if text_rec:
|
234 |
+
rec_img, pred = ocr(pil_image, pil_image_highres, languages)
|
235 |
+
with col1:
|
236 |
+
st.image(rec_img, caption="OCR Result", use_column_width=True)
|
237 |
+
json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
|
238 |
+
with json_tab:
|
239 |
+
st.json(pred.model_dump(), expanded=True)
|
240 |
+
with text_tab:
|
241 |
+
st.text("\n".join([p.text for p in pred.text_lines]))
|
242 |
+
|
243 |
+
if order_det:
|
244 |
+
order_img, pred = order_detection(pil_image)
|
245 |
+
with col1:
|
246 |
+
st.image(order_img, caption="Reading Order", use_column_width=True)
|
247 |
+
st.json(pred.model_dump(), expanded=True)
|
248 |
+
|
249 |
+
|
250 |
+
if table_rec:
|
251 |
+
table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
|
252 |
+
with col1:
|
253 |
+
st.image(table_img, caption="Table Recognition", use_column_width=True)
|
254 |
+
st.json([p.model_dump() for p in pred], expanded=True)
|
255 |
+
|
256 |
+
with col2:
|
257 |
+
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
|
ocr_text.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from surya.input.langs import replace_lang_with_code, get_unique_langs
|
10 |
+
from surya.input.load import load_from_folder, load_from_file, load_lang_file
|
11 |
+
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
|
12 |
+
from surya.model.recognition.model import load_model as load_recognition_model
|
13 |
+
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
14 |
+
from surya.model.recognition.tokenizer import _tokenize
|
15 |
+
from surya.ocr import run_ocr
|
16 |
+
from surya.postprocessing.text import draw_text_on_image
|
17 |
+
from surya.settings import settings
|
18 |
+
|
19 |
+
|
20 |
+
def main():
|
21 |
+
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
|
22 |
+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
|
23 |
+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
|
24 |
+
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
25 |
+
parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0)
|
26 |
+
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
|
27 |
+
parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None)
|
28 |
+
parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None)
|
29 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False)
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
if os.path.isdir(args.input_path):
|
33 |
+
images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
|
34 |
+
highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
|
35 |
+
folder_name = os.path.basename(args.input_path)
|
36 |
+
else:
|
37 |
+
images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
|
38 |
+
highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
|
39 |
+
folder_name = os.path.basename(args.input_path).split(".")[0]
|
40 |
+
|
41 |
+
if args.lang_file:
|
42 |
+
# We got all of our language settings from a file
|
43 |
+
langs = load_lang_file(args.lang_file, names)
|
44 |
+
for lang in langs:
|
45 |
+
replace_lang_with_code(lang)
|
46 |
+
image_langs = langs
|
47 |
+
elif args.langs:
|
48 |
+
# We got our language settings from the input
|
49 |
+
langs = args.langs.split(",")
|
50 |
+
replace_lang_with_code(langs)
|
51 |
+
image_langs = [langs] * len(images)
|
52 |
+
else:
|
53 |
+
image_langs = [None] * len(images)
|
54 |
+
|
55 |
+
det_processor = load_detection_processor()
|
56 |
+
det_model = load_detection_model()
|
57 |
+
|
58 |
+
rec_model = load_recognition_model()
|
59 |
+
rec_processor = load_recognition_processor()
|
60 |
+
|
61 |
+
result_path = os.path.join(args.results_dir, folder_name)
|
62 |
+
os.makedirs(result_path, exist_ok=True)
|
63 |
+
|
64 |
+
start = time.time()
|
65 |
+
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images)
|
66 |
+
if args.debug:
|
67 |
+
print(f"OCR took {time.time() - start:.2f} seconds")
|
68 |
+
max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
|
69 |
+
print(f"Max chars: {max_chars}")
|
70 |
+
|
71 |
+
if args.images:
|
72 |
+
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
|
73 |
+
bboxes = [l.bbox for l in pred.text_lines]
|
74 |
+
pred_text = [l.text for l in pred.text_lines]
|
75 |
+
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False)
|
76 |
+
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
|
77 |
+
|
78 |
+
out_preds = defaultdict(list)
|
79 |
+
for name, pred, image in zip(names, predictions_by_image, images):
|
80 |
+
out_pred = pred.model_dump()
|
81 |
+
out_pred["page"] = len(out_preds[name]) + 1
|
82 |
+
out_preds[name].append(out_pred)
|
83 |
+
|
84 |
+
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
|
85 |
+
json.dump(out_preds, f, ensure_ascii=False)
|
86 |
+
|
87 |
+
print(f"Wrote results to {result_path}")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
main()
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
pyproject.toml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "surya-ocr"
|
3 |
+
version = "0.6.1"
|
4 |
+
description = "OCR, layout, reading order, and table recognition in 90+ languages"
|
5 |
+
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
|
6 |
+
readme = "README.md"
|
7 |
+
license = "GPL-3.0-or-later"
|
8 |
+
repository = "https://github.com/VikParuchuri/surya"
|
9 |
+
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
|
10 |
+
packages = [
|
11 |
+
{include = "surya"}
|
12 |
+
]
|
13 |
+
include = [
|
14 |
+
"detect_text.py",
|
15 |
+
"ocr_text.py",
|
16 |
+
"ocr_app.py",
|
17 |
+
"run_ocr_app.py",
|
18 |
+
"detect_layout.py",
|
19 |
+
"reading_order.py",
|
20 |
+
"table_recognition.py"
|
21 |
+
]
|
22 |
+
|
23 |
+
[tool.poetry.dependencies]
|
24 |
+
python = ">=3.9,<3.13,!=3.9.7"
|
25 |
+
transformers = "^4.41.0"
|
26 |
+
torch = "^2.3.0"
|
27 |
+
pydantic = "^2.5.3"
|
28 |
+
pydantic-settings = "^2.1.0"
|
29 |
+
python-dotenv = "^1.0.0"
|
30 |
+
pillow = "^10.2.0"
|
31 |
+
pypdfium2 = "^4.25.0"
|
32 |
+
opencv-python = "^4.9.0.80"
|
33 |
+
tabulate = "^0.9.0"
|
34 |
+
filetype = "^1.2.0"
|
35 |
+
ftfy = "^6.1.3"
|
36 |
+
pdftext = "^0.3.12"
|
37 |
+
|
38 |
+
[tool.poetry.group.dev.dependencies]
|
39 |
+
jupyter = "^1.0.0"
|
40 |
+
pytesseract = "^0.3.10"
|
41 |
+
pymupdf = "^1.23.8"
|
42 |
+
snakeviz = "^2.2.0"
|
43 |
+
datasets = "^2.16.1"
|
44 |
+
rapidfuzz = "^3.6.1"
|
45 |
+
arabic-reshaper = "^3.0.0"
|
46 |
+
streamlit = "^1.31.0"
|
47 |
+
playwright = "^1.41.2"
|
48 |
+
|
49 |
+
[tool.poetry.scripts]
|
50 |
+
surya_detect = "detect_text:main"
|
51 |
+
surya_ocr = "ocr_text:main"
|
52 |
+
surya_layout = "detect_layout:main"
|
53 |
+
surya_gui = "run_ocr_app:run_app"
|
54 |
+
surya_order = "reading_order:main"
|
55 |
+
surya_table = "table_recognition:main"
|
56 |
+
|
57 |
+
[build-system]
|
58 |
+
requires = ["poetry-core"]
|
59 |
+
build-backend = "poetry.core.masonry.api"
|
reading_order.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from surya.detection import batch_text_detection
|
8 |
+
from surya.input.load import load_from_folder, load_from_file
|
9 |
+
from surya.layout import batch_layout_detection
|
10 |
+
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
|
11 |
+
from surya.model.ordering.model import load_model
|
12 |
+
from surya.model.ordering.processor import load_processor
|
13 |
+
from surya.ordering import batch_ordering
|
14 |
+
from surya.postprocessing.heatmap import draw_polys_on_image
|
15 |
+
from surya.settings import settings
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
|
20 |
+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
|
21 |
+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
|
22 |
+
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
23 |
+
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
model = load_model()
|
27 |
+
processor = load_processor()
|
28 |
+
|
29 |
+
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
30 |
+
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
31 |
+
|
32 |
+
det_model = load_det_model()
|
33 |
+
det_processor = load_det_processor()
|
34 |
+
|
35 |
+
if os.path.isdir(args.input_path):
|
36 |
+
images, names, _ = load_from_folder(args.input_path, args.max)
|
37 |
+
folder_name = os.path.basename(args.input_path)
|
38 |
+
else:
|
39 |
+
images, names, _ = load_from_file(args.input_path, args.max)
|
40 |
+
folder_name = os.path.basename(args.input_path).split(".")[0]
|
41 |
+
|
42 |
+
line_predictions = batch_text_detection(images, det_model, det_processor)
|
43 |
+
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
|
44 |
+
bboxes = []
|
45 |
+
for layout_pred in layout_predictions:
|
46 |
+
bbox = [l.bbox for l in layout_pred.bboxes]
|
47 |
+
bboxes.append(bbox)
|
48 |
+
|
49 |
+
order_predictions = batch_ordering(images, bboxes, model, processor)
|
50 |
+
result_path = os.path.join(args.results_dir, folder_name)
|
51 |
+
os.makedirs(result_path, exist_ok=True)
|
52 |
+
|
53 |
+
if args.images:
|
54 |
+
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
|
55 |
+
polys = [l.polygon for l in order_pred.bboxes]
|
56 |
+
labels = [str(l.position) for l in order_pred.bboxes]
|
57 |
+
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
|
58 |
+
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png"))
|
59 |
+
|
60 |
+
predictions_by_page = defaultdict(list)
|
61 |
+
for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)):
|
62 |
+
out_pred = pred.model_dump()
|
63 |
+
for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes):
|
64 |
+
bbox["label"] = layout_bbox.label
|
65 |
+
|
66 |
+
out_pred["page"] = len(predictions_by_page[name]) + 1
|
67 |
+
predictions_by_page[name].append(out_pred)
|
68 |
+
|
69 |
+
# Sort in reading order
|
70 |
+
for name in predictions_by_page:
|
71 |
+
for page_preds in predictions_by_page[name]:
|
72 |
+
page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"])
|
73 |
+
|
74 |
+
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
|
75 |
+
json.dump(predictions_by_page, f, ensure_ascii=False)
|
76 |
+
|
77 |
+
print(f"Wrote results to {result_path}")
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
surya-ocr
|
scripts/verify_benchmark_scores.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
def verify_layout(data):
|
6 |
+
scores = data["metrics"]
|
7 |
+
for layout_type, metrics in scores.items():
|
8 |
+
if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6:
|
9 |
+
raise ValueError("Scores do not meet the required threshold")
|
10 |
+
|
11 |
+
|
12 |
+
def verify_det(data):
|
13 |
+
scores = data["metrics"]["surya"]
|
14 |
+
if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
|
15 |
+
raise ValueError("Scores do not meet the required threshold")
|
16 |
+
|
17 |
+
|
18 |
+
def verify_rec(data):
|
19 |
+
scores = data["surya"]
|
20 |
+
if scores["avg_score"] <= 0.9:
|
21 |
+
raise ValueError("Scores do not meet the required threshold")
|
22 |
+
|
23 |
+
|
24 |
+
def verify_order(data):
|
25 |
+
score = data["mean_accuracy"]
|
26 |
+
if score < 0.75:
|
27 |
+
raise ValueError("Scores do not meet the required threshold")
|
28 |
+
|
29 |
+
|
30 |
+
def verify_table_rec(data):
|
31 |
+
row_score = data["surya"]["mean_row_iou"]
|
32 |
+
col_score = data["surya"]["mean_col_iou"]
|
33 |
+
|
34 |
+
if row_score < 0.75 or col_score < 0.75:
|
35 |
+
raise ValueError("Scores do not meet the required threshold")
|
36 |
+
|
37 |
+
|
38 |
+
def verify_scores(file_path, bench_type):
|
39 |
+
with open(file_path, 'r') as file:
|
40 |
+
data = json.load(file)
|
41 |
+
|
42 |
+
if bench_type == "detection":
|
43 |
+
verify_det(data)
|
44 |
+
elif bench_type == "recognition":
|
45 |
+
verify_rec(data)
|
46 |
+
elif bench_type == "layout":
|
47 |
+
verify_layout(data)
|
48 |
+
elif bench_type == "ordering":
|
49 |
+
verify_order(data)
|
50 |
+
elif bench_type == "table_recognition":
|
51 |
+
verify_table_rec(data)
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid benchmark type")
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
parser = argparse.ArgumentParser(description="Verify benchmark scores")
|
58 |
+
parser.add_argument("file_path", type=str, help="Path to the json file")
|
59 |
+
parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection")
|
60 |
+
args = parser.parse_args()
|
61 |
+
verify_scores(args.file_path, args.bench_type)
|
surya/benchmark/bbox.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fitz as pymupdf
|
2 |
+
from surya.postprocessing.util import rescale_bbox
|
3 |
+
|
4 |
+
|
5 |
+
def get_pdf_lines(pdf_path, img_sizes):
|
6 |
+
doc = pymupdf.open(pdf_path)
|
7 |
+
page_lines = []
|
8 |
+
for idx, img_size in enumerate(img_sizes):
|
9 |
+
page = doc[idx]
|
10 |
+
blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"]
|
11 |
+
|
12 |
+
line_boxes = []
|
13 |
+
for block_idx, block in enumerate(blocks):
|
14 |
+
for l in block["lines"]:
|
15 |
+
line_boxes.append(list(l["bbox"]))
|
16 |
+
|
17 |
+
page_box = page.bound()
|
18 |
+
pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1]
|
19 |
+
line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes]
|
20 |
+
page_lines.append(line_boxes)
|
21 |
+
|
22 |
+
return page_lines
|
surya/benchmark/metrics.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from itertools import repeat
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
|
7 |
+
|
8 |
+
def intersection_area(box1, box2):
|
9 |
+
x_left = max(box1[0], box2[0])
|
10 |
+
y_top = max(box1[1], box2[1])
|
11 |
+
x_right = min(box1[2], box2[2])
|
12 |
+
y_bottom = min(box1[3], box2[3])
|
13 |
+
|
14 |
+
if x_right < x_left or y_bottom < y_top:
|
15 |
+
return 0.0
|
16 |
+
|
17 |
+
return (x_right - x_left) * (y_bottom - y_top)
|
18 |
+
|
19 |
+
def box_area(box):
|
20 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
21 |
+
|
22 |
+
|
23 |
+
def calculate_iou(box1, box2, box1_only=False):
|
24 |
+
intersection = intersection_area(box1, box2)
|
25 |
+
union = box_area(box1)
|
26 |
+
if not box1_only:
|
27 |
+
union += box_area(box2) - intersection
|
28 |
+
|
29 |
+
if union == 0:
|
30 |
+
return 0
|
31 |
+
return intersection / union
|
32 |
+
|
33 |
+
|
34 |
+
def match_boxes(preds, references):
|
35 |
+
num_actual = len(references)
|
36 |
+
num_predicted = len(preds)
|
37 |
+
|
38 |
+
iou_matrix = np.zeros((num_actual, num_predicted))
|
39 |
+
for i, actual in enumerate(references):
|
40 |
+
for j, pred in enumerate(preds):
|
41 |
+
iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
|
42 |
+
|
43 |
+
sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
|
44 |
+
sorted_ious = iou_matrix.flatten()[sorted_indices]
|
45 |
+
actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
|
46 |
+
|
47 |
+
assigned_actual = set()
|
48 |
+
assigned_pred = set()
|
49 |
+
|
50 |
+
matches = []
|
51 |
+
for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
|
52 |
+
i, j = idx
|
53 |
+
if i not in assigned_actual and j not in assigned_pred:
|
54 |
+
iou_val = iou_matrix[i, j]
|
55 |
+
if iou_val > .95: # Account for rounding on box edges
|
56 |
+
iou_val = 1.0
|
57 |
+
matches.append((i, j, iou_val))
|
58 |
+
assigned_actual.add(i)
|
59 |
+
assigned_pred.add(j)
|
60 |
+
|
61 |
+
unassigned_actual = set(range(num_actual)) - assigned_actual
|
62 |
+
unassigned_pred = set(range(num_predicted)) - assigned_pred
|
63 |
+
matches.extend([(i, None, -1.0) for i in unassigned_actual])
|
64 |
+
matches.extend([(None, j, 0.0) for j in unassigned_pred])
|
65 |
+
|
66 |
+
return matches
|
67 |
+
|
68 |
+
def penalized_iou_score(preds, references):
|
69 |
+
matches = match_boxes(preds, references)
|
70 |
+
iou = sum([match[2] for match in matches]) / len(matches)
|
71 |
+
return iou
|
72 |
+
|
73 |
+
def intersection_pixels(box1, box2):
|
74 |
+
x_left = max(box1[0], box2[0])
|
75 |
+
y_top = max(box1[1], box2[1])
|
76 |
+
x_right = min(box1[2], box2[2])
|
77 |
+
y_bottom = min(box1[3], box2[3])
|
78 |
+
|
79 |
+
if x_right < x_left or y_bottom < y_top:
|
80 |
+
return set()
|
81 |
+
|
82 |
+
x_left, x_right = int(x_left), int(x_right)
|
83 |
+
y_top, y_bottom = int(y_top), int(y_bottom)
|
84 |
+
|
85 |
+
coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
|
86 |
+
pixels = set(zip(coords[0].flat, coords[1].flat))
|
87 |
+
|
88 |
+
return pixels
|
89 |
+
|
90 |
+
|
91 |
+
def calculate_coverage(box, other_boxes, penalize_double=False):
|
92 |
+
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
93 |
+
if box_area == 0:
|
94 |
+
return 0
|
95 |
+
|
96 |
+
# find total coverage of the box
|
97 |
+
covered_pixels = set()
|
98 |
+
double_coverage = list()
|
99 |
+
for other_box in other_boxes:
|
100 |
+
ia = intersection_pixels(box, other_box)
|
101 |
+
double_coverage.append(list(covered_pixels.intersection(ia)))
|
102 |
+
covered_pixels = covered_pixels.union(ia)
|
103 |
+
|
104 |
+
# Penalize double coverage - having multiple bboxes overlapping the same pixels
|
105 |
+
double_coverage_penalty = len(double_coverage)
|
106 |
+
if not penalize_double:
|
107 |
+
double_coverage_penalty = 0
|
108 |
+
covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
|
109 |
+
return covered_pixels_count / box_area
|
110 |
+
|
111 |
+
|
112 |
+
def calculate_coverage_fast(box, other_boxes, penalize_double=False):
|
113 |
+
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
114 |
+
if box_area == 0:
|
115 |
+
return 0
|
116 |
+
|
117 |
+
total_intersect = 0
|
118 |
+
for other_box in other_boxes:
|
119 |
+
total_intersect += intersection_area(box, other_box)
|
120 |
+
|
121 |
+
return min(1, total_intersect / box_area)
|
122 |
+
|
123 |
+
|
124 |
+
def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
|
125 |
+
if len(references) == 0:
|
126 |
+
return {
|
127 |
+
"precision": 1,
|
128 |
+
"recall": 1,
|
129 |
+
}
|
130 |
+
|
131 |
+
if len(preds) == 0:
|
132 |
+
return {
|
133 |
+
"precision": 0,
|
134 |
+
"recall": 0,
|
135 |
+
}
|
136 |
+
|
137 |
+
# If we're not penalizing double coverage, we can use a faster calculation
|
138 |
+
coverage_func = calculate_coverage_fast
|
139 |
+
if penalize_double:
|
140 |
+
coverage_func = calculate_coverage
|
141 |
+
|
142 |
+
with ProcessPoolExecutor(max_workers=workers) as executor:
|
143 |
+
precision_func = partial(coverage_func, penalize_double=penalize_double)
|
144 |
+
precision_iou = executor.map(precision_func, preds, repeat(references))
|
145 |
+
reference_iou = executor.map(coverage_func, references, repeat(preds))
|
146 |
+
|
147 |
+
precision_classes = [1 if i > threshold else 0 for i in precision_iou]
|
148 |
+
precision = sum(precision_classes) / len(precision_classes)
|
149 |
+
|
150 |
+
recall_classes = [1 if i > threshold else 0 for i in reference_iou]
|
151 |
+
recall = sum(recall_classes) / len(recall_classes)
|
152 |
+
|
153 |
+
return {
|
154 |
+
"precision": precision,
|
155 |
+
"recall": recall,
|
156 |
+
}
|
157 |
+
|
158 |
+
|
159 |
+
def mean_coverage(preds, references):
|
160 |
+
coverages = []
|
161 |
+
|
162 |
+
for box1 in references:
|
163 |
+
coverage = calculate_coverage(box1, preds)
|
164 |
+
coverages.append(coverage)
|
165 |
+
|
166 |
+
for box2 in preds:
|
167 |
+
coverage = calculate_coverage(box2, references)
|
168 |
+
coverages.append(coverage)
|
169 |
+
|
170 |
+
# Calculate the average coverage over all comparisons
|
171 |
+
if len(coverages) == 0:
|
172 |
+
return 0
|
173 |
+
coverage = sum(coverages) / len(coverages)
|
174 |
+
return {"coverage": coverage}
|
175 |
+
|
176 |
+
|
177 |
+
def rank_accuracy(preds, references):
|
178 |
+
# Preds and references need to be aligned so each position refers to the same bbox
|
179 |
+
pairs = []
|
180 |
+
for i, pred in enumerate(preds):
|
181 |
+
for j, pred2 in enumerate(preds):
|
182 |
+
if i == j:
|
183 |
+
continue
|
184 |
+
pairs.append((i, j, pred > pred2))
|
185 |
+
|
186 |
+
# Find how many of the prediction rankings are correct
|
187 |
+
correct = 0
|
188 |
+
for i, ref in enumerate(references):
|
189 |
+
for j, ref2 in enumerate(references):
|
190 |
+
if (i, j, ref > ref2) in pairs:
|
191 |
+
correct += 1
|
192 |
+
|
193 |
+
return correct / len(pairs)
|
surya/benchmark/tatr.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
|
3 |
+
from surya.settings import settings
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class MaxResize(object):
|
10 |
+
def __init__(self, max_size=800):
|
11 |
+
self.max_size = max_size
|
12 |
+
|
13 |
+
def __call__(self, image):
|
14 |
+
width, height = image.size
|
15 |
+
current_max_size = max(width, height)
|
16 |
+
scale = self.max_size / current_max_size
|
17 |
+
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
|
18 |
+
|
19 |
+
return resized_image
|
20 |
+
|
21 |
+
|
22 |
+
def to_tensor(image):
|
23 |
+
# Convert PIL Image to NumPy array
|
24 |
+
np_image = np.array(image).astype(np.float32)
|
25 |
+
|
26 |
+
# Rearrange dimensions to [C, H, W] format
|
27 |
+
np_image = np_image.transpose((2, 0, 1))
|
28 |
+
|
29 |
+
# Normalize to [0.0, 1.0]
|
30 |
+
np_image /= 255.0
|
31 |
+
|
32 |
+
return torch.from_numpy(np_image)
|
33 |
+
|
34 |
+
|
35 |
+
def normalize(tensor, mean, std):
|
36 |
+
for t, m, s in zip(tensor, mean, std):
|
37 |
+
t.sub_(m).div_(s)
|
38 |
+
return tensor
|
39 |
+
|
40 |
+
|
41 |
+
def structure_transform(image):
|
42 |
+
image = MaxResize(1000)(image)
|
43 |
+
tensor = to_tensor(image)
|
44 |
+
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
45 |
+
return normalized_tensor
|
46 |
+
|
47 |
+
|
48 |
+
def box_cxcywh_to_xyxy(x):
|
49 |
+
x_c, y_c, w, h = x.unbind(-1)
|
50 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
51 |
+
return torch.stack(b, dim=1)
|
52 |
+
|
53 |
+
|
54 |
+
def rescale_bboxes(out_bbox, size):
|
55 |
+
width, height = size
|
56 |
+
boxes = box_cxcywh_to_xyxy(out_bbox)
|
57 |
+
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
|
58 |
+
return boxes
|
59 |
+
|
60 |
+
|
61 |
+
def outputs_to_objects(outputs, img_sizes, id2label):
|
62 |
+
m = outputs.logits.softmax(-1).max(-1)
|
63 |
+
batch_labels = list(m.indices.detach().cpu().numpy())
|
64 |
+
batch_scores = list(m.values.detach().cpu().numpy())
|
65 |
+
batch_bboxes = outputs['pred_boxes'].detach().cpu()
|
66 |
+
|
67 |
+
batch_objects = []
|
68 |
+
for i in range(len(img_sizes)):
|
69 |
+
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
|
70 |
+
pred_scores = batch_scores[i]
|
71 |
+
pred_labels = batch_labels[i]
|
72 |
+
|
73 |
+
objects = []
|
74 |
+
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
75 |
+
class_label = id2label[int(label)]
|
76 |
+
if not class_label == 'no object':
|
77 |
+
objects.append({
|
78 |
+
'label': class_label,
|
79 |
+
'score': float(score),
|
80 |
+
'bbox': [float(elem) for elem in bbox]}
|
81 |
+
)
|
82 |
+
|
83 |
+
rows = []
|
84 |
+
cols = []
|
85 |
+
for i, cell in enumerate(objects):
|
86 |
+
if cell["label"] == "table column":
|
87 |
+
cols.append(cell)
|
88 |
+
|
89 |
+
if cell["label"] == "table row":
|
90 |
+
rows.append(cell)
|
91 |
+
batch_objects.append({
|
92 |
+
"rows": rows,
|
93 |
+
"cols": cols
|
94 |
+
})
|
95 |
+
|
96 |
+
return batch_objects
|
97 |
+
|
98 |
+
|
99 |
+
def load_tatr():
|
100 |
+
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
|
101 |
+
|
102 |
+
|
103 |
+
def batch_inference_tatr(model, images, batch_size):
|
104 |
+
device = model.device
|
105 |
+
rows_cols = []
|
106 |
+
for i in range(0, len(images), batch_size):
|
107 |
+
batch_images = images[i:i + batch_size]
|
108 |
+
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
|
109 |
+
|
110 |
+
# forward pass
|
111 |
+
with torch.no_grad():
|
112 |
+
outputs = model(pixel_values)
|
113 |
+
|
114 |
+
id2label = model.config.id2label
|
115 |
+
id2label[len(model.config.id2label)] = "no object"
|
116 |
+
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
|
117 |
+
return rows_cols
|
surya/benchmark/tesseract.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytesseract
|
5 |
+
from pytesseract import Output
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from surya.input.processing import slice_bboxes_from_image
|
9 |
+
from surya.settings import settings
|
10 |
+
import os
|
11 |
+
from concurrent.futures import ProcessPoolExecutor
|
12 |
+
from surya.detection import get_batch_size as get_det_batch_size
|
13 |
+
from surya.recognition import get_batch_size as get_rec_batch_size
|
14 |
+
from surya.languages import CODE_TO_LANGUAGE
|
15 |
+
|
16 |
+
|
17 |
+
def surya_lang_to_tesseract(code: str) -> Optional[str]:
|
18 |
+
lang_str = CODE_TO_LANGUAGE[code]
|
19 |
+
try:
|
20 |
+
tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
|
21 |
+
except KeyError:
|
22 |
+
return None
|
23 |
+
return tess_lang
|
24 |
+
|
25 |
+
|
26 |
+
def tesseract_ocr(img, bboxes, lang: str):
|
27 |
+
line_imgs = slice_bboxes_from_image(img, bboxes)
|
28 |
+
config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
|
29 |
+
lines = []
|
30 |
+
for line_img in line_imgs:
|
31 |
+
line = pytesseract.image_to_string(line_img, lang=lang, config=config)
|
32 |
+
lines.append(line)
|
33 |
+
return lines
|
34 |
+
|
35 |
+
|
36 |
+
def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
|
37 |
+
tess_parallel_cores = min(len(imgs), get_rec_batch_size())
|
38 |
+
if not cpus:
|
39 |
+
cpus = os.cpu_count()
|
40 |
+
tess_parallel_cores = min(tess_parallel_cores, cpus)
|
41 |
+
|
42 |
+
# Tesseract uses up to 4 processes per instance
|
43 |
+
# Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
|
44 |
+
tess_parallel = max(tess_parallel_cores // 2, 1)
|
45 |
+
|
46 |
+
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
|
47 |
+
tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
|
48 |
+
tess_text = list(tess_text)
|
49 |
+
return tess_text
|
50 |
+
|
51 |
+
|
52 |
+
def tesseract_bboxes(img):
|
53 |
+
arr_img = np.asarray(img, dtype=np.uint8)
|
54 |
+
ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
|
55 |
+
|
56 |
+
bboxes = []
|
57 |
+
n_boxes = len(ocr['level'])
|
58 |
+
for i in range(n_boxes):
|
59 |
+
# It is possible to merge by line here with line number, but it gives bad results.
|
60 |
+
_, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
|
61 |
+
bbox = (x, y, x + w, y + h)
|
62 |
+
bboxes.append(bbox)
|
63 |
+
|
64 |
+
return bboxes
|
65 |
+
|
66 |
+
|
67 |
+
def tesseract_parallel(imgs):
|
68 |
+
# Tesseract uses 4 threads per instance
|
69 |
+
tess_parallel_cores = min(len(imgs), get_det_batch_size())
|
70 |
+
cpus = os.cpu_count()
|
71 |
+
tess_parallel_cores = min(tess_parallel_cores, cpus)
|
72 |
+
|
73 |
+
# Tesseract uses 4 threads per instance
|
74 |
+
tess_parallel = max(tess_parallel_cores // 4, 1)
|
75 |
+
|
76 |
+
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
|
77 |
+
tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
|
78 |
+
tess_bboxes = list(tess_bboxes)
|
79 |
+
return tess_bboxes
|
80 |
+
|
81 |
+
|
82 |
+
TESS_CODE_TO_LANGUAGE = {
|
83 |
+
"afr": "Afrikaans",
|
84 |
+
"amh": "Amharic",
|
85 |
+
"ara": "Arabic",
|
86 |
+
"asm": "Assamese",
|
87 |
+
"aze": "Azerbaijani",
|
88 |
+
"bel": "Belarusian",
|
89 |
+
"ben": "Bengali",
|
90 |
+
"bod": "Tibetan",
|
91 |
+
"bos": "Bosnian",
|
92 |
+
"bre": "Breton",
|
93 |
+
"bul": "Bulgarian",
|
94 |
+
"cat": "Catalan",
|
95 |
+
"ceb": "Cebuano",
|
96 |
+
"ces": "Czech",
|
97 |
+
"chi_sim": "Chinese",
|
98 |
+
"chr": "Cherokee",
|
99 |
+
"cym": "Welsh",
|
100 |
+
"dan": "Danish",
|
101 |
+
"deu": "German",
|
102 |
+
"dzo": "Dzongkha",
|
103 |
+
"ell": "Greek",
|
104 |
+
"eng": "English",
|
105 |
+
"epo": "Esperanto",
|
106 |
+
"est": "Estonian",
|
107 |
+
"eus": "Basque",
|
108 |
+
"fas": "Persian",
|
109 |
+
"fin": "Finnish",
|
110 |
+
"fra": "French",
|
111 |
+
"fry": "Western Frisian",
|
112 |
+
"guj": "Gujarati",
|
113 |
+
"gla": "Scottish Gaelic",
|
114 |
+
"gle": "Irish",
|
115 |
+
"glg": "Galician",
|
116 |
+
"heb": "Hebrew",
|
117 |
+
"hin": "Hindi",
|
118 |
+
"hrv": "Croatian",
|
119 |
+
"hun": "Hungarian",
|
120 |
+
"hye": "Armenian",
|
121 |
+
"iku": "Inuktitut",
|
122 |
+
"ind": "Indonesian",
|
123 |
+
"isl": "Icelandic",
|
124 |
+
"ita": "Italian",
|
125 |
+
"jav": "Javanese",
|
126 |
+
"jpn": "Japanese",
|
127 |
+
"kan": "Kannada",
|
128 |
+
"kat": "Georgian",
|
129 |
+
"kaz": "Kazakh",
|
130 |
+
"khm": "Khmer",
|
131 |
+
"kir": "Kyrgyz",
|
132 |
+
"kor": "Korean",
|
133 |
+
"lao": "Lao",
|
134 |
+
"lat": "Latin",
|
135 |
+
"lav": "Latvian",
|
136 |
+
"lit": "Lithuanian",
|
137 |
+
"mal": "Malayalam",
|
138 |
+
"mar": "Marathi",
|
139 |
+
"mkd": "Macedonian",
|
140 |
+
"mlt": "Maltese",
|
141 |
+
"mon": "Mongolian",
|
142 |
+
"msa": "Malay",
|
143 |
+
"mya": "Burmese",
|
144 |
+
"nep": "Nepali",
|
145 |
+
"nld": "Dutch",
|
146 |
+
"nor": "Norwegian",
|
147 |
+
"ori": "Oriya",
|
148 |
+
"pan": "Punjabi",
|
149 |
+
"pol": "Polish",
|
150 |
+
"por": "Portuguese",
|
151 |
+
"pus": "Pashto",
|
152 |
+
"ron": "Romanian",
|
153 |
+
"rus": "Russian",
|
154 |
+
"san": "Sanskrit",
|
155 |
+
"sin": "Sinhala",
|
156 |
+
"slk": "Slovak",
|
157 |
+
"slv": "Slovenian",
|
158 |
+
"snd": "Sindhi",
|
159 |
+
"spa": "Spanish",
|
160 |
+
"sqi": "Albanian",
|
161 |
+
"srp": "Serbian",
|
162 |
+
"swa": "Swahili",
|
163 |
+
"swe": "Swedish",
|
164 |
+
"syr": "Syriac",
|
165 |
+
"tam": "Tamil",
|
166 |
+
"tel": "Telugu",
|
167 |
+
"tgk": "Tajik",
|
168 |
+
"tha": "Thai",
|
169 |
+
"tir": "Tigrinya",
|
170 |
+
"tur": "Turkish",
|
171 |
+
"uig": "Uyghur",
|
172 |
+
"ukr": "Ukrainian",
|
173 |
+
"urd": "Urdu",
|
174 |
+
"uzb": "Uzbek",
|
175 |
+
"vie": "Vietnamese",
|
176 |
+
"yid": "Yiddish"
|
177 |
+
}
|
178 |
+
|
179 |
+
TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
|
surya/benchmark/util.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def merge_boxes(box1, box2):
|
2 |
+
return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]))
|
3 |
+
|
4 |
+
|
5 |
+
def join_lines(bboxes, max_gap=5):
|
6 |
+
to_merge = {}
|
7 |
+
for i, box1 in bboxes:
|
8 |
+
for z, box2 in bboxes[i + 1:]:
|
9 |
+
j = i + z + 1
|
10 |
+
if box1 == box2:
|
11 |
+
continue
|
12 |
+
|
13 |
+
if box1[0] <= box2[0] and box1[2] >= box2[2]:
|
14 |
+
if abs(box1[1] - box2[3]) <= max_gap:
|
15 |
+
if i not in to_merge:
|
16 |
+
to_merge[i] = []
|
17 |
+
to_merge[i].append(j)
|
18 |
+
|
19 |
+
merged_boxes = set()
|
20 |
+
merged = []
|
21 |
+
for i, box in bboxes:
|
22 |
+
if i in merged_boxes:
|
23 |
+
continue
|
24 |
+
|
25 |
+
if i in to_merge:
|
26 |
+
for j in to_merge[i]:
|
27 |
+
box = merge_boxes(box, bboxes[j][1])
|
28 |
+
merged_boxes.add(j)
|
29 |
+
|
30 |
+
merged.append(box)
|
31 |
+
return merged
|
surya/detection.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Generator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from surya.model.detection.model import EfficientViTForSemanticSegmentation
|
8 |
+
from surya.postprocessing.heatmap import get_and_clean_boxes
|
9 |
+
from surya.postprocessing.affinity import get_vertical_lines
|
10 |
+
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
|
11 |
+
from surya.schema import TextDetectionResult
|
12 |
+
from surya.settings import settings
|
13 |
+
from tqdm import tqdm
|
14 |
+
from concurrent.futures import ProcessPoolExecutor
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
def get_batch_size():
|
19 |
+
batch_size = settings.DETECTOR_BATCH_SIZE
|
20 |
+
if batch_size is None:
|
21 |
+
batch_size = 8
|
22 |
+
if settings.TORCH_DEVICE_MODEL == "mps":
|
23 |
+
batch_size = 8
|
24 |
+
if settings.TORCH_DEVICE_MODEL == "cuda":
|
25 |
+
batch_size = 36
|
26 |
+
return batch_size
|
27 |
+
|
28 |
+
|
29 |
+
def batch_detection(
|
30 |
+
images: List,
|
31 |
+
model: EfficientViTForSemanticSegmentation,
|
32 |
+
processor,
|
33 |
+
batch_size=None
|
34 |
+
) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
|
35 |
+
assert all([isinstance(image, Image.Image) for image in images])
|
36 |
+
if batch_size is None:
|
37 |
+
batch_size = get_batch_size()
|
38 |
+
heatmap_count = model.config.num_labels
|
39 |
+
|
40 |
+
orig_sizes = [image.size for image in images]
|
41 |
+
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
|
42 |
+
|
43 |
+
batches = []
|
44 |
+
current_batch_size = 0
|
45 |
+
current_batch = []
|
46 |
+
for i in range(len(images)):
|
47 |
+
if current_batch_size + splits_per_image[i] > batch_size:
|
48 |
+
if len(current_batch) > 0:
|
49 |
+
batches.append(current_batch)
|
50 |
+
current_batch = []
|
51 |
+
current_batch_size = 0
|
52 |
+
current_batch.append(i)
|
53 |
+
current_batch_size += splits_per_image[i]
|
54 |
+
|
55 |
+
if len(current_batch) > 0:
|
56 |
+
batches.append(current_batch)
|
57 |
+
|
58 |
+
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
|
59 |
+
batch_image_idxs = batches[batch_idx]
|
60 |
+
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
|
61 |
+
|
62 |
+
split_index = []
|
63 |
+
split_heights = []
|
64 |
+
image_splits = []
|
65 |
+
for image_idx, image in enumerate(batch_images):
|
66 |
+
image_parts, split_height = split_image(image, processor)
|
67 |
+
image_splits.extend(image_parts)
|
68 |
+
split_index.extend([image_idx] * len(image_parts))
|
69 |
+
split_heights.extend(split_height)
|
70 |
+
|
71 |
+
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
|
72 |
+
# Batch images in dim 0
|
73 |
+
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
|
74 |
+
|
75 |
+
with torch.inference_mode():
|
76 |
+
pred = model(pixel_values=batch)
|
77 |
+
|
78 |
+
logits = pred.logits
|
79 |
+
correct_shape = [processor.size["height"], processor.size["width"]]
|
80 |
+
current_shape = list(logits.shape[2:])
|
81 |
+
if current_shape != correct_shape:
|
82 |
+
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
|
83 |
+
|
84 |
+
logits = logits.cpu().detach().numpy().astype(np.float32)
|
85 |
+
preds = []
|
86 |
+
for i, (idx, height) in enumerate(zip(split_index, split_heights)):
|
87 |
+
# If our current prediction length is below the image idx, that means we have a new image
|
88 |
+
# Otherwise, we need to add to the current image
|
89 |
+
if len(preds) <= idx:
|
90 |
+
preds.append([logits[i][k] for k in range(heatmap_count)])
|
91 |
+
else:
|
92 |
+
heatmaps = preds[idx]
|
93 |
+
pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
|
94 |
+
|
95 |
+
if height < processor.size["height"]:
|
96 |
+
# Cut off padding to get original height
|
97 |
+
pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps]
|
98 |
+
|
99 |
+
for k in range(heatmap_count):
|
100 |
+
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
|
101 |
+
preds[idx] = heatmaps
|
102 |
+
|
103 |
+
yield preds, [orig_sizes[j] for j in batch_image_idxs]
|
104 |
+
|
105 |
+
|
106 |
+
def parallel_get_lines(preds, orig_sizes):
|
107 |
+
heatmap, affinity_map = preds
|
108 |
+
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
|
109 |
+
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
|
110 |
+
affinity_size = list(reversed(affinity_map.shape))
|
111 |
+
heatmap_size = list(reversed(heatmap.shape))
|
112 |
+
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
|
113 |
+
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
|
114 |
+
|
115 |
+
result = TextDetectionResult(
|
116 |
+
bboxes=bboxes,
|
117 |
+
vertical_lines=vertical_lines,
|
118 |
+
heatmap=heat_img,
|
119 |
+
affinity_map=aff_img,
|
120 |
+
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
|
121 |
+
)
|
122 |
+
return result
|
123 |
+
|
124 |
+
|
125 |
+
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
|
126 |
+
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
|
127 |
+
|
128 |
+
results = []
|
129 |
+
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
130 |
+
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
|
131 |
+
|
132 |
+
if parallelize:
|
133 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
134 |
+
for preds, orig_sizes in detection_generator:
|
135 |
+
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
|
136 |
+
results.extend(batch_results)
|
137 |
+
else:
|
138 |
+
for preds, orig_sizes in detection_generator:
|
139 |
+
for pred, orig_size in zip(preds, orig_sizes):
|
140 |
+
results.append(parallel_get_lines(pred, orig_size))
|
141 |
+
|
142 |
+
return results
|
143 |
+
|
144 |
+
|
surya/input/langs.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE
|
3 |
+
|
4 |
+
|
5 |
+
def replace_lang_with_code(langs: List[str]):
|
6 |
+
for i in range(len(langs)):
|
7 |
+
if langs[i].title() in LANGUAGE_TO_CODE:
|
8 |
+
langs[i] = LANGUAGE_TO_CODE[langs[i].title()]
|
9 |
+
if langs[i] not in CODE_TO_LANGUAGE:
|
10 |
+
raise ValueError(f"Language code {langs[i]} not found.")
|
11 |
+
|
12 |
+
|
13 |
+
def get_unique_langs(langs: List[List[str]]):
|
14 |
+
uniques = []
|
15 |
+
for lang_list in langs:
|
16 |
+
for lang in lang_list:
|
17 |
+
if lang not in uniques:
|
18 |
+
uniques.append(lang)
|
19 |
+
return uniques
|
surya/input/load.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
|
3 |
+
from surya.input.processing import open_pdf, get_page_images
|
4 |
+
from surya.settings import settings
|
5 |
+
import os
|
6 |
+
import filetype
|
7 |
+
from PIL import Image
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def get_name_from_path(path):
|
13 |
+
return os.path.basename(path).split(".")[0]
|
14 |
+
|
15 |
+
|
16 |
+
def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
17 |
+
doc = open_pdf(pdf_path)
|
18 |
+
last_page = len(doc)
|
19 |
+
|
20 |
+
if start_page:
|
21 |
+
assert start_page < last_page and start_page >= 0, f"Start page must be between 0 and {last_page}"
|
22 |
+
else:
|
23 |
+
start_page = 0
|
24 |
+
|
25 |
+
if max_pages:
|
26 |
+
assert max_pages >= 0, f"Max pages must be greater than 0"
|
27 |
+
last_page = min(start_page + max_pages, last_page)
|
28 |
+
|
29 |
+
page_indices = list(range(start_page, last_page))
|
30 |
+
images = get_page_images(doc, page_indices, dpi=dpi)
|
31 |
+
text_lines = None
|
32 |
+
if load_text_lines:
|
33 |
+
from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import
|
34 |
+
text_lines = get_page_text_lines(
|
35 |
+
pdf_path,
|
36 |
+
page_indices,
|
37 |
+
[i.size for i in images]
|
38 |
+
)
|
39 |
+
doc.close()
|
40 |
+
names = [get_name_from_path(pdf_path) for _ in page_indices]
|
41 |
+
return images, names, text_lines
|
42 |
+
|
43 |
+
|
44 |
+
def load_image(image_path):
|
45 |
+
image = Image.open(image_path).convert("RGB")
|
46 |
+
name = get_name_from_path(image_path)
|
47 |
+
return [image], [name], [None]
|
48 |
+
|
49 |
+
|
50 |
+
def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
51 |
+
input_type = filetype.guess(input_path)
|
52 |
+
if input_type.extension == "pdf":
|
53 |
+
return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
|
54 |
+
else:
|
55 |
+
return load_image(input_path)
|
56 |
+
|
57 |
+
|
58 |
+
def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
59 |
+
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")]
|
60 |
+
image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
|
61 |
+
|
62 |
+
images = []
|
63 |
+
names = []
|
64 |
+
text_lines = []
|
65 |
+
for path in image_paths:
|
66 |
+
extension = filetype.guess(path)
|
67 |
+
if extension and extension.extension == "pdf":
|
68 |
+
image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
|
69 |
+
images.extend(image)
|
70 |
+
names.extend(name)
|
71 |
+
text_lines.extend(text_line)
|
72 |
+
else:
|
73 |
+
try:
|
74 |
+
image, name, text_line = load_image(path)
|
75 |
+
images.extend(image)
|
76 |
+
names.extend(name)
|
77 |
+
text_lines.extend(text_line)
|
78 |
+
except PIL.UnidentifiedImageError:
|
79 |
+
print(f"Could not load image {path}")
|
80 |
+
continue
|
81 |
+
return images, names, text_lines
|
82 |
+
|
83 |
+
|
84 |
+
def load_lang_file(lang_path, names):
|
85 |
+
with open(lang_path, "r") as f:
|
86 |
+
lang_dict = json.load(f)
|
87 |
+
return [lang_dict[name].copy() for name in names]
|
surya/input/pdflines.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pdftext.extraction import dictionary_output
|
2 |
+
|
3 |
+
from surya.postprocessing.text import sort_text_lines
|
4 |
+
from surya.schema import PolygonBox
|
5 |
+
|
6 |
+
|
7 |
+
def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list:
|
8 |
+
assert len(page_idxs) == len(out_sizes)
|
9 |
+
pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True)
|
10 |
+
for full_text, out_size in zip(pages_text, out_sizes):
|
11 |
+
width = full_text["width"]
|
12 |
+
height = full_text["height"]
|
13 |
+
text_w_scale = out_size[0] / width
|
14 |
+
text_h_scale = out_size[1] / height
|
15 |
+
for block in full_text["blocks"]:
|
16 |
+
for line in block["lines"]:
|
17 |
+
line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
|
18 |
+
line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
|
19 |
+
for span in line["spans"]:
|
20 |
+
for char in span["chars"]:
|
21 |
+
char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
|
22 |
+
char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
|
23 |
+
return pages_text
|
24 |
+
|
25 |
+
|
26 |
+
def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8):
|
27 |
+
# Returns coordinates relative to input table, not full image
|
28 |
+
table_texts = []
|
29 |
+
for table in tables:
|
30 |
+
table_poly = PolygonBox(polygon=[
|
31 |
+
[table[0], table[1]],
|
32 |
+
[table[2], table[1]],
|
33 |
+
[table[2], table[3]],
|
34 |
+
[table[0], table[3]]
|
35 |
+
])
|
36 |
+
table_text = []
|
37 |
+
rotation = full_text["rotation"]
|
38 |
+
for block in full_text["blocks"]:
|
39 |
+
for line in block["lines"]:
|
40 |
+
line_poly = PolygonBox(polygon=[
|
41 |
+
[line["bbox"][0], line["bbox"][1]],
|
42 |
+
[line["bbox"][2], line["bbox"][1]],
|
43 |
+
[line["bbox"][2], line["bbox"][3]],
|
44 |
+
[line["bbox"][0], line["bbox"][3]]
|
45 |
+
])
|
46 |
+
if line_poly.intersection_pct(table_poly) < table_thresh:
|
47 |
+
continue
|
48 |
+
curr_span = None
|
49 |
+
curr_box = None
|
50 |
+
for span in line["spans"]:
|
51 |
+
for char in span["chars"]:
|
52 |
+
same_span = False
|
53 |
+
if curr_span:
|
54 |
+
if rotation == 90:
|
55 |
+
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01
|
56 |
+
elif rotation == 180:
|
57 |
+
same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
|
58 |
+
elif rotation == 270:
|
59 |
+
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01
|
60 |
+
else:
|
61 |
+
same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
|
62 |
+
|
63 |
+
if curr_span is None:
|
64 |
+
curr_span = char["char"]
|
65 |
+
curr_box = char["bbox"]
|
66 |
+
elif same_span:
|
67 |
+
curr_span += char["char"]
|
68 |
+
curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
|
69 |
+
max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
|
70 |
+
else:
|
71 |
+
table_text.append({"text": curr_span, "bbox": curr_box})
|
72 |
+
curr_span = char["char"]
|
73 |
+
curr_box = char["bbox"]
|
74 |
+
if curr_span is not None:
|
75 |
+
table_text.append({"text": curr_span, "bbox": curr_box})
|
76 |
+
# Adjust to be relative to input table
|
77 |
+
for item in table_text:
|
78 |
+
item["bbox"] = [
|
79 |
+
item["bbox"][0] - table[0],
|
80 |
+
item["bbox"][1] - table[1],
|
81 |
+
item["bbox"][2] - table[0],
|
82 |
+
item["bbox"][3] - table[1]
|
83 |
+
]
|
84 |
+
table_text = sort_text_lines(table_text)
|
85 |
+
table_texts.append(table_text)
|
86 |
+
return table_texts
|
surya/input/processing.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import pypdfium2
|
7 |
+
from PIL import Image, ImageOps, ImageDraw
|
8 |
+
import torch
|
9 |
+
from surya.settings import settings
|
10 |
+
|
11 |
+
|
12 |
+
def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
|
13 |
+
new_images = []
|
14 |
+
for image in images:
|
15 |
+
if image.mode != "RGB":
|
16 |
+
image = image.convert("RGB")
|
17 |
+
new_images.append(image)
|
18 |
+
return new_images
|
19 |
+
|
20 |
+
|
21 |
+
def get_total_splits(image_size, processor):
|
22 |
+
img_height = list(image_size)[1]
|
23 |
+
max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
|
24 |
+
processor_height = processor.size["height"]
|
25 |
+
if img_height > max_height:
|
26 |
+
num_splits = math.ceil(img_height / processor_height)
|
27 |
+
return num_splits
|
28 |
+
return 1
|
29 |
+
|
30 |
+
|
31 |
+
def split_image(img, processor):
|
32 |
+
# This will not modify/return the original image - it will either crop, or copy the image
|
33 |
+
img_height = list(img.size)[1]
|
34 |
+
max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT
|
35 |
+
processor_height = processor.size["height"]
|
36 |
+
if img_height > max_height:
|
37 |
+
num_splits = math.ceil(img_height / processor_height)
|
38 |
+
splits = []
|
39 |
+
split_heights = []
|
40 |
+
for i in range(num_splits):
|
41 |
+
top = i * processor_height
|
42 |
+
bottom = (i + 1) * processor_height
|
43 |
+
if bottom > img_height:
|
44 |
+
bottom = img_height
|
45 |
+
cropped = img.crop((0, top, img.size[0], bottom))
|
46 |
+
height = bottom - top
|
47 |
+
if height < processor_height:
|
48 |
+
cropped = ImageOps.pad(cropped, (img.size[0], processor_height), color=255, centering=(0, 0))
|
49 |
+
splits.append(cropped)
|
50 |
+
split_heights.append(height)
|
51 |
+
return splits, split_heights
|
52 |
+
return [img.copy()], [img_height]
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_image_detection(img, processor):
|
56 |
+
new_size = (processor.size["width"], processor.size["height"])
|
57 |
+
|
58 |
+
# This double resize actually necessary for downstream accuracy
|
59 |
+
img.thumbnail(new_size, Image.Resampling.LANCZOS)
|
60 |
+
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
|
61 |
+
|
62 |
+
img = np.asarray(img, dtype=np.uint8)
|
63 |
+
img = processor(img)["pixel_values"][0]
|
64 |
+
img = torch.from_numpy(img)
|
65 |
+
return img
|
66 |
+
|
67 |
+
|
68 |
+
def open_pdf(pdf_filepath):
|
69 |
+
return pypdfium2.PdfDocument(pdf_filepath)
|
70 |
+
|
71 |
+
|
72 |
+
def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
|
73 |
+
renderer = doc.render(
|
74 |
+
pypdfium2.PdfBitmap.to_pil,
|
75 |
+
page_indices=indices,
|
76 |
+
scale=dpi / 72,
|
77 |
+
)
|
78 |
+
images = list(renderer)
|
79 |
+
images = [image.convert("RGB") for image in images]
|
80 |
+
return images
|
81 |
+
|
82 |
+
|
83 |
+
def slice_bboxes_from_image(image: Image.Image, bboxes):
|
84 |
+
lines = []
|
85 |
+
for bbox in bboxes:
|
86 |
+
line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
87 |
+
if line.size[0] == 0:
|
88 |
+
print(f"Warning: found an empty line with bbox {bbox}")
|
89 |
+
lines.append(line)
|
90 |
+
return lines
|
91 |
+
|
92 |
+
|
93 |
+
def slice_polys_from_image(image: Image.Image, polys):
|
94 |
+
image_array = np.array(image, dtype=np.uint8)
|
95 |
+
lines = []
|
96 |
+
for idx, poly in enumerate(polys):
|
97 |
+
lines.append(slice_and_pad_poly(image_array, poly))
|
98 |
+
return lines
|
99 |
+
|
100 |
+
|
101 |
+
def slice_and_pad_poly(image_array: np.array, coordinates):
|
102 |
+
# Draw polygon onto mask
|
103 |
+
coordinates = [(corner[0], corner[1]) for corner in coordinates]
|
104 |
+
bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])]
|
105 |
+
|
106 |
+
# We mask out anything not in the polygon
|
107 |
+
cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
|
108 |
+
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
|
109 |
+
|
110 |
+
# Pad the area outside the polygon with the pad value
|
111 |
+
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
|
112 |
+
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
|
113 |
+
mask = np.stack([mask] * 3, axis=-1)
|
114 |
+
|
115 |
+
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
|
116 |
+
rectangle_image = Image.fromarray(cropped_polygon)
|
117 |
+
|
118 |
+
return rectangle_image
|
surya/languages.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CODE_TO_LANGUAGE = {
|
2 |
+
"_math": "Math",
|
3 |
+
'af': 'Afrikaans',
|
4 |
+
'am': 'Amharic',
|
5 |
+
'ar': 'Arabic',
|
6 |
+
'as': 'Assamese',
|
7 |
+
'az': 'Azerbaijani',
|
8 |
+
'be': 'Belarusian',
|
9 |
+
'bg': 'Bulgarian',
|
10 |
+
'bn': 'Bengali',
|
11 |
+
'br': 'Breton',
|
12 |
+
'bs': 'Bosnian',
|
13 |
+
'ca': 'Catalan',
|
14 |
+
'cs': 'Czech',
|
15 |
+
'cy': 'Welsh',
|
16 |
+
'da': 'Danish',
|
17 |
+
'de': 'German',
|
18 |
+
'el': 'Greek',
|
19 |
+
'en': 'English',
|
20 |
+
'eo': 'Esperanto',
|
21 |
+
'es': 'Spanish',
|
22 |
+
'et': 'Estonian',
|
23 |
+
'eu': 'Basque',
|
24 |
+
'fa': 'Persian',
|
25 |
+
'fi': 'Finnish',
|
26 |
+
'fr': 'French',
|
27 |
+
'fy': 'Western Frisian',
|
28 |
+
'ga': 'Irish',
|
29 |
+
'gd': 'Scottish Gaelic',
|
30 |
+
'gl': 'Galician',
|
31 |
+
'gu': 'Gujarati',
|
32 |
+
'ha': 'Hausa',
|
33 |
+
'he': 'Hebrew',
|
34 |
+
'hi': 'Hindi',
|
35 |
+
'hr': 'Croatian',
|
36 |
+
'hu': 'Hungarian',
|
37 |
+
'hy': 'Armenian',
|
38 |
+
'id': 'Indonesian',
|
39 |
+
'is': 'Icelandic',
|
40 |
+
'it': 'Italian',
|
41 |
+
'ja': 'Japanese',
|
42 |
+
'jv': 'Javanese',
|
43 |
+
'ka': 'Georgian',
|
44 |
+
'kk': 'Kazakh',
|
45 |
+
'km': 'Khmer',
|
46 |
+
'kn': 'Kannada',
|
47 |
+
'ko': 'Korean',
|
48 |
+
'ku': 'Kurdish',
|
49 |
+
'ky': 'Kyrgyz',
|
50 |
+
'la': 'Latin',
|
51 |
+
'lo': 'Lao',
|
52 |
+
'lt': 'Lithuanian',
|
53 |
+
'lv': 'Latvian',
|
54 |
+
'mg': 'Malagasy',
|
55 |
+
'mk': 'Macedonian',
|
56 |
+
'ml': 'Malayalam',
|
57 |
+
'mn': 'Mongolian',
|
58 |
+
'mr': 'Marathi',
|
59 |
+
'ms': 'Malay',
|
60 |
+
'my': 'Burmese',
|
61 |
+
'ne': 'Nepali',
|
62 |
+
'nl': 'Dutch',
|
63 |
+
'no': 'Norwegian',
|
64 |
+
'om': 'Oromo',
|
65 |
+
'or': 'Oriya',
|
66 |
+
'pa': 'Punjabi',
|
67 |
+
'pl': 'Polish',
|
68 |
+
'ps': 'Pashto',
|
69 |
+
'pt': 'Portuguese',
|
70 |
+
'ro': 'Romanian',
|
71 |
+
'ru': 'Russian',
|
72 |
+
'sa': 'Sanskrit',
|
73 |
+
'sd': 'Sindhi',
|
74 |
+
'si': 'Sinhala',
|
75 |
+
'sk': 'Slovak',
|
76 |
+
'sl': 'Slovenian',
|
77 |
+
'so': 'Somali',
|
78 |
+
'sq': 'Albanian',
|
79 |
+
'sr': 'Serbian',
|
80 |
+
'su': 'Sundanese',
|
81 |
+
'sv': 'Swedish',
|
82 |
+
'sw': 'Swahili',
|
83 |
+
'ta': 'Tamil',
|
84 |
+
'te': 'Telugu',
|
85 |
+
'th': 'Thai',
|
86 |
+
'tl': 'Tagalog',
|
87 |
+
'tr': 'Turkish',
|
88 |
+
'ug': 'Uyghur',
|
89 |
+
'uk': 'Ukrainian',
|
90 |
+
'ur': 'Urdu',
|
91 |
+
'uz': 'Uzbek',
|
92 |
+
'vi': 'Vietnamese',
|
93 |
+
'xh': 'Xhosa',
|
94 |
+
'yi': 'Yiddish',
|
95 |
+
'zh': 'Chinese',
|
96 |
+
}
|
97 |
+
|
98 |
+
LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}
|
99 |
+
|
100 |
+
|
101 |
+
def is_arabic(lang_code):
|
102 |
+
return lang_code in ["ar", "fa", "ps", "ug", "ur"]
|
surya/layout.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from concurrent.futures import ProcessPoolExecutor
|
3 |
+
from typing import List, Optional
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from surya.detection import batch_detection
|
8 |
+
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
|
9 |
+
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
|
10 |
+
from surya.settings import settings
|
11 |
+
|
12 |
+
|
13 |
+
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
|
14 |
+
logits = np.stack(heatmaps, axis=0)
|
15 |
+
vertical_line_bboxes = detection_result.vertical_lines
|
16 |
+
line_bboxes = detection_result.bboxes
|
17 |
+
|
18 |
+
# Scale back to processor size
|
19 |
+
for line in vertical_line_bboxes:
|
20 |
+
line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape)))
|
21 |
+
|
22 |
+
for line in line_bboxes:
|
23 |
+
line.rescale(orig_size, list(reversed(heatmaps[0].shape)))
|
24 |
+
|
25 |
+
for bbox in vertical_line_bboxes:
|
26 |
+
# Give some width to the vertical lines
|
27 |
+
vert_bbox = list(bbox.bbox)
|
28 |
+
vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width)
|
29 |
+
|
30 |
+
logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are
|
31 |
+
|
32 |
+
logits[:, logits[0] >= .5] = 0 # zero out where blanks are
|
33 |
+
|
34 |
+
# Zero out where other segments are
|
35 |
+
for i in range(logits.shape[0]):
|
36 |
+
logits[i, segment_assignment != i] = 0
|
37 |
+
|
38 |
+
detected_boxes = []
|
39 |
+
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
|
40 |
+
heatmap = logits[heatmap_idx]
|
41 |
+
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
|
42 |
+
continue
|
43 |
+
bboxes = get_detected_boxes(heatmap)
|
44 |
+
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
|
45 |
+
for bb in bboxes:
|
46 |
+
bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
|
47 |
+
|
48 |
+
for bbox in bboxes:
|
49 |
+
detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1))
|
50 |
+
|
51 |
+
detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
|
52 |
+
# Expand bbox to cover intersecting lines
|
53 |
+
box_lines = defaultdict(list)
|
54 |
+
used_lines = set()
|
55 |
+
|
56 |
+
# We try 2 rounds of identifying the correct lines to snap to
|
57 |
+
# First round is majority intersection, second lowers the threshold
|
58 |
+
for thresh in [.5, .4]:
|
59 |
+
for bbox_idx, bbox in enumerate(detected_boxes):
|
60 |
+
for line_idx, line_bbox in enumerate(line_bboxes):
|
61 |
+
if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines:
|
62 |
+
box_lines[bbox_idx].append(line_bbox.bbox)
|
63 |
+
used_lines.add(line_idx)
|
64 |
+
|
65 |
+
new_boxes = []
|
66 |
+
for bbox_idx, bbox in enumerate(detected_boxes):
|
67 |
+
if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures
|
68 |
+
continue
|
69 |
+
|
70 |
+
# Skip if we didn't find any lines to snap to, except for Pictures and Formulas
|
71 |
+
if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
|
72 |
+
continue
|
73 |
+
|
74 |
+
covered_lines = box_lines[bbox_idx]
|
75 |
+
# Snap non-picture layout boxes to correct text boundaries
|
76 |
+
if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
|
77 |
+
min_x = min([line[0] for line in covered_lines])
|
78 |
+
min_y = min([line[1] for line in covered_lines])
|
79 |
+
max_x = max([line[2] for line in covered_lines])
|
80 |
+
max_y = max([line[3] for line in covered_lines])
|
81 |
+
|
82 |
+
# Tables and formulas can contain text, but text isn't the whole area
|
83 |
+
if bbox.label in ["Table", "Formula"]:
|
84 |
+
min_x_box = min([b[0] for b in bbox.polygon])
|
85 |
+
min_y_box = min([b[1] for b in bbox.polygon])
|
86 |
+
max_x_box = max([b[0] for b in bbox.polygon])
|
87 |
+
max_y_box = max([b[1] for b in bbox.polygon])
|
88 |
+
|
89 |
+
min_x = min(min_x, min_x_box)
|
90 |
+
min_y = min(min_y, min_y_box)
|
91 |
+
max_x = max(max_x, max_x_box)
|
92 |
+
max_y = max(max_y, max_y_box)
|
93 |
+
|
94 |
+
bbox.polygon = [
|
95 |
+
[min_x, min_y],
|
96 |
+
[max_x, min_y],
|
97 |
+
[max_x, max_y],
|
98 |
+
[min_x, max_y]
|
99 |
+
]
|
100 |
+
|
101 |
+
if bbox_idx in box_lines and bbox.label in ["Picture"]:
|
102 |
+
bbox.label = "Figure"
|
103 |
+
|
104 |
+
new_boxes.append(bbox)
|
105 |
+
|
106 |
+
# Merge tables together (sometimes one column is detected as a separate table)
|
107 |
+
mergeable_types = ["Table", "Picture", "Figure"]
|
108 |
+
for ftype in mergeable_types:
|
109 |
+
to_remove = set()
|
110 |
+
for bbox_idx, bbox in enumerate(new_boxes):
|
111 |
+
if bbox.label != ftype or bbox_idx in to_remove:
|
112 |
+
continue
|
113 |
+
|
114 |
+
for bbox_idx2, bbox2 in enumerate(new_boxes):
|
115 |
+
if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
|
116 |
+
continue
|
117 |
+
|
118 |
+
if bbox.intersection_pct(bbox2, x_margin=.25) > .1:
|
119 |
+
bbox.merge(bbox2)
|
120 |
+
to_remove.add(bbox_idx2)
|
121 |
+
|
122 |
+
new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove]
|
123 |
+
|
124 |
+
# Ensure we account for all text lines in the layout
|
125 |
+
unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
|
126 |
+
for bbox in unused_lines:
|
127 |
+
new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))
|
128 |
+
|
129 |
+
for bbox in new_boxes:
|
130 |
+
bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)
|
131 |
+
|
132 |
+
detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]
|
133 |
+
|
134 |
+
# Remove bboxes contained inside others, unless they're captions
|
135 |
+
contained_bbox = []
|
136 |
+
for i, bbox in enumerate(detected_boxes):
|
137 |
+
for j, bbox2 in enumerate(detected_boxes):
|
138 |
+
if i == j:
|
139 |
+
continue
|
140 |
+
|
141 |
+
if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]:
|
142 |
+
contained_bbox.append(j)
|
143 |
+
|
144 |
+
detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox]
|
145 |
+
|
146 |
+
return detected_boxes
|
147 |
+
|
148 |
+
|
149 |
+
def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
|
150 |
+
bboxes = []
|
151 |
+
for i in range(1, len(id2label)): # Skip the blank class
|
152 |
+
heatmap = heatmaps[i]
|
153 |
+
assert heatmap.shape == segment_assignment.shape
|
154 |
+
heatmap[segment_assignment != i] = 0 # zero out where another segment is
|
155 |
+
|
156 |
+
# Skip processing empty labels
|
157 |
+
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
|
158 |
+
continue
|
159 |
+
|
160 |
+
bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
|
161 |
+
for bb in bbox:
|
162 |
+
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
|
163 |
+
|
164 |
+
bboxes = keep_largest_boxes(bboxes)
|
165 |
+
return bboxes
|
166 |
+
|
167 |
+
|
168 |
+
def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
|
169 |
+
logits = np.stack(heatmaps, axis=0)
|
170 |
+
segment_assignment = logits.argmax(axis=0)
|
171 |
+
if detection_results is not None:
|
172 |
+
bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label,
|
173 |
+
segment_assignment)
|
174 |
+
else:
|
175 |
+
bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)
|
176 |
+
|
177 |
+
segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))
|
178 |
+
|
179 |
+
result = LayoutResult(
|
180 |
+
bboxes=bboxes,
|
181 |
+
segmentation_map=segmentation_img,
|
182 |
+
heatmaps=heatmaps,
|
183 |
+
image_bbox=[0, 0, orig_size[0], orig_size[1]]
|
184 |
+
)
|
185 |
+
|
186 |
+
return result
|
187 |
+
|
188 |
+
|
189 |
+
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
|
190 |
+
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
|
191 |
+
id2label = model.config.id2label
|
192 |
+
|
193 |
+
results = []
|
194 |
+
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
195 |
+
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
|
196 |
+
|
197 |
+
if parallelize:
|
198 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
199 |
+
img_idx = 0
|
200 |
+
for preds, orig_sizes in layout_generator:
|
201 |
+
futures = []
|
202 |
+
for pred, orig_size in zip(preds, orig_sizes):
|
203 |
+
future = executor.submit(
|
204 |
+
parallel_get_regions,
|
205 |
+
pred,
|
206 |
+
orig_size,
|
207 |
+
id2label,
|
208 |
+
detection_results[img_idx] if detection_results else None
|
209 |
+
)
|
210 |
+
|
211 |
+
futures.append(future)
|
212 |
+
img_idx += 1
|
213 |
+
|
214 |
+
for future in futures:
|
215 |
+
results.append(future.result())
|
216 |
+
else:
|
217 |
+
img_idx = 0
|
218 |
+
for preds, orig_sizes in layout_generator:
|
219 |
+
for pred, orig_size in zip(preds, orig_sizes):
|
220 |
+
results.append(parallel_get_regions(
|
221 |
+
pred,
|
222 |
+
orig_size,
|
223 |
+
id2label,
|
224 |
+
detection_results[img_idx] if detection_results else None
|
225 |
+
))
|
226 |
+
|
227 |
+
img_idx += 1
|
228 |
+
|
229 |
+
return results
|
surya/model/detection/config.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class EfficientViTConfig(PretrainedConfig):
|
5 |
+
r"""
|
6 |
+
```"""
|
7 |
+
|
8 |
+
model_type = "efficientvit"
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
num_classes=2,
|
13 |
+
num_channels=3,
|
14 |
+
widths=(32, 64, 128, 256, 512),
|
15 |
+
head_dim=32,
|
16 |
+
num_stages=4,
|
17 |
+
depths=(1, 1, 1, 6, 6),
|
18 |
+
strides=(2, 2, 2, 2, 2),
|
19 |
+
hidden_sizes=(32, 64, 160, 256),
|
20 |
+
patch_size=(7, 7),
|
21 |
+
hidden_dropout_prob=0.0,
|
22 |
+
attention_probs_dropout_prob=0.0,
|
23 |
+
classifier_dropout_prob=0.0,
|
24 |
+
layer_norm_eps=1e-6,
|
25 |
+
decoder_layer_hidden_size=128,
|
26 |
+
decoder_hidden_size=512,
|
27 |
+
semantic_loss_ignore_index=255,
|
28 |
+
initializer_range=0.02,
|
29 |
+
**kwargs,
|
30 |
+
):
|
31 |
+
super().__init__(**kwargs)
|
32 |
+
|
33 |
+
self.num_classes = num_classes
|
34 |
+
self.widths = widths
|
35 |
+
self.head_dim = head_dim
|
36 |
+
|
37 |
+
self.num_channels = num_channels
|
38 |
+
self.num_stages = num_stages
|
39 |
+
self.depths = depths
|
40 |
+
self.strides = strides
|
41 |
+
self.hidden_sizes = hidden_sizes
|
42 |
+
self.patch_size = patch_size
|
43 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
44 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
45 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
46 |
+
self.layer_norm_eps = layer_norm_eps
|
47 |
+
self.decoder_hidden_size = decoder_hidden_size
|
48 |
+
self.decoder_layer_hidden_size = decoder_layer_hidden_size
|
49 |
+
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
50 |
+
|
51 |
+
self.initializer_range = initializer_range
|
surya/model/detection/model.py
ADDED
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is an implementation of efficientvit, with some modifications (decode head, etc).
|
3 |
+
|
4 |
+
Original paper at https://arxiv.org/abs/2205.14756
|
5 |
+
|
6 |
+
Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
|
7 |
+
Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
|
8 |
+
"""
|
9 |
+
|
10 |
+
from typing import Optional, Union, Tuple
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from transformers import PreTrainedModel
|
18 |
+
from transformers.modeling_outputs import SemanticSegmenterOutput
|
19 |
+
|
20 |
+
from surya.model.detection.config import EfficientViTConfig
|
21 |
+
from surya.model.detection.processor import SegformerImageProcessor
|
22 |
+
from surya.settings import settings
|
23 |
+
|
24 |
+
|
25 |
+
def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
26 |
+
config = EfficientViTConfig.from_pretrained(checkpoint)
|
27 |
+
model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, ignore_mismatched_sizes=True)
|
28 |
+
model = model.to(device)
|
29 |
+
model = model.eval()
|
30 |
+
print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT):
|
35 |
+
processor = SegformerImageProcessor.from_pretrained(checkpoint)
|
36 |
+
return processor
|
37 |
+
|
38 |
+
|
39 |
+
def val2list(x: list or tuple or any, repeat_time=1):
|
40 |
+
if isinstance(x, (list, tuple)):
|
41 |
+
return list(x)
|
42 |
+
return [x for _ in range(repeat_time)]
|
43 |
+
|
44 |
+
|
45 |
+
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
|
46 |
+
# repeat elements if necessary
|
47 |
+
x = val2list(x)
|
48 |
+
if len(x) > 0:
|
49 |
+
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
50 |
+
|
51 |
+
return tuple(x)
|
52 |
+
|
53 |
+
|
54 |
+
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
|
55 |
+
if isinstance(kernel_size, tuple):
|
56 |
+
return tuple([get_same_padding(ks) for ks in kernel_size])
|
57 |
+
else:
|
58 |
+
assert kernel_size % 2 > 0, "kernel size should be odd number"
|
59 |
+
return kernel_size // 2
|
60 |
+
|
61 |
+
|
62 |
+
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
|
63 |
+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
64 |
+
return padding
|
65 |
+
|
66 |
+
class ConvNormAct(nn.Module):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
in_channels: int,
|
70 |
+
out_channels: int,
|
71 |
+
kernel_size=3,
|
72 |
+
stride=1,
|
73 |
+
dilation=1,
|
74 |
+
groups=1,
|
75 |
+
bias=False,
|
76 |
+
dropout=0.,
|
77 |
+
norm_layer=nn.BatchNorm2d,
|
78 |
+
act_layer=nn.ReLU,
|
79 |
+
):
|
80 |
+
super(ConvNormAct, self).__init__()
|
81 |
+
self.dropout = nn.Dropout(dropout, inplace=False)
|
82 |
+
padding = get_padding(kernel_size, stride, dilation)
|
83 |
+
self.conv = nn.Conv2d(
|
84 |
+
in_channels,
|
85 |
+
out_channels,
|
86 |
+
kernel_size=kernel_size,
|
87 |
+
stride=stride,
|
88 |
+
dilation=dilation,
|
89 |
+
groups=groups,
|
90 |
+
bias=bias,
|
91 |
+
padding=padding,
|
92 |
+
)
|
93 |
+
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
|
94 |
+
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x = self.conv(x)
|
98 |
+
x = self.norm(x)
|
99 |
+
x = self.act(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class DSConv(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
in_channels: int,
|
107 |
+
out_channels: int,
|
108 |
+
kernel_size=3,
|
109 |
+
stride=1,
|
110 |
+
use_bias=False,
|
111 |
+
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
|
112 |
+
act_layer=(nn.ReLU6, None),
|
113 |
+
):
|
114 |
+
super(DSConv, self).__init__()
|
115 |
+
use_bias = val2tuple(use_bias, 2)
|
116 |
+
norm_layer = val2tuple(norm_layer, 2)
|
117 |
+
act_layer = val2tuple(act_layer, 2)
|
118 |
+
|
119 |
+
self.depth_conv = ConvNormAct(
|
120 |
+
in_channels,
|
121 |
+
in_channels,
|
122 |
+
kernel_size,
|
123 |
+
stride,
|
124 |
+
groups=in_channels,
|
125 |
+
norm_layer=norm_layer[0],
|
126 |
+
act_layer=act_layer[0],
|
127 |
+
bias=use_bias[0],
|
128 |
+
)
|
129 |
+
self.point_conv = ConvNormAct(
|
130 |
+
in_channels,
|
131 |
+
out_channels,
|
132 |
+
1,
|
133 |
+
norm_layer=norm_layer[1],
|
134 |
+
act_layer=act_layer[1],
|
135 |
+
bias=use_bias[1],
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.depth_conv(x)
|
140 |
+
x = self.point_conv(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class ConvBlock(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
in_channels: int,
|
148 |
+
out_channels: int,
|
149 |
+
kernel_size=3,
|
150 |
+
stride=1,
|
151 |
+
mid_channels=None,
|
152 |
+
expand_ratio=1,
|
153 |
+
use_bias=False,
|
154 |
+
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
|
155 |
+
act_layer=(nn.ReLU6, None),
|
156 |
+
):
|
157 |
+
super(ConvBlock, self).__init__()
|
158 |
+
use_bias = val2tuple(use_bias, 2)
|
159 |
+
norm_layer = val2tuple(norm_layer, 2)
|
160 |
+
act_layer = val2tuple(act_layer, 2)
|
161 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
162 |
+
|
163 |
+
self.conv1 = ConvNormAct(
|
164 |
+
in_channels,
|
165 |
+
mid_channels,
|
166 |
+
kernel_size,
|
167 |
+
stride,
|
168 |
+
norm_layer=norm_layer[0],
|
169 |
+
act_layer=act_layer[0],
|
170 |
+
bias=use_bias[0],
|
171 |
+
)
|
172 |
+
self.conv2 = ConvNormAct(
|
173 |
+
mid_channels,
|
174 |
+
out_channels,
|
175 |
+
kernel_size,
|
176 |
+
1,
|
177 |
+
norm_layer=norm_layer[1],
|
178 |
+
act_layer=act_layer[1],
|
179 |
+
bias=use_bias[1],
|
180 |
+
)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
x = self.conv1(x)
|
184 |
+
x = self.conv2(x)
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class MBConv(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
in_channels: int,
|
192 |
+
out_channels: int,
|
193 |
+
kernel_size=3,
|
194 |
+
stride=1,
|
195 |
+
mid_channels=None,
|
196 |
+
expand_ratio=6,
|
197 |
+
use_bias=False,
|
198 |
+
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
|
199 |
+
act_layer=(nn.ReLU6, nn.ReLU6, None),
|
200 |
+
):
|
201 |
+
super(MBConv, self).__init__()
|
202 |
+
use_bias = val2tuple(use_bias, 3)
|
203 |
+
norm_layer = val2tuple(norm_layer, 3)
|
204 |
+
act_layer = val2tuple(act_layer, 3)
|
205 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
206 |
+
|
207 |
+
self.inverted_conv = ConvNormAct(
|
208 |
+
in_channels,
|
209 |
+
mid_channels,
|
210 |
+
1,
|
211 |
+
stride=1,
|
212 |
+
norm_layer=norm_layer[0],
|
213 |
+
act_layer=act_layer[0],
|
214 |
+
bias=use_bias[0],
|
215 |
+
)
|
216 |
+
self.depth_conv = ConvNormAct(
|
217 |
+
mid_channels,
|
218 |
+
mid_channels,
|
219 |
+
kernel_size,
|
220 |
+
stride=stride,
|
221 |
+
groups=mid_channels,
|
222 |
+
norm_layer=norm_layer[1],
|
223 |
+
act_layer=act_layer[1],
|
224 |
+
bias=use_bias[1],
|
225 |
+
)
|
226 |
+
self.point_conv = ConvNormAct(
|
227 |
+
mid_channels,
|
228 |
+
out_channels,
|
229 |
+
1,
|
230 |
+
norm_layer=norm_layer[2],
|
231 |
+
act_layer=act_layer[2],
|
232 |
+
bias=use_bias[2],
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
x = self.inverted_conv(x)
|
237 |
+
x = self.depth_conv(x)
|
238 |
+
x = self.point_conv(x)
|
239 |
+
return x
|
240 |
+
|
241 |
+
|
242 |
+
class FusedMBConv(nn.Module):
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
in_channels: int,
|
246 |
+
out_channels: int,
|
247 |
+
kernel_size=3,
|
248 |
+
stride=1,
|
249 |
+
mid_channels=None,
|
250 |
+
expand_ratio=6,
|
251 |
+
groups=1,
|
252 |
+
use_bias=False,
|
253 |
+
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
|
254 |
+
act_layer=(nn.ReLU6, None),
|
255 |
+
):
|
256 |
+
super(FusedMBConv, self).__init__()
|
257 |
+
use_bias = val2tuple(use_bias, 2)
|
258 |
+
norm_layer = val2tuple(norm_layer, 2)
|
259 |
+
act_layer = val2tuple(act_layer, 2)
|
260 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
261 |
+
|
262 |
+
self.spatial_conv = ConvNormAct(
|
263 |
+
in_channels,
|
264 |
+
mid_channels,
|
265 |
+
kernel_size,
|
266 |
+
stride=stride,
|
267 |
+
groups=groups,
|
268 |
+
norm_layer=norm_layer[0],
|
269 |
+
act_layer=act_layer[0],
|
270 |
+
bias=use_bias[0],
|
271 |
+
)
|
272 |
+
self.point_conv = ConvNormAct(
|
273 |
+
mid_channels,
|
274 |
+
out_channels,
|
275 |
+
1,
|
276 |
+
norm_layer=norm_layer[1],
|
277 |
+
act_layer=act_layer[1],
|
278 |
+
bias=use_bias[1],
|
279 |
+
)
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
x = self.spatial_conv(x)
|
283 |
+
x = self.point_conv(x)
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
class LiteMLA(nn.Module):
|
288 |
+
"""Lightweight multi-scale linear attention"""
|
289 |
+
|
290 |
+
def __init__(
|
291 |
+
self,
|
292 |
+
in_channels: int,
|
293 |
+
out_channels: int,
|
294 |
+
heads: int or None = None,
|
295 |
+
heads_ratio: float = 1.0,
|
296 |
+
dim=8,
|
297 |
+
use_bias=False,
|
298 |
+
norm_layer=(None, nn.BatchNorm2d),
|
299 |
+
act_layer=(None, None),
|
300 |
+
kernel_func=nn.ReLU,
|
301 |
+
scales=(5,),
|
302 |
+
eps=1e-5,
|
303 |
+
):
|
304 |
+
super(LiteMLA, self).__init__()
|
305 |
+
self.eps = eps
|
306 |
+
heads = heads or int(in_channels // dim * heads_ratio)
|
307 |
+
total_dim = heads * dim
|
308 |
+
use_bias = val2tuple(use_bias, 2)
|
309 |
+
norm_layer = val2tuple(norm_layer, 2)
|
310 |
+
act_layer = val2tuple(act_layer, 2)
|
311 |
+
|
312 |
+
self.dim = dim
|
313 |
+
self.qkv = ConvNormAct(
|
314 |
+
in_channels,
|
315 |
+
3 * total_dim,
|
316 |
+
1,
|
317 |
+
bias=use_bias[0],
|
318 |
+
norm_layer=norm_layer[0],
|
319 |
+
act_layer=act_layer[0],
|
320 |
+
)
|
321 |
+
self.aggreg = nn.ModuleList([
|
322 |
+
nn.Sequential(
|
323 |
+
nn.Conv2d(
|
324 |
+
3 * total_dim,
|
325 |
+
3 * total_dim,
|
326 |
+
scale,
|
327 |
+
padding=get_same_padding(scale),
|
328 |
+
groups=3 * total_dim,
|
329 |
+
bias=use_bias[0],
|
330 |
+
),
|
331 |
+
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
|
332 |
+
)
|
333 |
+
for scale in scales
|
334 |
+
])
|
335 |
+
self.kernel_func = kernel_func(inplace=False)
|
336 |
+
|
337 |
+
self.proj = ConvNormAct(
|
338 |
+
total_dim * (1 + len(scales)),
|
339 |
+
out_channels,
|
340 |
+
1,
|
341 |
+
bias=use_bias[1],
|
342 |
+
norm_layer=norm_layer[1],
|
343 |
+
act_layer=act_layer[1],
|
344 |
+
)
|
345 |
+
|
346 |
+
def _attn(self, q, k, v):
|
347 |
+
dtype = v.dtype
|
348 |
+
q, k, v = q.float(), k.float(), v.float()
|
349 |
+
kv = k.transpose(-1, -2) @ v
|
350 |
+
out = q @ kv
|
351 |
+
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
352 |
+
return out.to(dtype)
|
353 |
+
|
354 |
+
def forward(self, x):
|
355 |
+
# Shape is B, C, H, W
|
356 |
+
B, _, H, W = x.shape
|
357 |
+
|
358 |
+
# generate multi-scale q, k, v
|
359 |
+
qkv = self.qkv(x)
|
360 |
+
multi_scale_qkv = [qkv]
|
361 |
+
for op in self.aggreg:
|
362 |
+
multi_scale_qkv.append(op(qkv))
|
363 |
+
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
364 |
+
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
|
365 |
+
# Shape for each is B, C, HW, head_dim
|
366 |
+
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
|
367 |
+
|
368 |
+
# lightweight global attention
|
369 |
+
q = self.kernel_func(q)
|
370 |
+
k = self.kernel_func(k)
|
371 |
+
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
372 |
+
|
373 |
+
out = self._attn(q, k, v)
|
374 |
+
|
375 |
+
# final projection
|
376 |
+
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
377 |
+
out = self.proj(out)
|
378 |
+
return out
|
379 |
+
|
380 |
+
|
381 |
+
class EfficientVitBlock(nn.Module):
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
in_channels,
|
385 |
+
heads_ratio=1.0,
|
386 |
+
head_dim=32,
|
387 |
+
expand_ratio=4,
|
388 |
+
norm_layer=nn.BatchNorm2d,
|
389 |
+
act_layer=nn.Hardswish,
|
390 |
+
):
|
391 |
+
super(EfficientVitBlock, self).__init__()
|
392 |
+
self.context_module = ResidualBlock(
|
393 |
+
LiteMLA(
|
394 |
+
in_channels=in_channels,
|
395 |
+
out_channels=in_channels,
|
396 |
+
heads_ratio=heads_ratio,
|
397 |
+
dim=head_dim,
|
398 |
+
norm_layer=(None, norm_layer),
|
399 |
+
),
|
400 |
+
nn.Identity(),
|
401 |
+
)
|
402 |
+
self.local_module = ResidualBlock(
|
403 |
+
MBConv(
|
404 |
+
in_channels=in_channels,
|
405 |
+
out_channels=in_channels,
|
406 |
+
expand_ratio=expand_ratio,
|
407 |
+
use_bias=(True, True, False),
|
408 |
+
norm_layer=(None, None, norm_layer),
|
409 |
+
act_layer=(act_layer, act_layer, None),
|
410 |
+
),
|
411 |
+
nn.Identity(),
|
412 |
+
)
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
x = self.context_module(x)
|
416 |
+
x = self.local_module(x)
|
417 |
+
return x
|
418 |
+
|
419 |
+
|
420 |
+
class ResidualBlock(nn.Module):
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
main: Optional[nn.Module],
|
424 |
+
shortcut: Optional[nn.Module] = None,
|
425 |
+
pre_norm: Optional[nn.Module] = None,
|
426 |
+
):
|
427 |
+
super(ResidualBlock, self).__init__()
|
428 |
+
self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
|
429 |
+
self.main = main
|
430 |
+
self.shortcut = shortcut
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
res = self.main(self.pre_norm(x))
|
434 |
+
if self.shortcut is not None:
|
435 |
+
res = res + self.shortcut(x)
|
436 |
+
return res
|
437 |
+
|
438 |
+
|
439 |
+
def build_local_block(
|
440 |
+
in_channels: int,
|
441 |
+
out_channels: int,
|
442 |
+
stride: int,
|
443 |
+
kernel_size: int,
|
444 |
+
expand_ratio: float,
|
445 |
+
norm_layer: str,
|
446 |
+
act_layer: str,
|
447 |
+
fewer_norm: bool = False,
|
448 |
+
block_type: str = "default",
|
449 |
+
):
|
450 |
+
assert block_type in ["default", "large", "fused"]
|
451 |
+
if expand_ratio == 1:
|
452 |
+
if block_type == "default":
|
453 |
+
block = DSConv(
|
454 |
+
in_channels=in_channels,
|
455 |
+
out_channels=out_channels,
|
456 |
+
stride=stride,
|
457 |
+
kernel_size=kernel_size,
|
458 |
+
use_bias=(True, False) if fewer_norm else False,
|
459 |
+
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
|
460 |
+
act_layer=(act_layer, None),
|
461 |
+
)
|
462 |
+
else:
|
463 |
+
block = ConvBlock(
|
464 |
+
in_channels=in_channels,
|
465 |
+
out_channels=out_channels,
|
466 |
+
stride=stride,
|
467 |
+
kernel_size=kernel_size,
|
468 |
+
use_bias=(True, False) if fewer_norm else False,
|
469 |
+
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
|
470 |
+
act_layer=(act_layer, None),
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
if block_type == "default":
|
474 |
+
block = MBConv(
|
475 |
+
in_channels=in_channels,
|
476 |
+
out_channels=out_channels,
|
477 |
+
stride=stride,
|
478 |
+
kernel_size=kernel_size,
|
479 |
+
expand_ratio=expand_ratio,
|
480 |
+
use_bias=(True, True, False) if fewer_norm else False,
|
481 |
+
norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
|
482 |
+
act_layer=(act_layer, act_layer, None),
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
block = FusedMBConv(
|
486 |
+
in_channels=in_channels,
|
487 |
+
out_channels=out_channels,
|
488 |
+
stride=stride,
|
489 |
+
kernel_size=kernel_size,
|
490 |
+
expand_ratio=expand_ratio,
|
491 |
+
use_bias=(True, False) if fewer_norm else False,
|
492 |
+
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
|
493 |
+
act_layer=(act_layer, None),
|
494 |
+
)
|
495 |
+
return block
|
496 |
+
|
497 |
+
|
498 |
+
class Stem(nn.Sequential):
|
499 |
+
def __init__(self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type='default'):
|
500 |
+
super().__init__()
|
501 |
+
self.stride = stride
|
502 |
+
|
503 |
+
self.add_module(
|
504 |
+
'in_conv',
|
505 |
+
ConvNormAct(
|
506 |
+
in_chs, out_chs,
|
507 |
+
kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer,
|
508 |
+
)
|
509 |
+
)
|
510 |
+
stem_block = 0
|
511 |
+
for _ in range(depth):
|
512 |
+
self.add_module(f'res{stem_block}', ResidualBlock(
|
513 |
+
build_local_block(
|
514 |
+
in_channels=out_chs,
|
515 |
+
out_channels=out_chs,
|
516 |
+
stride=1,
|
517 |
+
kernel_size=3,
|
518 |
+
expand_ratio=1,
|
519 |
+
norm_layer=norm_layer,
|
520 |
+
act_layer=act_layer,
|
521 |
+
block_type=block_type,
|
522 |
+
),
|
523 |
+
nn.Identity(),
|
524 |
+
))
|
525 |
+
stem_block += 1
|
526 |
+
|
527 |
+
|
528 |
+
class EfficientVitLargeStage(nn.Module):
|
529 |
+
def __init__(
|
530 |
+
self,
|
531 |
+
in_chs,
|
532 |
+
out_chs,
|
533 |
+
depth,
|
534 |
+
stride,
|
535 |
+
norm_layer,
|
536 |
+
act_layer,
|
537 |
+
head_dim,
|
538 |
+
vit_stage=False,
|
539 |
+
fewer_norm=False,
|
540 |
+
):
|
541 |
+
super(EfficientVitLargeStage, self).__init__()
|
542 |
+
blocks = [ResidualBlock(
|
543 |
+
build_local_block(
|
544 |
+
in_channels=in_chs,
|
545 |
+
out_channels=out_chs,
|
546 |
+
stride=stride,
|
547 |
+
kernel_size=stride + 1,
|
548 |
+
expand_ratio=24 if vit_stage else 16,
|
549 |
+
norm_layer=norm_layer,
|
550 |
+
act_layer=act_layer,
|
551 |
+
fewer_norm=vit_stage or fewer_norm,
|
552 |
+
block_type='default' if fewer_norm else 'fused',
|
553 |
+
),
|
554 |
+
None,
|
555 |
+
)]
|
556 |
+
in_chs = out_chs
|
557 |
+
|
558 |
+
if vit_stage:
|
559 |
+
# for stage 4
|
560 |
+
for _ in range(depth):
|
561 |
+
blocks.append(
|
562 |
+
EfficientVitBlock(
|
563 |
+
in_channels=in_chs,
|
564 |
+
head_dim=head_dim,
|
565 |
+
expand_ratio=6,
|
566 |
+
norm_layer=norm_layer,
|
567 |
+
act_layer=act_layer,
|
568 |
+
)
|
569 |
+
)
|
570 |
+
else:
|
571 |
+
# for stage 1, 2, 3
|
572 |
+
for i in range(depth):
|
573 |
+
blocks.append(ResidualBlock(
|
574 |
+
build_local_block(
|
575 |
+
in_channels=in_chs,
|
576 |
+
out_channels=out_chs,
|
577 |
+
stride=1,
|
578 |
+
kernel_size=3,
|
579 |
+
expand_ratio=4,
|
580 |
+
norm_layer=norm_layer,
|
581 |
+
act_layer=act_layer,
|
582 |
+
fewer_norm=fewer_norm,
|
583 |
+
block_type='default' if fewer_norm else 'fused',
|
584 |
+
),
|
585 |
+
nn.Identity(),
|
586 |
+
))
|
587 |
+
|
588 |
+
self.blocks = nn.Sequential(*blocks)
|
589 |
+
|
590 |
+
def forward(self, x):
|
591 |
+
return self.blocks(x)
|
592 |
+
|
593 |
+
|
594 |
+
class EfficientVitLarge(nn.Module):
|
595 |
+
def __init__(
|
596 |
+
self,
|
597 |
+
config: EfficientViTConfig,
|
598 |
+
norm_layer=nn.BatchNorm2d,
|
599 |
+
act_layer=nn.Hardswish,
|
600 |
+
):
|
601 |
+
super(EfficientVitLarge, self).__init__()
|
602 |
+
self.grad_checkpointing = False
|
603 |
+
self.num_classes = config.num_classes
|
604 |
+
self.norm_eps = config.layer_norm_eps
|
605 |
+
norm_layer = partial(norm_layer, eps=self.norm_eps)
|
606 |
+
|
607 |
+
# input stem
|
608 |
+
self.stem = Stem(config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type='large')
|
609 |
+
stride = config.strides[0]
|
610 |
+
|
611 |
+
# stages
|
612 |
+
self.feature_info = []
|
613 |
+
self.stages = nn.Sequential()
|
614 |
+
in_channels = config.widths[0]
|
615 |
+
for i, (w, d, s) in enumerate(zip(config.widths[1:], config.depths[1:], config.strides[1:])):
|
616 |
+
self.stages.append(EfficientVitLargeStage(
|
617 |
+
in_channels,
|
618 |
+
w,
|
619 |
+
depth=d,
|
620 |
+
stride=s,
|
621 |
+
norm_layer=norm_layer,
|
622 |
+
act_layer=act_layer,
|
623 |
+
head_dim=config.head_dim,
|
624 |
+
vit_stage=i >= 3,
|
625 |
+
fewer_norm=i >= 2,
|
626 |
+
))
|
627 |
+
stride *= s
|
628 |
+
in_channels = w
|
629 |
+
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
|
630 |
+
|
631 |
+
self.num_features = in_channels
|
632 |
+
|
633 |
+
@torch.jit.ignore
|
634 |
+
def set_grad_checkpointing(self, enable=True):
|
635 |
+
self.grad_checkpointing = enable
|
636 |
+
|
637 |
+
def forward(self, x):
|
638 |
+
x = self.stem(x)
|
639 |
+
encoder_hidden_states = []
|
640 |
+
for i, module in enumerate(self.stages):
|
641 |
+
x = module(x)
|
642 |
+
encoder_hidden_states.append(x)
|
643 |
+
|
644 |
+
return encoder_hidden_states
|
645 |
+
|
646 |
+
|
647 |
+
class EfficientViTPreTrainedModel(PreTrainedModel):
|
648 |
+
"""
|
649 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
650 |
+
models.
|
651 |
+
"""
|
652 |
+
|
653 |
+
config_class = EfficientViTConfig
|
654 |
+
base_model_prefix = "efficientvit"
|
655 |
+
main_input_name = "pixel_values"
|
656 |
+
|
657 |
+
def _init_weights(self, module):
|
658 |
+
"""Initialize the weights"""
|
659 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
660 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
661 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
662 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
663 |
+
if module.bias is not None:
|
664 |
+
module.bias.data.zero_()
|
665 |
+
elif isinstance(module, nn.Embedding):
|
666 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
667 |
+
if module.padding_idx is not None:
|
668 |
+
module.weight.data[module.padding_idx].zero_()
|
669 |
+
elif isinstance(module, nn.LayerNorm):
|
670 |
+
module.bias.data.zero_()
|
671 |
+
module.weight.data.fill_(1.0)
|
672 |
+
|
673 |
+
|
674 |
+
class DecodeMLP(nn.Module):
|
675 |
+
def __init__(self, input_dim, output_dim):
|
676 |
+
super().__init__()
|
677 |
+
self.proj = nn.Linear(input_dim, output_dim)
|
678 |
+
|
679 |
+
def forward(self, hidden_states: torch.Tensor):
|
680 |
+
# Input is B, C, H, W
|
681 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
682 |
+
# Output is B, HW, C
|
683 |
+
hidden_states = self.proj(hidden_states)
|
684 |
+
return hidden_states
|
685 |
+
|
686 |
+
|
687 |
+
class DecodeHead(EfficientViTPreTrainedModel):
|
688 |
+
def __init__(self, config: EfficientViTConfig):
|
689 |
+
super().__init__(config)
|
690 |
+
|
691 |
+
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
|
692 |
+
mlps = []
|
693 |
+
for width in config.widths[1:]:
|
694 |
+
mlp = DecodeMLP(input_dim=width, output_dim=config.decoder_layer_hidden_size)
|
695 |
+
mlps.append(mlp)
|
696 |
+
self.linear_c = nn.ModuleList(mlps)
|
697 |
+
|
698 |
+
# the following 3 layers implement the ConvModule of the original implementation
|
699 |
+
self.linear_fuse = nn.Conv2d(
|
700 |
+
in_channels=config.decoder_layer_hidden_size * config.num_stages,
|
701 |
+
out_channels=config.decoder_hidden_size,
|
702 |
+
kernel_size=1,
|
703 |
+
bias=False,
|
704 |
+
)
|
705 |
+
self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
|
706 |
+
self.activation = nn.ReLU()
|
707 |
+
|
708 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
709 |
+
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
|
710 |
+
|
711 |
+
self.config = config
|
712 |
+
|
713 |
+
def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
|
714 |
+
batch_size = encoder_hidden_states[-1].shape[0]
|
715 |
+
|
716 |
+
all_hidden_states = ()
|
717 |
+
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
|
718 |
+
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
|
719 |
+
encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C
|
720 |
+
# Permute to B, C, HW
|
721 |
+
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
|
722 |
+
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
|
723 |
+
# upsample
|
724 |
+
encoder_hidden_state = nn.functional.interpolate(
|
725 |
+
encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
|
726 |
+
)
|
727 |
+
all_hidden_states += (encoder_hidden_state,)
|
728 |
+
|
729 |
+
hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
|
730 |
+
hidden_states = self.batch_norm(hidden_states)
|
731 |
+
hidden_states = self.activation(hidden_states)
|
732 |
+
|
733 |
+
# logits are of shape (batch_size, num_labels, height/4, width/4)
|
734 |
+
logits = self.classifier(hidden_states)
|
735 |
+
|
736 |
+
return logits
|
737 |
+
|
738 |
+
|
739 |
+
class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel):
|
740 |
+
def __init__(self, config, **kwargs):
|
741 |
+
super().__init__(config)
|
742 |
+
self.vit = EfficientVitLarge(config)
|
743 |
+
self.decode_head = DecodeHead(config)
|
744 |
+
|
745 |
+
# Initialize weights and apply final processing
|
746 |
+
self.post_init()
|
747 |
+
|
748 |
+
def forward(
|
749 |
+
self,
|
750 |
+
pixel_values: torch.FloatTensor
|
751 |
+
) -> Union[Tuple, SemanticSegmenterOutput]:
|
752 |
+
|
753 |
+
# Pixel values should be B,C,H,W
|
754 |
+
encoder_hidden_states = self.vit(
|
755 |
+
pixel_values,
|
756 |
+
)
|
757 |
+
|
758 |
+
logits = self.decode_head(encoder_hidden_states)
|
759 |
+
|
760 |
+
# Apply sigmoid to get 0-1 output
|
761 |
+
logits = torch.special.expit(logits)
|
762 |
+
|
763 |
+
return SemanticSegmenterOutput(
|
764 |
+
loss=None,
|
765 |
+
logits=logits,
|
766 |
+
hidden_states=encoder_hidden_states
|
767 |
+
)
|
surya/model/detection/processor.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Any, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
7 |
+
from transformers.image_transforms import to_channel_dimension_format
|
8 |
+
from transformers.image_utils import (
|
9 |
+
IMAGENET_DEFAULT_MEAN,
|
10 |
+
IMAGENET_DEFAULT_STD,
|
11 |
+
ChannelDimension,
|
12 |
+
ImageInput,
|
13 |
+
PILImageResampling,
|
14 |
+
infer_channel_dimension_format,
|
15 |
+
make_list_of_images,
|
16 |
+
)
|
17 |
+
from transformers.utils import TensorType
|
18 |
+
|
19 |
+
|
20 |
+
import PIL.Image
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
class SegformerImageProcessor(BaseImageProcessor):
|
25 |
+
r"""
|
26 |
+
Constructs a Segformer image processor.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
30 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
31 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
32 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
|
33 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
34 |
+
method.
|
35 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
36 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
37 |
+
`preprocess` method.
|
38 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
40 |
+
parameter in the `preprocess` method.
|
41 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
42 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
43 |
+
method.
|
44 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
45 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
46 |
+
method.
|
47 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
48 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
49 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
50 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
51 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
52 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
53 |
+
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
54 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
55 |
+
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
56 |
+
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
57 |
+
`preprocess` method.
|
58 |
+
"""
|
59 |
+
|
60 |
+
model_input_names = ["pixel_values"]
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
do_resize: bool = True,
|
65 |
+
size: Dict[str, int] = None,
|
66 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
67 |
+
do_rescale: bool = True,
|
68 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
69 |
+
do_normalize: bool = True,
|
70 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
71 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
72 |
+
do_reduce_labels: bool = False,
|
73 |
+
**kwargs,
|
74 |
+
) -> None:
|
75 |
+
if "reduce_labels" in kwargs:
|
76 |
+
warnings.warn(
|
77 |
+
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
|
78 |
+
"`do_reduce_labels` instead.",
|
79 |
+
FutureWarning,
|
80 |
+
)
|
81 |
+
do_reduce_labels = kwargs.pop("reduce_labels")
|
82 |
+
|
83 |
+
super().__init__(**kwargs)
|
84 |
+
size = size if size is not None else {"height": 512, "width": 512}
|
85 |
+
size = get_size_dict(size)
|
86 |
+
self.do_resize = do_resize
|
87 |
+
self.size = size
|
88 |
+
self.resample = resample
|
89 |
+
self.do_rescale = do_rescale
|
90 |
+
self.rescale_factor = rescale_factor
|
91 |
+
self.do_normalize = do_normalize
|
92 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
93 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
94 |
+
self.do_reduce_labels = do_reduce_labels
|
95 |
+
self._valid_processor_keys = [
|
96 |
+
"images",
|
97 |
+
"segmentation_maps",
|
98 |
+
"do_resize",
|
99 |
+
"size",
|
100 |
+
"resample",
|
101 |
+
"do_rescale",
|
102 |
+
"rescale_factor",
|
103 |
+
"do_normalize",
|
104 |
+
"image_mean",
|
105 |
+
"image_std",
|
106 |
+
"do_reduce_labels",
|
107 |
+
"return_tensors",
|
108 |
+
"data_format",
|
109 |
+
"input_data_format",
|
110 |
+
]
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
114 |
+
"""
|
115 |
+
Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
|
116 |
+
processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
|
117 |
+
reduce_labels=True)`
|
118 |
+
"""
|
119 |
+
image_processor_dict = image_processor_dict.copy()
|
120 |
+
if "reduce_labels" in kwargs:
|
121 |
+
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
|
122 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
123 |
+
|
124 |
+
def _preprocess(
|
125 |
+
self,
|
126 |
+
image: ImageInput,
|
127 |
+
do_resize: bool,
|
128 |
+
do_rescale: bool,
|
129 |
+
do_normalize: bool,
|
130 |
+
size: Optional[Dict[str, int]] = None,
|
131 |
+
resample: PILImageResampling = None,
|
132 |
+
rescale_factor: Optional[float] = None,
|
133 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
134 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
135 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
136 |
+
):
|
137 |
+
|
138 |
+
if do_rescale:
|
139 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
140 |
+
|
141 |
+
if do_normalize:
|
142 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
143 |
+
|
144 |
+
return image
|
145 |
+
|
146 |
+
def _preprocess_image(
|
147 |
+
self,
|
148 |
+
image: ImageInput,
|
149 |
+
do_resize: bool = None,
|
150 |
+
size: Dict[str, int] = None,
|
151 |
+
resample: PILImageResampling = None,
|
152 |
+
do_rescale: bool = None,
|
153 |
+
rescale_factor: float = None,
|
154 |
+
do_normalize: bool = None,
|
155 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
156 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
157 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
158 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
159 |
+
) -> np.ndarray:
|
160 |
+
"""Preprocesses a single image."""
|
161 |
+
# All transformations expect numpy arrays.
|
162 |
+
if input_data_format is None:
|
163 |
+
input_data_format = infer_channel_dimension_format(image)
|
164 |
+
|
165 |
+
image = self._preprocess(
|
166 |
+
image=image,
|
167 |
+
do_resize=do_resize,
|
168 |
+
size=size,
|
169 |
+
resample=resample,
|
170 |
+
do_rescale=do_rescale,
|
171 |
+
rescale_factor=rescale_factor,
|
172 |
+
do_normalize=do_normalize,
|
173 |
+
image_mean=image_mean,
|
174 |
+
image_std=image_std,
|
175 |
+
input_data_format=input_data_format,
|
176 |
+
)
|
177 |
+
if data_format is not None:
|
178 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
179 |
+
return image
|
180 |
+
|
181 |
+
def __call__(self, images, segmentation_maps=None, **kwargs):
|
182 |
+
"""
|
183 |
+
Preprocesses a batch of images and optionally segmentation maps.
|
184 |
+
|
185 |
+
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
|
186 |
+
passed in as positional arguments.
|
187 |
+
"""
|
188 |
+
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
189 |
+
|
190 |
+
def preprocess(
|
191 |
+
self,
|
192 |
+
images: ImageInput,
|
193 |
+
segmentation_maps: Optional[ImageInput] = None,
|
194 |
+
do_resize: Optional[bool] = None,
|
195 |
+
size: Optional[Dict[str, int]] = None,
|
196 |
+
resample: PILImageResampling = None,
|
197 |
+
do_rescale: Optional[bool] = None,
|
198 |
+
rescale_factor: Optional[float] = None,
|
199 |
+
do_normalize: Optional[bool] = None,
|
200 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
201 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
202 |
+
do_reduce_labels: Optional[bool] = None,
|
203 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
204 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
205 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
206 |
+
**kwargs,
|
207 |
+
) -> PIL.Image.Image:
|
208 |
+
"""
|
209 |
+
Preprocess an image or batch of images.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
images (`ImageInput`):
|
213 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
214 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
215 |
+
segmentation_maps (`ImageInput`, *optional*):
|
216 |
+
Segmentation map to preprocess.
|
217 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
218 |
+
Whether to resize the image.
|
219 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
220 |
+
Size of the image after `resize` is applied.
|
221 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
222 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
223 |
+
has an effect if `do_resize` is set to `True`.
|
224 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
225 |
+
Whether to rescale the image values between [0 - 1].
|
226 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
227 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
228 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
229 |
+
Whether to normalize the image.
|
230 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
231 |
+
Image mean.
|
232 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
233 |
+
Image standard deviation.
|
234 |
+
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
235 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
236 |
+
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
237 |
+
ADE20k). The background label will be replaced by 255.
|
238 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
239 |
+
The type of tensors to return. Can be one of:
|
240 |
+
- Unset: Return a list of `np.ndarray`.
|
241 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
242 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
243 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
244 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
245 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
246 |
+
The channel dimension format for the output image. Can be one of:
|
247 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
248 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
249 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
250 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
251 |
+
from the input image. Can be one of:
|
252 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
253 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
254 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
255 |
+
"""
|
256 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
257 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
258 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
259 |
+
resample = resample if resample is not None else self.resample
|
260 |
+
size = size if size is not None else self.size
|
261 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
262 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
263 |
+
image_std = image_std if image_std is not None else self.image_std
|
264 |
+
|
265 |
+
images = make_list_of_images(images)
|
266 |
+
images = [
|
267 |
+
self._preprocess_image(
|
268 |
+
image=img,
|
269 |
+
do_resize=do_resize,
|
270 |
+
resample=resample,
|
271 |
+
size=size,
|
272 |
+
do_rescale=do_rescale,
|
273 |
+
rescale_factor=rescale_factor,
|
274 |
+
do_normalize=do_normalize,
|
275 |
+
image_mean=image_mean,
|
276 |
+
image_std=image_std,
|
277 |
+
data_format=data_format,
|
278 |
+
input_data_format=input_data_format,
|
279 |
+
)
|
280 |
+
for img in images
|
281 |
+
]
|
282 |
+
|
283 |
+
data = {"pixel_values": images}
|
284 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
surya/model/ordering/config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import MBartConfig, DonutSwinConfig
|
2 |
+
|
3 |
+
|
4 |
+
class MBartOrderConfig(MBartConfig):
|
5 |
+
pass
|
6 |
+
|
7 |
+
class VariableDonutSwinConfig(DonutSwinConfig):
|
8 |
+
pass
|
surya/model/ordering/decoder.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, List, Union, Tuple
|
3 |
+
|
4 |
+
from transformers import MBartForCausalLM, MBartConfig
|
5 |
+
from torch import nn
|
6 |
+
from transformers.activations import ACT2FN
|
7 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
|
8 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
|
9 |
+
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer
|
10 |
+
from surya.model.ordering.config import MBartOrderConfig
|
11 |
+
import torch
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
From llama
|
18 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
19 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
20 |
+
"""
|
21 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
22 |
+
if n_rep == 1:
|
23 |
+
return hidden_states
|
24 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
25 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
26 |
+
|
27 |
+
|
28 |
+
class MBartGQAttention(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
embed_dim: int,
|
32 |
+
num_heads: int,
|
33 |
+
num_kv_heads: int,
|
34 |
+
dropout: float = 0.0,
|
35 |
+
is_decoder: bool = False,
|
36 |
+
bias: bool = True,
|
37 |
+
is_causal: bool = False,
|
38 |
+
config: Optional[MBartConfig] = None,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.embed_dim = embed_dim
|
42 |
+
self.num_heads = num_heads
|
43 |
+
self.num_kv_heads = num_kv_heads
|
44 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
45 |
+
|
46 |
+
assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
47 |
+
assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
48 |
+
|
49 |
+
self.dropout = dropout
|
50 |
+
self.head_dim = embed_dim // num_heads
|
51 |
+
self.config = config
|
52 |
+
|
53 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
54 |
+
raise ValueError(
|
55 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
56 |
+
f" and `num_heads`: {num_heads})."
|
57 |
+
)
|
58 |
+
self.scaling = self.head_dim**-0.5
|
59 |
+
self.is_decoder = is_decoder
|
60 |
+
self.is_causal = is_causal
|
61 |
+
|
62 |
+
self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
|
63 |
+
self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
|
64 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
65 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
66 |
+
|
67 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
68 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
69 |
+
|
70 |
+
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
71 |
+
return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
hidden_states: torch.Tensor,
|
76 |
+
key_value_states: Optional[torch.Tensor] = None,
|
77 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
78 |
+
attention_mask: Optional[torch.Tensor] = None,
|
79 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
80 |
+
output_attentions: bool = False,
|
81 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
82 |
+
"""Input shape: Batch x Time x Channel"""
|
83 |
+
|
84 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
85 |
+
# for the decoder
|
86 |
+
is_cross_attention = key_value_states is not None
|
87 |
+
|
88 |
+
bsz, tgt_len, _ = hidden_states.size()
|
89 |
+
|
90 |
+
# get query proj
|
91 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
92 |
+
# get key, value proj
|
93 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
94 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
95 |
+
# the provided `key_value_states` to support prefix tuning
|
96 |
+
if (
|
97 |
+
is_cross_attention
|
98 |
+
and past_key_value is not None
|
99 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
100 |
+
):
|
101 |
+
# reuse k,v, cross_attentions
|
102 |
+
key_states = past_key_value[0]
|
103 |
+
value_states = past_key_value[1]
|
104 |
+
elif is_cross_attention:
|
105 |
+
# cross_attentions
|
106 |
+
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
|
107 |
+
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
|
108 |
+
elif past_key_value is not None:
|
109 |
+
# reuse k, v, self_attention
|
110 |
+
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
|
111 |
+
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
|
112 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
113 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
114 |
+
else:
|
115 |
+
# self_attention
|
116 |
+
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
|
117 |
+
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
|
118 |
+
|
119 |
+
if self.is_decoder:
|
120 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
121 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
122 |
+
# key/value_states (first "if" case)
|
123 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
124 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
125 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
126 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
127 |
+
past_key_value = (key_states, value_states)
|
128 |
+
|
129 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
130 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
131 |
+
|
132 |
+
# Expand kv heads, then match query shape
|
133 |
+
key_states = repeat_kv(key_states, self.num_kv_groups)
|
134 |
+
value_states = repeat_kv(value_states, self.num_kv_groups)
|
135 |
+
key_states = key_states.reshape(*proj_shape)
|
136 |
+
value_states = value_states.reshape(*proj_shape)
|
137 |
+
|
138 |
+
src_len = key_states.size(1)
|
139 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
140 |
+
|
141 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
142 |
+
raise ValueError(
|
143 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
144 |
+
f" {attn_weights.size()}"
|
145 |
+
)
|
146 |
+
|
147 |
+
if attention_mask is not None:
|
148 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
149 |
+
raise ValueError(
|
150 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
151 |
+
)
|
152 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
153 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
154 |
+
|
155 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
156 |
+
|
157 |
+
if layer_head_mask is not None:
|
158 |
+
if layer_head_mask.size() != (self.num_heads,):
|
159 |
+
raise ValueError(
|
160 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
161 |
+
f" {layer_head_mask.size()}"
|
162 |
+
)
|
163 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
164 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
165 |
+
|
166 |
+
if output_attentions:
|
167 |
+
# this operation is a bit awkward, but it's required to
|
168 |
+
# make sure that attn_weights keeps its gradient.
|
169 |
+
# In order to do so, attn_weights have to be reshaped
|
170 |
+
# twice and have to be reused in the following
|
171 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
172 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
173 |
+
else:
|
174 |
+
attn_weights_reshaped = None
|
175 |
+
|
176 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
177 |
+
|
178 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
179 |
+
|
180 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
181 |
+
raise ValueError(
|
182 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
183 |
+
f" {attn_output.size()}"
|
184 |
+
)
|
185 |
+
|
186 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
187 |
+
attn_output = attn_output.transpose(1, 2)
|
188 |
+
|
189 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
190 |
+
# partitioned across GPUs when using tensor-parallelism.
|
191 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
192 |
+
|
193 |
+
attn_output = self.out_proj(attn_output)
|
194 |
+
|
195 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
196 |
+
|
197 |
+
|
198 |
+
MBART_ATTENTION_CLASSES = {
|
199 |
+
"eager": MBartGQAttention,
|
200 |
+
"flash_attention_2": None
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
class MBartOrderDecoderLayer(MBartDecoderLayer):
|
205 |
+
def __init__(self, config: MBartConfig):
|
206 |
+
nn.Module.__init__(self)
|
207 |
+
self.embed_dim = config.d_model
|
208 |
+
|
209 |
+
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
210 |
+
embed_dim=self.embed_dim,
|
211 |
+
num_heads=config.decoder_attention_heads,
|
212 |
+
num_kv_heads=config.kv_heads,
|
213 |
+
dropout=config.attention_dropout,
|
214 |
+
is_decoder=True,
|
215 |
+
is_causal=True,
|
216 |
+
config=config,
|
217 |
+
)
|
218 |
+
self.dropout = config.dropout
|
219 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
220 |
+
self.activation_dropout = config.activation_dropout
|
221 |
+
|
222 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
223 |
+
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
224 |
+
self.embed_dim,
|
225 |
+
config.decoder_attention_heads,
|
226 |
+
num_kv_heads=config.kv_heads,
|
227 |
+
dropout=config.attention_dropout,
|
228 |
+
is_decoder=True,
|
229 |
+
config=config,
|
230 |
+
)
|
231 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
232 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
233 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
234 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
235 |
+
|
236 |
+
|
237 |
+
class BboxEmbedding(nn.Module):
|
238 |
+
def __init__(self, config):
|
239 |
+
super().__init__()
|
240 |
+
self.x1_embed = nn.Embedding(config.max_width, config.d_model)
|
241 |
+
self.y1_embed = nn.Embedding(config.max_height, config.d_model)
|
242 |
+
self.x2_embed = nn.Embedding(config.max_width, config.d_model)
|
243 |
+
self.y2_embed = nn.Embedding(config.max_height, config.d_model)
|
244 |
+
self.w_embed = nn.Embedding(config.max_width, config.d_model)
|
245 |
+
self.h_embed = nn.Embedding(config.max_height, config.d_model)
|
246 |
+
self.cx_embed = nn.Embedding(config.max_width, config.d_model)
|
247 |
+
self.cy_embed = nn.Embedding(config.max_height, config.d_model)
|
248 |
+
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model)
|
249 |
+
|
250 |
+
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int):
|
251 |
+
x1, y1, x2, y2 = boxes.unbind(dim=-1)
|
252 |
+
# Shape is (batch_size, num_boxes/seq len, d_model)
|
253 |
+
w = x2 - x1
|
254 |
+
h = y2 - y1
|
255 |
+
# Center x and y in torch long tensors
|
256 |
+
cx = (x1 + x2) / 2
|
257 |
+
cy = (y1 + y2) / 2
|
258 |
+
cx = cx.long()
|
259 |
+
cy = cy.long()
|
260 |
+
|
261 |
+
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
|
262 |
+
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
|
263 |
+
|
264 |
+
# Add in positional embeddings for the boxes
|
265 |
+
if past_key_values_length == 0:
|
266 |
+
for j in range(embedded.shape[0]):
|
267 |
+
box_start = input_box_counts[j, 0]
|
268 |
+
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
|
269 |
+
box_count = box_end - box_start
|
270 |
+
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
|
271 |
+
|
272 |
+
return embedded
|
273 |
+
|
274 |
+
|
275 |
+
class MBartOrderDecoder(MBartDecoder):
|
276 |
+
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
277 |
+
MBartPreTrainedModel.__init__(self, config)
|
278 |
+
self.dropout = config.dropout
|
279 |
+
self.layerdrop = config.decoder_layerdrop
|
280 |
+
self.padding_idx = config.pad_token_id
|
281 |
+
self.max_target_positions = config.max_position_embeddings
|
282 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
283 |
+
|
284 |
+
self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens
|
285 |
+
|
286 |
+
if embed_tokens is not None:
|
287 |
+
self.embed_tokens.weight = embed_tokens.weight
|
288 |
+
|
289 |
+
self.embed_positions = MBartLearnedPositionalEmbedding(
|
290 |
+
config.max_position_embeddings,
|
291 |
+
config.d_model,
|
292 |
+
)
|
293 |
+
# Language-specific MoE goes at second and second-to-last layer
|
294 |
+
self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)])
|
295 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
296 |
+
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
297 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
298 |
+
|
299 |
+
self.gradient_checkpointing = False
|
300 |
+
# Initialize weights and apply final processing
|
301 |
+
self.post_init()
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self,
|
305 |
+
input_boxes: torch.LongTensor = None,
|
306 |
+
input_boxes_mask: Optional[torch.Tensor] = None,
|
307 |
+
input_boxes_counts: Optional[torch.Tensor] = None,
|
308 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
309 |
+
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
310 |
+
head_mask: Optional[torch.Tensor] = None,
|
311 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
312 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
313 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
314 |
+
use_cache: Optional[bool] = None,
|
315 |
+
output_attentions: Optional[bool] = None,
|
316 |
+
output_hidden_states: Optional[bool] = None,
|
317 |
+
return_dict: Optional[bool] = None,
|
318 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
319 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
320 |
+
output_hidden_states = (
|
321 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
322 |
+
)
|
323 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
324 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
325 |
+
|
326 |
+
# retrieve input_ids and inputs_embeds
|
327 |
+
if input_boxes is not None and inputs_embeds is not None:
|
328 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
329 |
+
elif input_boxes is not None:
|
330 |
+
input = input_boxes
|
331 |
+
input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes)
|
332 |
+
elif inputs_embeds is not None:
|
333 |
+
input_shape = inputs_embeds.size()[:-1]
|
334 |
+
input = inputs_embeds[:, :, -1]
|
335 |
+
else:
|
336 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
337 |
+
|
338 |
+
# past_key_values_length
|
339 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
340 |
+
|
341 |
+
if inputs_embeds is None:
|
342 |
+
inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale
|
343 |
+
|
344 |
+
if self._use_flash_attention_2:
|
345 |
+
# 2d mask is passed through the layers
|
346 |
+
attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None
|
347 |
+
else:
|
348 |
+
# 4d mask is passed through the layers
|
349 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
350 |
+
input_boxes_mask, input_shape, inputs_embeds, past_key_values_length
|
351 |
+
)
|
352 |
+
|
353 |
+
if past_key_values_length == 0:
|
354 |
+
box_ends = input_boxes_counts[:, 1]
|
355 |
+
box_starts = input_boxes_counts[:, 0]
|
356 |
+
input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :]
|
357 |
+
# Enable all boxes to attend to each other (before the sep token)
|
358 |
+
# Ensure that the boxes are not attending to the padding tokens
|
359 |
+
boxes_end_mask = input_shape_arranged < box_ends[:, None]
|
360 |
+
boxes_start_mask = input_shape_arranged >= box_starts[:, None]
|
361 |
+
boxes_mask = boxes_end_mask & boxes_start_mask
|
362 |
+
boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting
|
363 |
+
attention_mask = attention_mask.masked_fill(boxes_mask, 0)
|
364 |
+
|
365 |
+
# expand encoder attention mask
|
366 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
367 |
+
if self._use_flash_attention_2:
|
368 |
+
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
369 |
+
else:
|
370 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
371 |
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
372 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
373 |
+
)
|
374 |
+
|
375 |
+
# embed positions
|
376 |
+
positions = self.embed_positions(input, past_key_values_length)
|
377 |
+
|
378 |
+
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
379 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
380 |
+
|
381 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
382 |
+
|
383 |
+
if self.gradient_checkpointing and self.training:
|
384 |
+
if use_cache:
|
385 |
+
use_cache = False
|
386 |
+
|
387 |
+
# decoder layers
|
388 |
+
all_hidden_states = () if output_hidden_states else None
|
389 |
+
all_self_attns = () if output_attentions else None
|
390 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
391 |
+
next_decoder_cache = () if use_cache else None
|
392 |
+
|
393 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
394 |
+
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
395 |
+
if attn_mask is not None:
|
396 |
+
if attn_mask.size()[0] != len(self.layers):
|
397 |
+
raise ValueError(
|
398 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
399 |
+
f" {attn_mask.size()[0]}."
|
400 |
+
)
|
401 |
+
for idx, decoder_layer in enumerate(self.layers):
|
402 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
403 |
+
if output_hidden_states:
|
404 |
+
all_hidden_states += (hidden_states,)
|
405 |
+
if self.training:
|
406 |
+
dropout_probability = torch.rand([])
|
407 |
+
if dropout_probability < self.layerdrop:
|
408 |
+
continue
|
409 |
+
|
410 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
411 |
+
|
412 |
+
if self.gradient_checkpointing and self.training:
|
413 |
+
layer_outputs = self._gradient_checkpointing_func(
|
414 |
+
decoder_layer.__call__,
|
415 |
+
hidden_states,
|
416 |
+
attention_mask,
|
417 |
+
encoder_hidden_states,
|
418 |
+
encoder_attention_mask,
|
419 |
+
head_mask[idx] if head_mask is not None else None,
|
420 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
421 |
+
None,
|
422 |
+
output_attentions,
|
423 |
+
use_cache,
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
layer_outputs = decoder_layer(
|
427 |
+
hidden_states,
|
428 |
+
attention_mask=attention_mask,
|
429 |
+
encoder_hidden_states=encoder_hidden_states,
|
430 |
+
encoder_attention_mask=encoder_attention_mask,
|
431 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
432 |
+
cross_attn_layer_head_mask=(
|
433 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
434 |
+
),
|
435 |
+
past_key_value=past_key_value,
|
436 |
+
output_attentions=output_attentions,
|
437 |
+
use_cache=use_cache,
|
438 |
+
)
|
439 |
+
hidden_states = layer_outputs[0]
|
440 |
+
|
441 |
+
if use_cache:
|
442 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
443 |
+
|
444 |
+
if output_attentions:
|
445 |
+
all_self_attns += (layer_outputs[1],)
|
446 |
+
|
447 |
+
if encoder_hidden_states is not None:
|
448 |
+
all_cross_attentions += (layer_outputs[2],)
|
449 |
+
|
450 |
+
hidden_states = self.layer_norm(hidden_states)
|
451 |
+
|
452 |
+
# add hidden states from the last decoder layer
|
453 |
+
if output_hidden_states:
|
454 |
+
all_hidden_states += (hidden_states,)
|
455 |
+
|
456 |
+
next_cache = next_decoder_cache if use_cache else None
|
457 |
+
if not return_dict:
|
458 |
+
return tuple(
|
459 |
+
v
|
460 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
461 |
+
if v is not None
|
462 |
+
)
|
463 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
464 |
+
last_hidden_state=hidden_states,
|
465 |
+
past_key_values=next_cache,
|
466 |
+
hidden_states=all_hidden_states,
|
467 |
+
attentions=all_self_attns,
|
468 |
+
cross_attentions=all_cross_attentions,
|
469 |
+
)
|
470 |
+
|
471 |
+
|
472 |
+
class MBartOrderDecoderWrapper(MBartPreTrainedModel):
|
473 |
+
"""
|
474 |
+
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
475 |
+
used in combination with the [`EncoderDecoderModel`] framework.
|
476 |
+
"""
|
477 |
+
|
478 |
+
def __init__(self, config):
|
479 |
+
super().__init__(config)
|
480 |
+
self.decoder = MBartOrderDecoder(config)
|
481 |
+
|
482 |
+
def forward(self, *args, **kwargs):
|
483 |
+
return self.decoder(*args, **kwargs)
|
484 |
+
|
485 |
+
|
486 |
+
class MBartOrder(MBartForCausalLM):
|
487 |
+
config_class = MBartOrderConfig
|
488 |
+
_tied_weights_keys = []
|
489 |
+
|
490 |
+
def __init__(self, config, **kwargs):
|
491 |
+
config = copy.deepcopy(config)
|
492 |
+
config.is_decoder = True
|
493 |
+
config.is_encoder_decoder = False
|
494 |
+
MBartPreTrainedModel.__init__(self, config)
|
495 |
+
self.model = MBartOrderDecoderWrapper(config)
|
496 |
+
|
497 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
498 |
+
|
499 |
+
# Initialize weights and apply final processing
|
500 |
+
self.post_init()
|
501 |
+
|
502 |
+
def forward(
|
503 |
+
self,
|
504 |
+
input_boxes: torch.LongTensor = None,
|
505 |
+
input_boxes_mask: Optional[torch.Tensor] = None,
|
506 |
+
input_boxes_counts: Optional[torch.Tensor] = None,
|
507 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
508 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
509 |
+
head_mask: Optional[torch.Tensor] = None,
|
510 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
511 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
512 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
513 |
+
labels: Optional[torch.LongTensor] = None,
|
514 |
+
use_cache: Optional[bool] = None,
|
515 |
+
output_attentions: Optional[bool] = None,
|
516 |
+
output_hidden_states: Optional[bool] = None,
|
517 |
+
return_dict: Optional[bool] = None,
|
518 |
+
**kwargs
|
519 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
520 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
521 |
+
output_hidden_states = (
|
522 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
523 |
+
)
|
524 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
525 |
+
|
526 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
527 |
+
outputs = self.model.decoder(
|
528 |
+
input_boxes=input_boxes,
|
529 |
+
input_boxes_mask=input_boxes_mask,
|
530 |
+
input_boxes_counts=input_boxes_counts,
|
531 |
+
encoder_hidden_states=encoder_hidden_states,
|
532 |
+
encoder_attention_mask=encoder_attention_mask,
|
533 |
+
head_mask=head_mask,
|
534 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
535 |
+
past_key_values=past_key_values,
|
536 |
+
inputs_embeds=inputs_embeds,
|
537 |
+
use_cache=use_cache,
|
538 |
+
output_attentions=output_attentions,
|
539 |
+
output_hidden_states=output_hidden_states,
|
540 |
+
return_dict=return_dict,
|
541 |
+
)
|
542 |
+
|
543 |
+
logits = self.lm_head(outputs[0])
|
544 |
+
|
545 |
+
loss = None
|
546 |
+
if not return_dict:
|
547 |
+
output = (logits,) + outputs[1:]
|
548 |
+
return (loss,) + output if loss is not None else output
|
549 |
+
|
550 |
+
return CausalLMOutputWithCrossAttentions(
|
551 |
+
loss=loss,
|
552 |
+
logits=logits,
|
553 |
+
past_key_values=outputs.past_key_values,
|
554 |
+
hidden_states=outputs.hidden_states,
|
555 |
+
attentions=outputs.attentions,
|
556 |
+
cross_attentions=outputs.cross_attentions,
|
557 |
+
)
|
surya/model/ordering/encoder.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
import collections
|
5 |
+
import math
|
6 |
+
|
7 |
+
from transformers import DonutSwinPreTrainedModel
|
8 |
+
from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \
|
9 |
+
DonutSwinEncoder
|
10 |
+
|
11 |
+
from surya.model.ordering.config import VariableDonutSwinConfig
|
12 |
+
|
13 |
+
class VariableDonutSwinEmbeddings(DonutSwinEmbeddings):
|
14 |
+
"""
|
15 |
+
Construct the patch and position embeddings. Optionally, also the mask token.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, config, use_mask_token=False, **kwargs):
|
19 |
+
super().__init__(config, use_mask_token)
|
20 |
+
|
21 |
+
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
|
22 |
+
num_patches = self.patch_embeddings.num_patches
|
23 |
+
self.patch_grid = self.patch_embeddings.grid_size
|
24 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
|
25 |
+
self.position_embeddings = None
|
26 |
+
|
27 |
+
if config.use_absolute_embeddings:
|
28 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
|
29 |
+
|
30 |
+
self.row_embeddings = None
|
31 |
+
self.column_embeddings = None
|
32 |
+
if config.use_2d_embeddings:
|
33 |
+
self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
|
34 |
+
self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
|
35 |
+
|
36 |
+
self.norm = nn.LayerNorm(config.embed_dim)
|
37 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
38 |
+
|
39 |
+
def forward(
|
40 |
+
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
|
41 |
+
) -> Tuple[torch.Tensor]:
|
42 |
+
|
43 |
+
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
44 |
+
# Layernorm across the last dimension (each patch is a single row)
|
45 |
+
embeddings = self.norm(embeddings)
|
46 |
+
batch_size, seq_len, embed_dim = embeddings.size()
|
47 |
+
|
48 |
+
if bool_masked_pos is not None:
|
49 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
50 |
+
# replace the masked visual tokens by mask_tokens
|
51 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
52 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
53 |
+
|
54 |
+
if self.position_embeddings is not None:
|
55 |
+
embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
|
56 |
+
|
57 |
+
if self.row_embeddings is not None and self.column_embeddings is not None:
|
58 |
+
# Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
|
59 |
+
row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1)
|
60 |
+
column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
|
61 |
+
|
62 |
+
embeddings = embeddings + row_embeddings + column_embeddings
|
63 |
+
|
64 |
+
embeddings = self.dropout(embeddings)
|
65 |
+
|
66 |
+
return embeddings, output_dimensions
|
67 |
+
|
68 |
+
|
69 |
+
class VariableDonutSwinModel(DonutSwinModel):
|
70 |
+
config_class = VariableDonutSwinConfig
|
71 |
+
def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs):
|
72 |
+
super().__init__(config)
|
73 |
+
self.config = config
|
74 |
+
self.num_layers = len(config.depths)
|
75 |
+
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
76 |
+
|
77 |
+
self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token)
|
78 |
+
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
|
79 |
+
|
80 |
+
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
81 |
+
|
82 |
+
# Initialize weights and apply final processing
|
83 |
+
self.post_init()
|
surya/model/ordering/encoderdecoder.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Tuple, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import VisionEncoderDecoderModel
|
5 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
6 |
+
|
7 |
+
|
8 |
+
class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel):
|
9 |
+
def forward(
|
10 |
+
self,
|
11 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
12 |
+
decoder_input_boxes: torch.LongTensor = None,
|
13 |
+
# Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding
|
14 |
+
decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise
|
15 |
+
decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image
|
16 |
+
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
17 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
18 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
19 |
+
labels: Optional[List[List[int]]] = None,
|
20 |
+
use_cache: Optional[bool] = None,
|
21 |
+
output_attentions: Optional[bool] = None,
|
22 |
+
output_hidden_states: Optional[bool] = None,
|
23 |
+
return_dict: Optional[bool] = None,
|
24 |
+
**kwargs,
|
25 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
26 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
27 |
+
|
28 |
+
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
29 |
+
|
30 |
+
kwargs_decoder = {
|
31 |
+
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
32 |
+
}
|
33 |
+
|
34 |
+
if encoder_outputs is None:
|
35 |
+
if pixel_values is None:
|
36 |
+
raise ValueError("You have to specify pixel_values")
|
37 |
+
|
38 |
+
encoder_outputs = self.encoder(
|
39 |
+
pixel_values=pixel_values,
|
40 |
+
output_attentions=output_attentions,
|
41 |
+
output_hidden_states=output_hidden_states,
|
42 |
+
return_dict=return_dict,
|
43 |
+
**kwargs_encoder,
|
44 |
+
)
|
45 |
+
elif isinstance(encoder_outputs, tuple):
|
46 |
+
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
47 |
+
|
48 |
+
encoder_hidden_states = encoder_outputs[0]
|
49 |
+
|
50 |
+
# optionally project encoder_hidden_states
|
51 |
+
if (
|
52 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
53 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
54 |
+
):
|
55 |
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
56 |
+
|
57 |
+
# else:
|
58 |
+
encoder_attention_mask = None
|
59 |
+
|
60 |
+
# Decode
|
61 |
+
decoder_outputs = self.decoder(
|
62 |
+
input_boxes=decoder_input_boxes,
|
63 |
+
input_boxes_mask=decoder_input_boxes_mask,
|
64 |
+
input_boxes_counts=decoder_input_boxes_counts,
|
65 |
+
encoder_hidden_states=encoder_hidden_states,
|
66 |
+
encoder_attention_mask=encoder_attention_mask,
|
67 |
+
inputs_embeds=decoder_inputs_embeds,
|
68 |
+
output_attentions=output_attentions,
|
69 |
+
output_hidden_states=output_hidden_states,
|
70 |
+
use_cache=use_cache,
|
71 |
+
past_key_values=past_key_values,
|
72 |
+
return_dict=return_dict,
|
73 |
+
labels=labels,
|
74 |
+
**kwargs_decoder,
|
75 |
+
)
|
76 |
+
|
77 |
+
if not return_dict:
|
78 |
+
return decoder_outputs + encoder_outputs
|
79 |
+
|
80 |
+
return Seq2SeqLMOutput(
|
81 |
+
loss=decoder_outputs.loss,
|
82 |
+
logits=decoder_outputs.logits,
|
83 |
+
past_key_values=decoder_outputs.past_key_values,
|
84 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
85 |
+
decoder_attentions=decoder_outputs.attentions,
|
86 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
87 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
88 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
89 |
+
encoder_attentions=encoder_outputs.attentions,
|
90 |
+
)
|
surya/model/ordering/model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
|
2 |
+
AutoModel
|
3 |
+
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
|
4 |
+
from surya.model.ordering.decoder import MBartOrder
|
5 |
+
from surya.model.ordering.encoder import VariableDonutSwinModel
|
6 |
+
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
|
7 |
+
from surya.model.ordering.processor import OrderImageProcessor
|
8 |
+
from surya.settings import settings
|
9 |
+
|
10 |
+
|
11 |
+
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
12 |
+
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
|
13 |
+
|
14 |
+
decoder_config = vars(config.decoder)
|
15 |
+
decoder = MBartOrderConfig(**decoder_config)
|
16 |
+
config.decoder = decoder
|
17 |
+
|
18 |
+
encoder_config = vars(config.encoder)
|
19 |
+
encoder = VariableDonutSwinConfig(**encoder_config)
|
20 |
+
config.encoder = encoder
|
21 |
+
|
22 |
+
# Get transformers to load custom model
|
23 |
+
AutoModel.register(MBartOrderConfig, MBartOrder)
|
24 |
+
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
|
25 |
+
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
|
26 |
+
|
27 |
+
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
|
28 |
+
assert isinstance(model.decoder, MBartOrder)
|
29 |
+
assert isinstance(model.encoder, VariableDonutSwinModel)
|
30 |
+
|
31 |
+
model = model.to(device)
|
32 |
+
model = model.eval()
|
33 |
+
print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
|
34 |
+
return model
|
surya/model/ordering/processor.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Dict, Union, Optional, List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import TensorType
|
6 |
+
from transformers import DonutImageProcessor, DonutProcessor
|
7 |
+
from transformers.image_processing_utils import BatchFeature
|
8 |
+
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
|
9 |
+
valid_images, to_numpy_array
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import PIL
|
13 |
+
from surya.settings import settings
|
14 |
+
|
15 |
+
|
16 |
+
def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
|
17 |
+
processor = OrderImageProcessor.from_pretrained(checkpoint)
|
18 |
+
processor.size = settings.ORDER_IMAGE_SIZE
|
19 |
+
box_size = 1024
|
20 |
+
max_tokens = 256
|
21 |
+
processor.token_sep_id = max_tokens + box_size + 1
|
22 |
+
processor.token_pad_id = max_tokens + box_size + 2
|
23 |
+
processor.max_boxes = settings.ORDER_MAX_BOXES - 1
|
24 |
+
processor.box_size = {"height": box_size, "width": box_size}
|
25 |
+
return processor
|
26 |
+
|
27 |
+
|
28 |
+
class OrderImageProcessor(DonutImageProcessor):
|
29 |
+
def __init__(self, *args, **kwargs):
|
30 |
+
super().__init__(*args, **kwargs)
|
31 |
+
|
32 |
+
self.patch_size = kwargs.get("patch_size", (4, 4))
|
33 |
+
|
34 |
+
def process_inner(self, images: List[np.ndarray]):
|
35 |
+
images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format
|
36 |
+
|
37 |
+
assert images[0].shape[0] == 3 # RGB input images, channel dim last
|
38 |
+
|
39 |
+
# Convert to float32 for rescale/normalize
|
40 |
+
images = [img.astype(np.float32) for img in images]
|
41 |
+
|
42 |
+
# Rescale and normalize
|
43 |
+
images = [
|
44 |
+
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
|
45 |
+
for img in images
|
46 |
+
]
|
47 |
+
images = [
|
48 |
+
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
|
49 |
+
for img in images
|
50 |
+
]
|
51 |
+
|
52 |
+
return images
|
53 |
+
|
54 |
+
def process_boxes(self, boxes):
|
55 |
+
padded_boxes = []
|
56 |
+
box_masks = []
|
57 |
+
box_counts = []
|
58 |
+
for b in boxes:
|
59 |
+
# Left pad for generation
|
60 |
+
padded_b = deepcopy(b)
|
61 |
+
padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
|
62 |
+
padded_boxes.append(padded_b)
|
63 |
+
|
64 |
+
max_boxes = max(len(b) for b in padded_boxes)
|
65 |
+
for i in range(len(padded_boxes)):
|
66 |
+
pad_len = max_boxes - len(padded_boxes[i])
|
67 |
+
box_len = len(padded_boxes[i])
|
68 |
+
box_mask = [0] * pad_len + [1] * box_len
|
69 |
+
padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
|
70 |
+
padded_boxes[i] = padded_box
|
71 |
+
box_masks.append(box_mask)
|
72 |
+
box_counts.append([pad_len, max_boxes])
|
73 |
+
|
74 |
+
return padded_boxes, box_masks, box_counts
|
75 |
+
|
76 |
+
def resize_img_and_boxes(self, img, boxes):
|
77 |
+
orig_dim = img.size
|
78 |
+
new_size = (self.size["width"], self.size["height"])
|
79 |
+
img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
|
80 |
+
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
|
81 |
+
|
82 |
+
img = np.asarray(img, dtype=np.uint8)
|
83 |
+
|
84 |
+
width, height = orig_dim
|
85 |
+
box_width, box_height = self.box_size["width"], self.box_size["height"]
|
86 |
+
for box in boxes:
|
87 |
+
# Rescale to 0-1024
|
88 |
+
box[0] = box[0] / width * box_width
|
89 |
+
box[1] = box[1] / height * box_height
|
90 |
+
box[2] = box[2] / width * box_width
|
91 |
+
box[3] = box[3] / height * box_height
|
92 |
+
|
93 |
+
if box[0] < 0:
|
94 |
+
box[0] = 0
|
95 |
+
if box[1] < 0:
|
96 |
+
box[1] = 0
|
97 |
+
if box[2] > box_width:
|
98 |
+
box[2] = box_width
|
99 |
+
if box[3] > box_height:
|
100 |
+
box[3] = box_height
|
101 |
+
|
102 |
+
return img, boxes
|
103 |
+
|
104 |
+
def preprocess(
|
105 |
+
self,
|
106 |
+
images: ImageInput,
|
107 |
+
boxes: List[List[int]],
|
108 |
+
do_resize: bool = None,
|
109 |
+
size: Dict[str, int] = None,
|
110 |
+
resample: PILImageResampling = None,
|
111 |
+
do_thumbnail: bool = None,
|
112 |
+
do_align_long_axis: bool = None,
|
113 |
+
do_pad: bool = None,
|
114 |
+
random_padding: bool = False,
|
115 |
+
do_rescale: bool = None,
|
116 |
+
rescale_factor: float = None,
|
117 |
+
do_normalize: bool = None,
|
118 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
119 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
120 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
121 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
122 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
123 |
+
**kwargs,
|
124 |
+
) -> PIL.Image.Image:
|
125 |
+
images = make_list_of_images(images)
|
126 |
+
|
127 |
+
if not valid_images(images):
|
128 |
+
raise ValueError(
|
129 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
130 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
131 |
+
)
|
132 |
+
|
133 |
+
new_images = []
|
134 |
+
new_boxes = []
|
135 |
+
for img, box in zip(images, boxes):
|
136 |
+
if len(box) > self.max_boxes:
|
137 |
+
raise ValueError(f"Too many boxes, max is {self.max_boxes}")
|
138 |
+
img, box = self.resize_img_and_boxes(img, box)
|
139 |
+
new_images.append(img)
|
140 |
+
new_boxes.append(box)
|
141 |
+
|
142 |
+
images = new_images
|
143 |
+
boxes = new_boxes
|
144 |
+
|
145 |
+
# Convert to numpy for later processing steps
|
146 |
+
images = [np.array(image) for image in images]
|
147 |
+
|
148 |
+
images = self.process_inner(images)
|
149 |
+
boxes, box_mask, box_counts = self.process_boxes(boxes)
|
150 |
+
data = {
|
151 |
+
"pixel_values": images,
|
152 |
+
"input_boxes": boxes,
|
153 |
+
"input_boxes_mask": box_mask,
|
154 |
+
"input_boxes_counts": box_counts,
|
155 |
+
}
|
156 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
surya/model/recognition/config.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import PretrainedConfig
|
5 |
+
from transformers.utils import ModelOutput
|
6 |
+
|
7 |
+
|
8 |
+
class SuryaOCRConfig(PretrainedConfig):
|
9 |
+
model_type = "vision-encoder-decoder"
|
10 |
+
is_composition = True
|
11 |
+
|
12 |
+
def __init__(self, **kwargs):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
|
15 |
+
encoder_config = kwargs.pop("encoder")
|
16 |
+
decoder_config = kwargs.pop("decoder")
|
17 |
+
|
18 |
+
self.encoder = encoder_config
|
19 |
+
self.decoder = decoder_config
|
20 |
+
self.is_encoder_decoder = True
|
21 |
+
|
22 |
+
if isinstance(decoder_config, dict):
|
23 |
+
self.decoder_start_token_id = decoder_config["bos_token_id"]
|
24 |
+
self.pad_token_id = decoder_config["pad_token_id"]
|
25 |
+
self.eos_token_id = decoder_config["eos_token_id"]
|
26 |
+
else:
|
27 |
+
self.decoder_start_token_id = decoder_config.bos_token_id
|
28 |
+
self.pad_token_id = decoder_config.pad_token_id
|
29 |
+
self.eos_token_id = decoder_config.eos_token_id
|
30 |
+
|
31 |
+
|
32 |
+
class DonutSwinConfig(PretrainedConfig):
|
33 |
+
model_type = "donut-swin"
|
34 |
+
|
35 |
+
attribute_map = {
|
36 |
+
"num_attention_heads": "num_heads",
|
37 |
+
"num_hidden_layers": "num_layers",
|
38 |
+
}
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
image_size=(256, 896),
|
43 |
+
patch_size=4,
|
44 |
+
num_channels=3,
|
45 |
+
embed_dim=128,
|
46 |
+
depths=[2, 2, 14, 2],
|
47 |
+
num_heads=[4, 8, 16, 32],
|
48 |
+
num_kv_heads=[1, 2, 4, 8],
|
49 |
+
window_size=7,
|
50 |
+
mlp_ratio=4.0,
|
51 |
+
qkv_bias=True,
|
52 |
+
hidden_dropout_prob=0.0,
|
53 |
+
attention_probs_dropout_prob=0.0,
|
54 |
+
drop_path_rate=0.1,
|
55 |
+
hidden_act="gelu",
|
56 |
+
use_absolute_embeddings=True,
|
57 |
+
initializer_range=0.02,
|
58 |
+
layer_norm_eps=1e-5,
|
59 |
+
encoder_length=256,
|
60 |
+
**kwargs,
|
61 |
+
):
|
62 |
+
super().__init__(**kwargs)
|
63 |
+
|
64 |
+
self.image_size = image_size
|
65 |
+
self.patch_size = patch_size
|
66 |
+
self.num_channels = num_channels
|
67 |
+
self.embed_dim = embed_dim
|
68 |
+
self.depths = depths
|
69 |
+
self.num_layers = len(depths)
|
70 |
+
self.num_heads = num_heads
|
71 |
+
self.num_kv_heads = num_kv_heads
|
72 |
+
self.window_size = window_size
|
73 |
+
self.mlp_ratio = mlp_ratio
|
74 |
+
self.qkv_bias = qkv_bias
|
75 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
76 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
77 |
+
self.drop_path_rate = drop_path_rate
|
78 |
+
self.hidden_act = hidden_act
|
79 |
+
self.use_absolute_embeddings = use_absolute_embeddings
|
80 |
+
self.layer_norm_eps = layer_norm_eps
|
81 |
+
self.initializer_range = initializer_range
|
82 |
+
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
83 |
+
# this indicates the channel dimension after the last stage of the model
|
84 |
+
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
85 |
+
self.encoder_length = encoder_length
|
86 |
+
|
87 |
+
|
88 |
+
class SuryaOCRDecoderConfig(PretrainedConfig):
|
89 |
+
model_type = "surya_ocr"
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
num_hidden_layers=10,
|
94 |
+
vocab_size=65792,
|
95 |
+
hidden_size=1024,
|
96 |
+
intermediate_size=4 * 1024,
|
97 |
+
num_attention_heads=16,
|
98 |
+
lru_width=None,
|
99 |
+
attention_window_size=16,
|
100 |
+
conv1d_width=4,
|
101 |
+
logits_soft_cap=30.0,
|
102 |
+
rms_norm_eps=1e-6,
|
103 |
+
use_cache=True,
|
104 |
+
pad_token_id=0,
|
105 |
+
eos_token_id=1,
|
106 |
+
bos_token_id=1,
|
107 |
+
hidden_activation="gelu_pytorch_tanh",
|
108 |
+
rope_theta=10000.0,
|
109 |
+
block_types=("attention",),
|
110 |
+
cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
|
111 |
+
self_attn_layers=(0, 1, 3, 5, 7, 9),
|
112 |
+
global_attn_layers=(0, 1, 3, 5, 7, 9),
|
113 |
+
attention_dropout=0.0,
|
114 |
+
num_key_value_heads=2,
|
115 |
+
attention_bias=False,
|
116 |
+
w_init_variance_scale=0.01,
|
117 |
+
init_std=0.02,
|
118 |
+
tie_word_embeddings=False,
|
119 |
+
aux_heads=0, # How many n-token-ahead heads to add
|
120 |
+
encoder_hidden_size=1024,
|
121 |
+
causal=False,
|
122 |
+
**kwargs,
|
123 |
+
):
|
124 |
+
self.num_hidden_layers = num_hidden_layers
|
125 |
+
self.vocab_size = vocab_size
|
126 |
+
self.hidden_size = hidden_size
|
127 |
+
self.intermediate_size = intermediate_size
|
128 |
+
self.num_attention_heads = num_attention_heads
|
129 |
+
self.lru_width = lru_width if lru_width is not None else hidden_size
|
130 |
+
self.attention_window_size = attention_window_size
|
131 |
+
self.conv1d_width = conv1d_width
|
132 |
+
self.logits_soft_cap = logits_soft_cap
|
133 |
+
self.rms_norm_eps = rms_norm_eps
|
134 |
+
self.use_cache = use_cache
|
135 |
+
self.rope_theta = rope_theta
|
136 |
+
self.block_types = list(block_types)
|
137 |
+
self.hidden_activation = hidden_activation
|
138 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
139 |
+
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
140 |
+
if self.num_key_value_heads > self.num_attention_heads:
|
141 |
+
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
142 |
+
self.cross_attn_layers = cross_attn_layers
|
143 |
+
self.self_attn_layers = self_attn_layers
|
144 |
+
self.global_attn_layers = global_attn_layers
|
145 |
+
self.attention_dropout = attention_dropout
|
146 |
+
self.attention_bias = attention_bias
|
147 |
+
self.w_init_variance_scale = w_init_variance_scale
|
148 |
+
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
149 |
+
self.init_std = init_std
|
150 |
+
self.tie_word_embeddings = tie_word_embeddings
|
151 |
+
self.aux_heads = aux_heads
|
152 |
+
self.encoder_hidden_size = encoder_hidden_size
|
153 |
+
self.causal = causal
|
154 |
+
|
155 |
+
super().__init__(
|
156 |
+
pad_token_id=pad_token_id,
|
157 |
+
bos_token_id=bos_token_id,
|
158 |
+
eos_token_id=eos_token_id,
|
159 |
+
**kwargs,
|
160 |
+
)
|
161 |
+
|
162 |
+
@property
|
163 |
+
def layers_block_type(self):
|
164 |
+
return (self.block_types * 100)[: self.num_hidden_layers]
|
165 |
+
|
166 |
+
|
167 |
+
class SuryaOCRTextEncoderConfig(PretrainedConfig):
|
168 |
+
model_type = "surya_ocr"
|
169 |
+
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
num_hidden_layers=10,
|
173 |
+
vocab_size=65792,
|
174 |
+
hidden_size=1024,
|
175 |
+
intermediate_size=4 * 1024,
|
176 |
+
num_attention_heads=16,
|
177 |
+
lru_width=None,
|
178 |
+
attention_window_size=16,
|
179 |
+
conv1d_width=4,
|
180 |
+
logits_soft_cap=30.0,
|
181 |
+
rms_norm_eps=1e-6,
|
182 |
+
use_cache=True,
|
183 |
+
pad_token_id=0,
|
184 |
+
eos_token_id=1,
|
185 |
+
bos_token_id=1,
|
186 |
+
hidden_activation="gelu_pytorch_tanh",
|
187 |
+
rope_theta=10000.0,
|
188 |
+
block_types=("attention",),
|
189 |
+
cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
|
190 |
+
self_attn_layers=(0, 1, 3, 5, 7, 9),
|
191 |
+
global_attn_layers=(0, 1, 3, 5, 7, 9),
|
192 |
+
attention_dropout=0.0,
|
193 |
+
num_key_value_heads=2,
|
194 |
+
attention_bias=False,
|
195 |
+
w_init_variance_scale=0.01,
|
196 |
+
init_std=0.02,
|
197 |
+
tie_word_embeddings=False,
|
198 |
+
aux_heads=0, # How many n-token-ahead heads to add
|
199 |
+
encoder_hidden_size=1024,
|
200 |
+
iteration_count=1,
|
201 |
+
causal=False,
|
202 |
+
query_token_count=128,
|
203 |
+
**kwargs,
|
204 |
+
):
|
205 |
+
self.num_hidden_layers = num_hidden_layers
|
206 |
+
self.vocab_size = vocab_size
|
207 |
+
self.hidden_size = hidden_size
|
208 |
+
self.intermediate_size = intermediate_size
|
209 |
+
self.num_attention_heads = num_attention_heads
|
210 |
+
self.lru_width = lru_width if lru_width is not None else hidden_size
|
211 |
+
self.attention_window_size = attention_window_size
|
212 |
+
self.conv1d_width = conv1d_width
|
213 |
+
self.logits_soft_cap = logits_soft_cap
|
214 |
+
self.rms_norm_eps = rms_norm_eps
|
215 |
+
self.use_cache = use_cache
|
216 |
+
self.rope_theta = rope_theta
|
217 |
+
self.block_types = list(block_types)
|
218 |
+
self.hidden_activation = hidden_activation
|
219 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
220 |
+
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
221 |
+
if self.num_key_value_heads > self.num_attention_heads:
|
222 |
+
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
223 |
+
self.cross_attn_layers = cross_attn_layers
|
224 |
+
self.self_attn_layers = self_attn_layers
|
225 |
+
self.global_attn_layers = global_attn_layers
|
226 |
+
self.attention_dropout = attention_dropout
|
227 |
+
self.attention_bias = attention_bias
|
228 |
+
self.w_init_variance_scale = w_init_variance_scale
|
229 |
+
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
230 |
+
self.init_std = init_std
|
231 |
+
self.tie_word_embeddings = tie_word_embeddings
|
232 |
+
self.aux_heads = aux_heads
|
233 |
+
self.encoder_hidden_size = encoder_hidden_size
|
234 |
+
self.iteration_count = iteration_count
|
235 |
+
self.causal = causal
|
236 |
+
self.query_token_count = query_token_count
|
237 |
+
|
238 |
+
super().__init__(
|
239 |
+
pad_token_id=pad_token_id,
|
240 |
+
bos_token_id=bos_token_id,
|
241 |
+
eos_token_id=eos_token_id,
|
242 |
+
**kwargs,
|
243 |
+
)
|
244 |
+
|
245 |
+
@property
|
246 |
+
def layers_block_type(self):
|
247 |
+
return (self.block_types * 100)[: self.num_hidden_layers]
|
248 |
+
|
249 |
+
TOTAL_TOKENS = 65536
|
250 |
+
TOKEN_OFFSET = 3 # Pad, eos, bos
|
251 |
+
SPECIAL_TOKENS = 253
|
252 |
+
TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS
|
253 |
+
LANGUAGE_MAP = {
|
254 |
+
'af': 0,
|
255 |
+
'am': 1,
|
256 |
+
'ar': 2,
|
257 |
+
'as': 3,
|
258 |
+
'az': 4,
|
259 |
+
'be': 5,
|
260 |
+
'bg': 6,
|
261 |
+
'bn': 7,
|
262 |
+
'br': 8,
|
263 |
+
'bs': 9,
|
264 |
+
'ca': 10,
|
265 |
+
'cs': 11,
|
266 |
+
'cy': 12,
|
267 |
+
'da': 13,
|
268 |
+
'de': 14,
|
269 |
+
'el': 15,
|
270 |
+
'en': 16,
|
271 |
+
'eo': 17,
|
272 |
+
'es': 18,
|
273 |
+
'et': 19,
|
274 |
+
'eu': 20,
|
275 |
+
'fa': 21,
|
276 |
+
'fi': 22,
|
277 |
+
'fr': 23,
|
278 |
+
'fy': 24,
|
279 |
+
'ga': 25,
|
280 |
+
'gd': 26,
|
281 |
+
'gl': 27,
|
282 |
+
'gu': 28,
|
283 |
+
'ha': 29,
|
284 |
+
'he': 30,
|
285 |
+
'hi': 31,
|
286 |
+
'hr': 32,
|
287 |
+
'hu': 33,
|
288 |
+
'hy': 34,
|
289 |
+
'id': 35,
|
290 |
+
'is': 36,
|
291 |
+
'it': 37,
|
292 |
+
'ja': 38,
|
293 |
+
'jv': 39,
|
294 |
+
'ka': 40,
|
295 |
+
'kk': 41,
|
296 |
+
'km': 42,
|
297 |
+
'kn': 43,
|
298 |
+
'ko': 44,
|
299 |
+
'ku': 45,
|
300 |
+
'ky': 46,
|
301 |
+
'la': 47,
|
302 |
+
'lo': 48,
|
303 |
+
'lt': 49,
|
304 |
+
'lv': 50,
|
305 |
+
'mg': 51,
|
306 |
+
'mk': 52,
|
307 |
+
'ml': 53,
|
308 |
+
'mn': 54,
|
309 |
+
'mr': 55,
|
310 |
+
'ms': 56,
|
311 |
+
'my': 57,
|
312 |
+
'ne': 58,
|
313 |
+
'nl': 59,
|
314 |
+
'no': 60,
|
315 |
+
'om': 61,
|
316 |
+
'or': 62,
|
317 |
+
'pa': 63,
|
318 |
+
'pl': 64,
|
319 |
+
'ps': 65,
|
320 |
+
'pt': 66,
|
321 |
+
'ro': 67,
|
322 |
+
'ru': 68,
|
323 |
+
'sa': 69,
|
324 |
+
'sd': 70,
|
325 |
+
'si': 71,
|
326 |
+
'sk': 72,
|
327 |
+
'sl': 73,
|
328 |
+
'so': 74,
|
329 |
+
'sq': 75,
|
330 |
+
'sr': 76,
|
331 |
+
'su': 77,
|
332 |
+
'sv': 78,
|
333 |
+
'sw': 79,
|
334 |
+
'ta': 80,
|
335 |
+
'te': 81,
|
336 |
+
'th': 82,
|
337 |
+
'tl': 83,
|
338 |
+
'tr': 84,
|
339 |
+
'ug': 85,
|
340 |
+
'uk': 86,
|
341 |
+
'ur': 87,
|
342 |
+
'uz': 88,
|
343 |
+
'vi': 89,
|
344 |
+
'xh': 90,
|
345 |
+
'yi': 91,
|
346 |
+
'zh': 92,
|
347 |
+
"_math": 93
|
348 |
+
}
|
surya/model/recognition/decoder.py
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
from transformers.utils import ModelOutput
|
8 |
+
|
9 |
+
from surya.model.recognition.config import SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig
|
10 |
+
from transformers import PreTrainedModel
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
|
14 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
15 |
+
|
16 |
+
from surya.settings import settings
|
17 |
+
|
18 |
+
_MAX_SQRT_GRADIENT = 1000.0
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class OCRModelOutput(ModelOutput):
|
23 |
+
logits: torch.Tensor
|
24 |
+
aux_logits: torch.Tensor | None = None
|
25 |
+
hidden_states: torch.Tensor | None = None
|
26 |
+
|
27 |
+
|
28 |
+
class SuryaOCRDecoderRMSNorm(nn.Module):
|
29 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
30 |
+
super().__init__()
|
31 |
+
self.eps = eps
|
32 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
33 |
+
|
34 |
+
def _norm(self, x):
|
35 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
output = self._norm(x.float())
|
39 |
+
# Llama does x.to(float16) * w whilst SuryaOCRDecoder is (x * w).to(float16)
|
40 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
41 |
+
output = output * (1.0 + self.weight.float())
|
42 |
+
return output.type_as(x)
|
43 |
+
|
44 |
+
def extra_repr(self):
|
45 |
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
46 |
+
|
47 |
+
|
48 |
+
ALL_LAYERNORM_LAYERS.append(SuryaOCRDecoderRMSNorm)
|
49 |
+
|
50 |
+
|
51 |
+
class SuryaOCRDecoderRotaryEmbedding(nn.Module):
|
52 |
+
def __init__(self, dim, base=10000, device=None):
|
53 |
+
super().__init__()
|
54 |
+
self.dim = dim
|
55 |
+
self.base = base
|
56 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
57 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaOCRDecoder
|
61 |
+
def forward(self, x, position_ids, seq_len=None):
|
62 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
63 |
+
self.inv_freq.to(x.device)
|
64 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
65 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
66 |
+
|
67 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
68 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
69 |
+
cos = emb.cos()
|
70 |
+
sin = emb.sin()
|
71 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
72 |
+
|
73 |
+
|
74 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
75 |
+
def rotate_half(x):
|
76 |
+
"""Rotates half the hidden dims of the input."""
|
77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
79 |
+
return torch.cat((-x2, x1), dim=-1)
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
83 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
84 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
q (`torch.Tensor`): The query tensor.
|
88 |
+
k (`torch.Tensor`): The key tensor.
|
89 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
90 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
91 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
92 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
93 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
94 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
95 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
96 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
97 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
98 |
+
Returns:
|
99 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
100 |
+
"""
|
101 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
102 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
103 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
104 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
105 |
+
return q_embed, k_embed
|
106 |
+
|
107 |
+
|
108 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
109 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
110 |
+
"""
|
111 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
112 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
113 |
+
"""
|
114 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
115 |
+
if n_rep == 1:
|
116 |
+
return hidden_states
|
117 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
118 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
119 |
+
|
120 |
+
|
121 |
+
class SuryaOCRDecoderSdpaCrossAttention(nn.Module):
|
122 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper
|
123 |
+
Modified for GQA
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, config: SuryaOCRDecoderConfig):
|
127 |
+
super().__init__()
|
128 |
+
self.config = config
|
129 |
+
self.attention_dropout = config.attention_dropout
|
130 |
+
self.hidden_size = config.hidden_size
|
131 |
+
self.num_attention_heads = config.num_attention_heads
|
132 |
+
self.head_dim = config.head_dim
|
133 |
+
self.num_key_value_heads = config.num_key_value_heads
|
134 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
135 |
+
|
136 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
137 |
+
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
138 |
+
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
139 |
+
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
140 |
+
self.rotary_emb = SuryaOCRDecoderRotaryEmbedding(
|
141 |
+
self.head_dim,
|
142 |
+
base=config.rope_theta,
|
143 |
+
)
|
144 |
+
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
hidden_states: torch.Tensor,
|
148 |
+
encoder_hidden_states: torch.Tensor,
|
149 |
+
attention_mask: Optional[torch.Tensor] = None,
|
150 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
151 |
+
use_cache: bool = False,
|
152 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
153 |
+
# Encoder attention mask currently ignored
|
154 |
+
|
155 |
+
bsz, q_len, _ = hidden_states.size()
|
156 |
+
_, v_len, _ = encoder_hidden_states.size()
|
157 |
+
|
158 |
+
query_states = self.q_proj(hidden_states)
|
159 |
+
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
160 |
+
|
161 |
+
if self.key_states is None:
|
162 |
+
key_states = self.k_proj(encoder_hidden_states)
|
163 |
+
value_states = self.v_proj(encoder_hidden_states)
|
164 |
+
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
165 |
+
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
166 |
+
if use_cache:
|
167 |
+
self._update_cache(key_states, value_states)
|
168 |
+
else:
|
169 |
+
key_states = self.key_states
|
170 |
+
value_states = self.value_states
|
171 |
+
|
172 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
173 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
174 |
+
|
175 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
176 |
+
query_states.contiguous(),
|
177 |
+
key_states.contiguous(),
|
178 |
+
value_states.contiguous(),
|
179 |
+
attn_mask=None,
|
180 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
181 |
+
scale=self.head_dim**-0.5,
|
182 |
+
)
|
183 |
+
|
184 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
185 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
186 |
+
attn_output = self.o_proj(attn_output)
|
187 |
+
return attn_output
|
188 |
+
|
189 |
+
def _setup_cache(self, batch_size, device, dtype=None):
|
190 |
+
# Setup initial caches
|
191 |
+
self.value_states = None
|
192 |
+
self.key_states = None
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
196 |
+
self.value_states = value_states
|
197 |
+
self.key_states = key_states
|
198 |
+
|
199 |
+
|
200 |
+
class SuryaOCRDecoderSdpaAttention(nn.Module):
|
201 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
202 |
+
|
203 |
+
def __init__(self, config: SuryaOCRDecoderConfig):
|
204 |
+
super().__init__()
|
205 |
+
self.config = config
|
206 |
+
self.attention_dropout = config.attention_dropout
|
207 |
+
self.hidden_size = config.hidden_size
|
208 |
+
self.num_attention_heads = config.num_attention_heads
|
209 |
+
self.head_dim = config.head_dim
|
210 |
+
self.num_key_value_heads = config.num_key_value_heads
|
211 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
212 |
+
|
213 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
214 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
215 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
216 |
+
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
217 |
+
self.rotary_emb = SuryaOCRDecoderRotaryEmbedding(
|
218 |
+
self.head_dim,
|
219 |
+
base=config.rope_theta,
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
hidden_states: torch.Tensor,
|
225 |
+
position_ids: Optional[torch.LongTensor] = None,
|
226 |
+
attention_mask: Optional[torch.Tensor] = None,
|
227 |
+
cache_position: Optional[torch.LongTensor] = None,
|
228 |
+
use_cache: bool = False,
|
229 |
+
window_attn: bool = False,
|
230 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
231 |
+
bsz, q_len, _ = hidden_states.size()
|
232 |
+
|
233 |
+
query_states = self.q_proj(hidden_states)
|
234 |
+
key_states = self.k_proj(hidden_states)
|
235 |
+
value_states = self.v_proj(hidden_states)
|
236 |
+
|
237 |
+
# Final is bsz, num_attention_heads, seq_len, head_dim
|
238 |
+
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
239 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
240 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
241 |
+
|
242 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
243 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
244 |
+
|
245 |
+
if use_cache and hasattr(self, "key_states"):
|
246 |
+
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
|
247 |
+
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
|
248 |
+
|
249 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
250 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
251 |
+
|
252 |
+
causal_mask = attention_mask
|
253 |
+
if attention_mask is not None:
|
254 |
+
# Mask is batch, head, seq_len, kv_len
|
255 |
+
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
|
256 |
+
current_cache_position = cache_position[-1].item() if cache_position is not None else None
|
257 |
+
if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
|
258 |
+
# Mask out future cache positions
|
259 |
+
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
|
260 |
+
position_mask[:, :, :, :current_cache_position + 1] = False
|
261 |
+
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
|
262 |
+
|
263 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
264 |
+
query_states.contiguous(),
|
265 |
+
key_states.contiguous(),
|
266 |
+
value_states.contiguous(),
|
267 |
+
attn_mask=causal_mask,
|
268 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
269 |
+
scale=self.head_dim**-0.5,
|
270 |
+
)
|
271 |
+
|
272 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
273 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
274 |
+
attn_output = self.o_proj(attn_output)
|
275 |
+
return attn_output
|
276 |
+
|
277 |
+
def _setup_cache(self, batch_size, device, dtype=None):
|
278 |
+
if dtype is None and self.config.torch_dtype is not None:
|
279 |
+
dtype = self.config.torch_dtype
|
280 |
+
dtype = dtype if dtype is not None else torch.float32
|
281 |
+
|
282 |
+
# Setup initial caches
|
283 |
+
self.value_states = None
|
284 |
+
self.key_states = None
|
285 |
+
|
286 |
+
if settings.RECOGNITION_STATIC_CACHE:
|
287 |
+
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
|
288 |
+
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
289 |
+
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
290 |
+
|
291 |
+
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
|
292 |
+
cache_position = cache_kwargs.get("cache_position")
|
293 |
+
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
|
294 |
+
|
295 |
+
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
|
296 |
+
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
|
297 |
+
|
298 |
+
self.key_states, self.value_states = k_out, v_out
|
299 |
+
return k_out, v_out
|
300 |
+
|
301 |
+
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
|
302 |
+
k_out = key_states
|
303 |
+
if self.key_states is not None:
|
304 |
+
k_out = torch.cat([self.key_states, key_states], dim=2)
|
305 |
+
|
306 |
+
v_out = value_states
|
307 |
+
if self.value_states is not None:
|
308 |
+
v_out = torch.cat([self.value_states, value_states], dim=2)
|
309 |
+
|
310 |
+
self.key_states, self.value_states = k_out, v_out
|
311 |
+
return k_out, v_out
|
312 |
+
|
313 |
+
@torch.no_grad()
|
314 |
+
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
315 |
+
if settings.RECOGNITION_STATIC_CACHE:
|
316 |
+
return self._update_static_cache(key_states, value_states, **cache_kwargs)
|
317 |
+
|
318 |
+
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
|
319 |
+
|
320 |
+
|
321 |
+
class SuryaOCRDecoderMlp(nn.Module):
|
322 |
+
def __init__(self, config):
|
323 |
+
super().__init__()
|
324 |
+
self.config = config
|
325 |
+
self.hidden_size = config.hidden_size
|
326 |
+
self.intermediate_size = config.intermediate_size
|
327 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
328 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
329 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
330 |
+
if config.hidden_activation is None:
|
331 |
+
config.hidden_activation = "gelu_pytorch_tanh"
|
332 |
+
hidden_activation = config.hidden_activation
|
333 |
+
self.act_fn = ACT2FN[hidden_activation]
|
334 |
+
|
335 |
+
def forward(self, x):
|
336 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
337 |
+
|
338 |
+
|
339 |
+
class SuryaOCRDecoderLayer(nn.Module):
|
340 |
+
def __init__(self, config, layer_idx):
|
341 |
+
super().__init__()
|
342 |
+
super().__init__()
|
343 |
+
self.cross_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
344 |
+
self.temporal_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
345 |
+
|
346 |
+
self.temporal_block = None
|
347 |
+
if layer_idx in config.self_attn_layers:
|
348 |
+
self.temporal_block = SuryaOCRDecoderSdpaAttention(config)
|
349 |
+
|
350 |
+
self.cross_attn_block = None
|
351 |
+
if layer_idx in config.cross_attn_layers:
|
352 |
+
self.cross_attn_block = SuryaOCRDecoderSdpaCrossAttention(config)
|
353 |
+
|
354 |
+
self.window_attn = layer_idx not in config.global_attn_layers
|
355 |
+
self.channel_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
356 |
+
self.mlp_block = SuryaOCRDecoderMlp(config)
|
357 |
+
|
358 |
+
def forward(
|
359 |
+
self,
|
360 |
+
activations: torch.Tensor,
|
361 |
+
position_ids: torch.Tensor,
|
362 |
+
attention_mask: torch.Tensor,
|
363 |
+
encoder_hidden_states: torch.Tensor = None,
|
364 |
+
encoder_attention_mask: torch.Tensor = None,
|
365 |
+
cache_position: torch.Tensor = None,
|
366 |
+
use_cache: bool = None,
|
367 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
368 |
+
raw_activations = activations
|
369 |
+
|
370 |
+
if self.cross_attn_block is not None:
|
371 |
+
# Do cross-attention on encoder outputs
|
372 |
+
cross_attn_inputs = self.cross_pre_norm(activations)
|
373 |
+
cross_attn_path = self.cross_attn_block(
|
374 |
+
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
|
375 |
+
)
|
376 |
+
cross_attn_output = cross_attn_path + raw_activations
|
377 |
+
else:
|
378 |
+
cross_attn_output = raw_activations
|
379 |
+
|
380 |
+
if self.temporal_block is not None:
|
381 |
+
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
|
382 |
+
hidden_states = self.temporal_block(
|
383 |
+
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
|
384 |
+
)
|
385 |
+
|
386 |
+
residual = hidden_states + raw_activations
|
387 |
+
else:
|
388 |
+
residual = cross_attn_output
|
389 |
+
|
390 |
+
hidden_states = self.channel_pre_norm(residual)
|
391 |
+
hidden_states = self.mlp_block(hidden_states)
|
392 |
+
|
393 |
+
hidden_states = hidden_states + residual
|
394 |
+
return hidden_states
|
395 |
+
|
396 |
+
|
397 |
+
class SuryaOCRDecoderPreTrainedModel(PreTrainedModel):
|
398 |
+
config_class = SuryaOCRDecoderConfig
|
399 |
+
base_model_prefix = "model"
|
400 |
+
supports_gradient_checkpointing = True
|
401 |
+
_no_split_modules = ["SuryaOCRDecoderLayer"]
|
402 |
+
_skip_keys_device_placement = ["cache"]
|
403 |
+
_supports_flash_attn_2 = False
|
404 |
+
_supports_sdpa = False # we can't compare with eager for now
|
405 |
+
_supports_cache_class = True
|
406 |
+
_supports_quantized_cache = True
|
407 |
+
|
408 |
+
def _init_weights(self, module):
|
409 |
+
if isinstance(module, SuryaOCRDecoderSdpaAttention):
|
410 |
+
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
|
411 |
+
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
|
412 |
+
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
|
413 |
+
|
414 |
+
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
|
415 |
+
elif isinstance(module, nn.Linear):
|
416 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
417 |
+
if getattr(module, "bias", None) is not None:
|
418 |
+
torch.nn.init.zeros_(module.bias)
|
419 |
+
elif isinstance(module, nn.Embedding):
|
420 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
421 |
+
if module.padding_idx is not None:
|
422 |
+
module.weight.data[module.padding_idx].zero_()
|
423 |
+
|
424 |
+
def _setup_cache(self, config, batch, device, dtype):
|
425 |
+
layers = getattr(self, "model", self).layers
|
426 |
+
for layer in layers:
|
427 |
+
if layer.temporal_block:
|
428 |
+
layer.temporal_block._setup_cache(batch, device, dtype)
|
429 |
+
if layer.cross_attn_block:
|
430 |
+
layer.cross_attn_block._setup_cache(batch, device, dtype)
|
431 |
+
|
432 |
+
def reset_cache(self, batch, device, dtype):
|
433 |
+
pass
|
434 |
+
|
435 |
+
def _tie_weights(self):
|
436 |
+
pass
|
437 |
+
|
438 |
+
def tie_weights(self):
|
439 |
+
pass
|
440 |
+
|
441 |
+
|
442 |
+
class SuryaOCRDecoderModel(SuryaOCRDecoderPreTrainedModel):
|
443 |
+
"""
|
444 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaOCRDecoderDecoderLayer`]
|
445 |
+
|
446 |
+
Args:
|
447 |
+
config: SuryaOCRDecoderConfig
|
448 |
+
"""
|
449 |
+
|
450 |
+
def __init__(self, config: SuryaOCRDecoderConfig):
|
451 |
+
super().__init__(config)
|
452 |
+
self.padding_idx = config.pad_token_id
|
453 |
+
self.vocab_size = config.vocab_size
|
454 |
+
self.causal = config.causal
|
455 |
+
|
456 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
457 |
+
self.layers = nn.ModuleList(
|
458 |
+
[SuryaOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
459 |
+
)
|
460 |
+
self.final_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
461 |
+
self.gradient_checkpointing = False
|
462 |
+
|
463 |
+
self.register_buffer(
|
464 |
+
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
|
465 |
+
)
|
466 |
+
# Initialize weights and apply final processing
|
467 |
+
self.post_init()
|
468 |
+
|
469 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
|
470 |
+
def get_input_embeddings(self):
|
471 |
+
return self.embed_tokens
|
472 |
+
|
473 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
|
474 |
+
def set_input_embeddings(self, value):
|
475 |
+
self.embed_tokens = value
|
476 |
+
|
477 |
+
def forward(
|
478 |
+
self,
|
479 |
+
input_ids: torch.LongTensor = None,
|
480 |
+
position_ids: Optional[torch.LongTensor] = None,
|
481 |
+
attention_mask: Optional[torch.Tensor] = None,
|
482 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
483 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
484 |
+
cache_position: Optional[torch.LongTensor] = None,
|
485 |
+
use_cache: Optional[bool] = None,
|
486 |
+
output_hidden_states: Optional[bool] = None,
|
487 |
+
return_dict: Optional[bool] = None,
|
488 |
+
prefill: bool = False
|
489 |
+
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
490 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
491 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
492 |
+
|
493 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
494 |
+
use_cache = False
|
495 |
+
|
496 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
497 |
+
hidden_states = inputs_embeds
|
498 |
+
|
499 |
+
if use_cache and prefill:
|
500 |
+
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
|
501 |
+
|
502 |
+
if cache_position is None:
|
503 |
+
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
504 |
+
if position_ids is None:
|
505 |
+
position_ids = cache_position.unsqueeze(0)
|
506 |
+
|
507 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
508 |
+
|
509 |
+
all_hidden_states = () if output_hidden_states else None
|
510 |
+
for i, residual_block in enumerate(self.layers):
|
511 |
+
if output_hidden_states:
|
512 |
+
all_hidden_states += (hidden_states,)
|
513 |
+
if self.gradient_checkpointing and self.training:
|
514 |
+
hidden_states = self._gradient_checkpointing_func(
|
515 |
+
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
|
519 |
+
|
520 |
+
hidden_states = self.final_norm(hidden_states)
|
521 |
+
|
522 |
+
# add hidden states from the last decoder layer
|
523 |
+
if output_hidden_states:
|
524 |
+
all_hidden_states += (hidden_states,)
|
525 |
+
|
526 |
+
if not return_dict:
|
527 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
528 |
+
|
529 |
+
return BaseModelOutputWithNoAttention(
|
530 |
+
last_hidden_state=hidden_states,
|
531 |
+
hidden_states=all_hidden_states,
|
532 |
+
)
|
533 |
+
|
534 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
535 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
536 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
537 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
538 |
+
# Ignore copy
|
539 |
+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
540 |
+
if not self.causal:
|
541 |
+
return None
|
542 |
+
|
543 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
544 |
+
min_dtype = torch.finfo(dtype).min
|
545 |
+
sequence_length = input_tensor.shape[1]
|
546 |
+
target_length = max(settings.RECOGNITION_MAX_TOKENS, sequence_length)
|
547 |
+
|
548 |
+
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
549 |
+
causal_mask = diagonal
|
550 |
+
if sequence_length != 1:
|
551 |
+
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
|
552 |
+
# triu will be the min_dtype, everything else is 0 (attended to)
|
553 |
+
causal_mask = torch.triu(diagonal, diagonal=1)
|
554 |
+
|
555 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
556 |
+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
557 |
+
if attention_mask is not None:
|
558 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
559 |
+
if attention_mask.dim() == 2:
|
560 |
+
# Mask positions in the causal mask that are masked in the attention mask
|
561 |
+
mask_length = attention_mask.shape[-1]
|
562 |
+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
563 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
564 |
+
|
565 |
+
if attention_mask is not None and attention_mask.device.type == "cuda":
|
566 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
567 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
568 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
569 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
570 |
+
|
571 |
+
return causal_mask
|
572 |
+
|
573 |
+
|
574 |
+
class SuryaOCRDecoder(SuryaOCRDecoderPreTrainedModel):
|
575 |
+
_tied_weights_keys = None
|
576 |
+
|
577 |
+
def __init__(self, config, **kwargs):
|
578 |
+
super().__init__(config)
|
579 |
+
self.model = SuryaOCRDecoderModel(config)
|
580 |
+
self.vocab_size = config.vocab_size
|
581 |
+
aux_heads = config.aux_heads if config.aux_heads is not None else 0
|
582 |
+
lm_heads = aux_heads + 1
|
583 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * lm_heads, bias=False)
|
584 |
+
|
585 |
+
# Initialize weights and apply final processing
|
586 |
+
self.post_init()
|
587 |
+
|
588 |
+
def get_input_embeddings(self):
|
589 |
+
return self.model.embed_tokens
|
590 |
+
|
591 |
+
def set_input_embeddings(self, value):
|
592 |
+
self.model.embed_tokens = value
|
593 |
+
|
594 |
+
def get_output_embeddings(self):
|
595 |
+
return self.lm_head
|
596 |
+
|
597 |
+
def set_output_embeddings(self, new_embeddings):
|
598 |
+
self.lm_head = new_embeddings
|
599 |
+
|
600 |
+
def set_decoder(self, decoder):
|
601 |
+
self.model = decoder
|
602 |
+
|
603 |
+
def get_decoder(self):
|
604 |
+
return self.model
|
605 |
+
|
606 |
+
# Ignore copy
|
607 |
+
def forward(
|
608 |
+
self,
|
609 |
+
input_ids: Optional[torch.LongTensor] = None,
|
610 |
+
cache_position: Optional[torch.LongTensor] = None,
|
611 |
+
attention_mask: Optional[torch.Tensor] = None,
|
612 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
613 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
614 |
+
use_cache: Optional[bool] = None,
|
615 |
+
prefill: bool = False,
|
616 |
+
**kwargs
|
617 |
+
) -> Union[Tuple, OCRModelOutput]:
|
618 |
+
outputs = self.model(
|
619 |
+
input_ids=input_ids,
|
620 |
+
cache_position=cache_position,
|
621 |
+
attention_mask=attention_mask,
|
622 |
+
encoder_hidden_states=encoder_hidden_states,
|
623 |
+
encoder_attention_mask=encoder_attention_mask,
|
624 |
+
use_cache=use_cache,
|
625 |
+
output_hidden_states=True,
|
626 |
+
return_dict=True,
|
627 |
+
prefill=prefill,
|
628 |
+
)
|
629 |
+
|
630 |
+
hidden_states = outputs[0]
|
631 |
+
all_logits = self.lm_head(hidden_states)
|
632 |
+
all_logits = torch.split(all_logits, self.vocab_size, dim=-1)
|
633 |
+
logits = all_logits[0]
|
634 |
+
aux_logits = all_logits[1:] if len(all_logits) > 1 else None
|
635 |
+
|
636 |
+
return OCRModelOutput(
|
637 |
+
logits=logits,
|
638 |
+
aux_logits=aux_logits,
|
639 |
+
hidden_states=outputs.hidden_states,
|
640 |
+
)
|
641 |
+
|
642 |
+
@dataclass
|
643 |
+
class TextEncoderOutput(CausalLMOutput):
|
644 |
+
hidden_states: torch.FloatTensor = None
|
645 |
+
|
646 |
+
|
647 |
+
class SuryaOCRTextEncoder(SuryaOCRDecoderPreTrainedModel):
|
648 |
+
_tied_weights_keys = None
|
649 |
+
config_class = SuryaOCRTextEncoderConfig
|
650 |
+
|
651 |
+
def __init__(self, config, **kwargs):
|
652 |
+
super().__init__(config)
|
653 |
+
self.model = SuryaOCRDecoderModel(config)
|
654 |
+
self.vocab_size = config.vocab_size
|
655 |
+
|
656 |
+
# Initialize weights and apply final processing
|
657 |
+
self.post_init()
|
658 |
+
|
659 |
+
def get_input_embeddings(self):
|
660 |
+
return self.model.embed_tokens
|
661 |
+
|
662 |
+
def set_input_embeddings(self, value):
|
663 |
+
self.model.embed_tokens = value
|
664 |
+
|
665 |
+
def set_decoder(self, decoder):
|
666 |
+
self.model = decoder
|
667 |
+
|
668 |
+
def get_decoder(self):
|
669 |
+
return self.model
|
670 |
+
|
671 |
+
# Ignore copy
|
672 |
+
def forward(
|
673 |
+
self,
|
674 |
+
input_ids: Optional[torch.LongTensor] = None,
|
675 |
+
cache_position: Optional[torch.LongTensor] = None,
|
676 |
+
attention_mask: Optional[torch.Tensor] = None,
|
677 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
678 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
679 |
+
use_cache: Optional[bool] = None,
|
680 |
+
**kwargs
|
681 |
+
) -> Union[Tuple, CausalLMOutput]:
|
682 |
+
outputs = self.model(
|
683 |
+
input_ids=input_ids,
|
684 |
+
cache_position=cache_position,
|
685 |
+
attention_mask=attention_mask,
|
686 |
+
encoder_hidden_states=encoder_hidden_states,
|
687 |
+
encoder_attention_mask=encoder_attention_mask,
|
688 |
+
use_cache=use_cache,
|
689 |
+
output_hidden_states=True,
|
690 |
+
return_dict=True,
|
691 |
+
)
|
692 |
+
|
693 |
+
return TextEncoderOutput(
|
694 |
+
hidden_states=outputs.last_hidden_state,
|
695 |
+
)
|
surya/model/recognition/encoder.py
ADDED
@@ -0,0 +1,852 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" EfficientViT (by MIT Song Han's Lab)
|
2 |
+
|
3 |
+
Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
|
4 |
+
- https://arxiv.org/abs/2205.14756
|
5 |
+
|
6 |
+
Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
|
7 |
+
Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
|
8 |
+
"""
|
9 |
+
|
10 |
+
import collections.abc
|
11 |
+
import math
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from typing import Optional, Tuple, Union
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.checkpoint
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from transformers.activations import ACT2FN
|
20 |
+
from transformers.modeling_utils import PreTrainedModel
|
21 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
22 |
+
from transformers.utils import ModelOutput
|
23 |
+
from surya.model.recognition.config import DonutSwinConfig
|
24 |
+
|
25 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
|
30 |
+
class DonutSwinEncoderOutput(ModelOutput):
|
31 |
+
|
32 |
+
last_hidden_state: torch.FloatTensor = None
|
33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
34 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
35 |
+
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class DonutSwinModelOutput(ModelOutput):
|
40 |
+
last_hidden_state: torch.FloatTensor = None
|
41 |
+
|
42 |
+
|
43 |
+
# Copied from transformers.models.swin.modeling_swin.window_partition
|
44 |
+
def window_partition(input_feature, window_size):
|
45 |
+
"""
|
46 |
+
Partitions the given input into windows.
|
47 |
+
"""
|
48 |
+
batch_size, height, width, num_channels = input_feature.shape
|
49 |
+
input_feature = input_feature.view(
|
50 |
+
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
51 |
+
)
|
52 |
+
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
53 |
+
return windows
|
54 |
+
|
55 |
+
|
56 |
+
# Copied from transformers.models.swin.modeling_swin.window_reverse
|
57 |
+
def window_reverse(windows, window_size, height, width):
|
58 |
+
"""
|
59 |
+
Merges windows to produce higher resolution features.
|
60 |
+
"""
|
61 |
+
num_channels = windows.shape[-1]
|
62 |
+
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
|
63 |
+
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
|
64 |
+
return windows
|
65 |
+
|
66 |
+
|
67 |
+
# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
|
68 |
+
class DonutSwinEmbeddings(nn.Module):
|
69 |
+
"""
|
70 |
+
Construct the patch and position embeddings. Optionally, also the mask token.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, config, use_mask_token=False):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
|
77 |
+
num_patches = self.patch_embeddings.num_patches
|
78 |
+
self.patch_grid = self.patch_embeddings.grid_size
|
79 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
|
80 |
+
|
81 |
+
if config.use_absolute_embeddings:
|
82 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
|
83 |
+
else:
|
84 |
+
self.position_embeddings = None
|
85 |
+
|
86 |
+
self.norm = nn.LayerNorm(config.embed_dim)
|
87 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
88 |
+
|
89 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
90 |
+
"""
|
91 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
92 |
+
resolution images.
|
93 |
+
|
94 |
+
Source:
|
95 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
96 |
+
"""
|
97 |
+
|
98 |
+
num_patches = embeddings.shape[1] - 1
|
99 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
100 |
+
if num_patches == num_positions and height == width:
|
101 |
+
return self.position_embeddings
|
102 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
103 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
104 |
+
dim = embeddings.shape[-1]
|
105 |
+
h0 = height // self.config.patch_size
|
106 |
+
w0 = width // self.config.patch_size
|
107 |
+
# we add a small number to avoid floating point error in the interpolation
|
108 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
109 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
110 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
111 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
112 |
+
patch_pos_embed = nn.functional.interpolate(
|
113 |
+
patch_pos_embed,
|
114 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
115 |
+
mode="bicubic",
|
116 |
+
align_corners=False,
|
117 |
+
)
|
118 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
119 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self,
|
123 |
+
pixel_values: Optional[torch.FloatTensor],
|
124 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
125 |
+
interpolate_pos_encoding: bool = False,
|
126 |
+
) -> Tuple[torch.Tensor]:
|
127 |
+
_, num_channels, height, width = pixel_values.shape
|
128 |
+
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
129 |
+
embeddings = self.norm(embeddings)
|
130 |
+
batch_size, seq_len, _ = embeddings.size()
|
131 |
+
|
132 |
+
if bool_masked_pos is not None:
|
133 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
134 |
+
# replace the masked visual tokens by mask_tokens
|
135 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
136 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
137 |
+
|
138 |
+
if self.position_embeddings is not None:
|
139 |
+
if interpolate_pos_encoding:
|
140 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
141 |
+
else:
|
142 |
+
embeddings = embeddings + self.position_embeddings[:, :seq_len]
|
143 |
+
|
144 |
+
embeddings = self.dropout(embeddings)
|
145 |
+
|
146 |
+
return embeddings, output_dimensions
|
147 |
+
|
148 |
+
|
149 |
+
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
|
150 |
+
class DonutSwinPatchEmbeddings(nn.Module):
|
151 |
+
"""
|
152 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
153 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
154 |
+
Transformer.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, config):
|
158 |
+
super().__init__()
|
159 |
+
image_size, patch_size = config.image_size, config.patch_size
|
160 |
+
num_channels, hidden_size = config.num_channels, config.embed_dim
|
161 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
162 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
163 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
164 |
+
self.image_size = image_size
|
165 |
+
self.patch_size = patch_size
|
166 |
+
self.num_channels = num_channels
|
167 |
+
self.num_patches = num_patches
|
168 |
+
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
169 |
+
|
170 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
171 |
+
|
172 |
+
def maybe_pad(self, pixel_values, height, width):
|
173 |
+
if width % self.patch_size[1] != 0:
|
174 |
+
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
175 |
+
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
176 |
+
if height % self.patch_size[0] != 0:
|
177 |
+
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
178 |
+
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
179 |
+
return pixel_values
|
180 |
+
|
181 |
+
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
182 |
+
_, num_channels, height, width = pixel_values.shape
|
183 |
+
# pad the input to be divisible by self.patch_size, if needed
|
184 |
+
pixel_values = self.maybe_pad(pixel_values, height, width)
|
185 |
+
embeddings = self.projection(pixel_values)
|
186 |
+
_, _, height, width = embeddings.shape
|
187 |
+
output_dimensions = (height, width)
|
188 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
189 |
+
|
190 |
+
return embeddings, output_dimensions
|
191 |
+
|
192 |
+
|
193 |
+
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
|
194 |
+
class DonutSwinPatchMerging(nn.Module):
|
195 |
+
"""
|
196 |
+
Patch Merging Layer.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
input_resolution (`Tuple[int]`):
|
200 |
+
Resolution of input feature.
|
201 |
+
dim (`int`):
|
202 |
+
Number of input channels.
|
203 |
+
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
204 |
+
Normalization layer class.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.dim = dim
|
211 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
212 |
+
self.norm = norm_layer(4 * dim)
|
213 |
+
|
214 |
+
def maybe_pad(self, input_feature, height, width):
|
215 |
+
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
216 |
+
if should_pad:
|
217 |
+
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
218 |
+
input_feature = nn.functional.pad(input_feature, pad_values)
|
219 |
+
|
220 |
+
return input_feature
|
221 |
+
|
222 |
+
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
|
223 |
+
height, width = input_dimensions
|
224 |
+
# `dim` is height * width
|
225 |
+
batch_size, dim, num_channels = input_feature.shape
|
226 |
+
|
227 |
+
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
228 |
+
# pad input to be disible by width and height, if needed
|
229 |
+
input_feature = self.maybe_pad(input_feature, height, width)
|
230 |
+
# [batch_size, height/2, width/2, num_channels]
|
231 |
+
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
232 |
+
# [batch_size, height/2, width/2, num_channels]
|
233 |
+
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
234 |
+
# [batch_size, height/2, width/2, num_channels]
|
235 |
+
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
236 |
+
# [batch_size, height/2, width/2, num_channels]
|
237 |
+
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
238 |
+
# batch_size height/2 width/2 4*num_channels
|
239 |
+
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
|
240 |
+
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
|
241 |
+
|
242 |
+
input_feature = self.norm(input_feature)
|
243 |
+
input_feature = self.reduction(input_feature)
|
244 |
+
|
245 |
+
return input_feature
|
246 |
+
|
247 |
+
|
248 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
249 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
250 |
+
"""
|
251 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
252 |
+
|
253 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
254 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
255 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
256 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
257 |
+
argument.
|
258 |
+
"""
|
259 |
+
if drop_prob == 0.0 or not training:
|
260 |
+
return input
|
261 |
+
keep_prob = 1 - drop_prob
|
262 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
263 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
264 |
+
random_tensor.floor_() # binarize
|
265 |
+
output = input.div(keep_prob) * random_tensor
|
266 |
+
return output
|
267 |
+
|
268 |
+
|
269 |
+
# Copied from transformers.models.swin.modeling_swin.SwinDropPath
|
270 |
+
class DonutSwinDropPath(nn.Module):
|
271 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
272 |
+
|
273 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
274 |
+
super().__init__()
|
275 |
+
self.drop_prob = drop_prob
|
276 |
+
|
277 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
278 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
279 |
+
|
280 |
+
def extra_repr(self) -> str:
|
281 |
+
return "p={}".format(self.drop_prob)
|
282 |
+
|
283 |
+
|
284 |
+
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
|
285 |
+
class DonutSwinSelfAttention(nn.Module):
|
286 |
+
def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
|
287 |
+
super().__init__()
|
288 |
+
if dim % num_heads != 0:
|
289 |
+
raise ValueError(
|
290 |
+
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
|
291 |
+
)
|
292 |
+
|
293 |
+
self.num_attention_heads = num_heads
|
294 |
+
self.num_kv_heads = num_kv_heads
|
295 |
+
self.kv_repeats = self.num_attention_heads // self.num_kv_heads
|
296 |
+
self.attention_head_size = int(dim / num_heads)
|
297 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
298 |
+
self.kv_head_size = self.num_kv_heads * self.attention_head_size
|
299 |
+
self.window_size = (
|
300 |
+
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
301 |
+
)
|
302 |
+
|
303 |
+
self.relative_position_bias_table = nn.Parameter(
|
304 |
+
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
305 |
+
)
|
306 |
+
|
307 |
+
# get pair-wise relative position index for each token inside the window
|
308 |
+
coords_h = torch.arange(self.window_size[0])
|
309 |
+
coords_w = torch.arange(self.window_size[1])
|
310 |
+
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
311 |
+
coords_flatten = torch.flatten(coords, 1)
|
312 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
313 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
314 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
315 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
316 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
317 |
+
relative_position_index = relative_coords.sum(-1)
|
318 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
319 |
+
|
320 |
+
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
321 |
+
self.key = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)
|
322 |
+
self.value = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias)
|
323 |
+
|
324 |
+
self.dropout_p = config.attention_probs_dropout_prob
|
325 |
+
|
326 |
+
def transpose_for_scores(self, x):
|
327 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
328 |
+
x = x.view(new_x_shape)
|
329 |
+
return x.permute(0, 2, 1, 3)
|
330 |
+
|
331 |
+
def transpose_kv_for_scores(self, x, repeats):
|
332 |
+
new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size)
|
333 |
+
x = x.view(new_x_shape)
|
334 |
+
x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim
|
335 |
+
return x.permute(0, 2, 1, 3).contiguous()
|
336 |
+
|
337 |
+
def forward(
|
338 |
+
self,
|
339 |
+
hidden_states: torch.Tensor,
|
340 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
341 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
342 |
+
output_attentions: Optional[bool] = False,
|
343 |
+
) -> Tuple[torch.Tensor]:
|
344 |
+
batch_size, dim, num_channels = hidden_states.shape
|
345 |
+
mixed_query_layer = self.query(hidden_states)
|
346 |
+
|
347 |
+
# Final is (batch_size, num_attention_heads, seq_len, attention_head_size)
|
348 |
+
key_layer = self.transpose_kv_for_scores(self.key(hidden_states), self.kv_repeats)
|
349 |
+
value_layer = self.transpose_kv_for_scores(self.value(hidden_states), self.kv_repeats)
|
350 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
351 |
+
|
352 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
353 |
+
relative_position_bias = relative_position_bias.view(
|
354 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
355 |
+
)
|
356 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
|
357 |
+
if attention_mask is None:
|
358 |
+
attention_mask = relative_position_bias
|
359 |
+
else:
|
360 |
+
mask_shape = attention_mask.shape[0]
|
361 |
+
repeat_count = (batch_size // mask_shape)
|
362 |
+
attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1)
|
363 |
+
attention_mask = attention_mask + relative_position_bias
|
364 |
+
|
365 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
366 |
+
query_layer.contiguous(),
|
367 |
+
key_layer.contiguous(),
|
368 |
+
value_layer.contiguous(),
|
369 |
+
attn_mask=attention_mask,
|
370 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
371 |
+
scale=self.attention_head_size**-0.5,
|
372 |
+
)
|
373 |
+
|
374 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
375 |
+
attn_output = attn_output.view(batch_size, dim, num_channels)
|
376 |
+
|
377 |
+
outputs = (attn_output,)
|
378 |
+
return outputs
|
379 |
+
|
380 |
+
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
|
381 |
+
class DonutSwinSelfOutput(nn.Module):
|
382 |
+
def __init__(self, config, dim):
|
383 |
+
super().__init__()
|
384 |
+
self.dense = nn.Linear(dim, dim)
|
385 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
386 |
+
|
387 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
388 |
+
hidden_states = self.dense(hidden_states)
|
389 |
+
hidden_states = self.dropout(hidden_states)
|
390 |
+
|
391 |
+
return hidden_states
|
392 |
+
|
393 |
+
|
394 |
+
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
|
395 |
+
class DonutSwinAttention(nn.Module):
|
396 |
+
def __init__(self, config, dim, num_heads, num_kv_heads, window_size):
|
397 |
+
super().__init__()
|
398 |
+
self.self = DonutSwinSelfAttention(config, dim, num_heads, num_kv_heads, window_size)
|
399 |
+
self.output = DonutSwinSelfOutput(config, dim)
|
400 |
+
self.pruned_heads = set()
|
401 |
+
|
402 |
+
def prune_heads(self, heads):
|
403 |
+
if len(heads) == 0:
|
404 |
+
return
|
405 |
+
heads, index = find_pruneable_heads_and_indices(
|
406 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
407 |
+
)
|
408 |
+
|
409 |
+
# Prune linear layers
|
410 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
411 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
412 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
413 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
414 |
+
|
415 |
+
# Update hyper params and store pruned heads
|
416 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
417 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
418 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
419 |
+
|
420 |
+
def forward(
|
421 |
+
self,
|
422 |
+
hidden_states: torch.Tensor,
|
423 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
424 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
425 |
+
output_attentions: Optional[bool] = False,
|
426 |
+
) -> Tuple[torch.Tensor]:
|
427 |
+
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
|
428 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
429 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
430 |
+
return outputs
|
431 |
+
|
432 |
+
|
433 |
+
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
|
434 |
+
class DonutSwinIntermediate(nn.Module):
|
435 |
+
def __init__(self, config, dim):
|
436 |
+
super().__init__()
|
437 |
+
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
|
438 |
+
if isinstance(config.hidden_act, str):
|
439 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
440 |
+
else:
|
441 |
+
self.intermediate_act_fn = config.hidden_act
|
442 |
+
|
443 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
444 |
+
hidden_states = self.dense(hidden_states)
|
445 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
446 |
+
return hidden_states
|
447 |
+
|
448 |
+
|
449 |
+
# Copied from transformers.models.swin.modeling_swin.SwinOutput
|
450 |
+
class DonutSwinOutput(nn.Module):
|
451 |
+
def __init__(self, config, dim):
|
452 |
+
super().__init__()
|
453 |
+
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
|
454 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
455 |
+
|
456 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
457 |
+
hidden_states = self.dense(hidden_states)
|
458 |
+
hidden_states = self.dropout(hidden_states)
|
459 |
+
return hidden_states
|
460 |
+
|
461 |
+
|
462 |
+
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
|
463 |
+
class DonutSwinLayer(nn.Module):
|
464 |
+
def __init__(self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0):
|
465 |
+
super().__init__()
|
466 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
467 |
+
self.shift_size = shift_size
|
468 |
+
self.window_size = config.window_size
|
469 |
+
self.input_resolution = input_resolution
|
470 |
+
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
471 |
+
self.attention = DonutSwinAttention(config, dim, num_heads, num_kv_heads, window_size=self.window_size)
|
472 |
+
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
473 |
+
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
474 |
+
self.intermediate = DonutSwinIntermediate(config, dim)
|
475 |
+
self.output = DonutSwinOutput(config, dim)
|
476 |
+
|
477 |
+
def set_shift_and_window_size(self, input_resolution):
|
478 |
+
if min(input_resolution) <= self.window_size:
|
479 |
+
# if window size is larger than input resolution, we don't partition windows
|
480 |
+
self.shift_size = int(0)
|
481 |
+
self.window_size = (
|
482 |
+
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
483 |
+
)
|
484 |
+
|
485 |
+
def get_attn_mask(self, height, width, dtype, device):
|
486 |
+
if self.shift_size > 0:
|
487 |
+
# calculate attention mask for SW-MSA
|
488 |
+
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
|
489 |
+
height_slices = (
|
490 |
+
slice(0, -self.window_size),
|
491 |
+
slice(-self.window_size, -self.shift_size),
|
492 |
+
slice(-self.shift_size, None),
|
493 |
+
)
|
494 |
+
width_slices = (
|
495 |
+
slice(0, -self.window_size),
|
496 |
+
slice(-self.window_size, -self.shift_size),
|
497 |
+
slice(-self.shift_size, None),
|
498 |
+
)
|
499 |
+
count = 0
|
500 |
+
for height_slice in height_slices:
|
501 |
+
for width_slice in width_slices:
|
502 |
+
img_mask[:, height_slice, width_slice, :] = count
|
503 |
+
count += 1
|
504 |
+
|
505 |
+
mask_windows = window_partition(img_mask, self.window_size)
|
506 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
507 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
508 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
509 |
+
else:
|
510 |
+
attn_mask = None
|
511 |
+
return attn_mask
|
512 |
+
|
513 |
+
def maybe_pad(self, hidden_states, height, width):
|
514 |
+
pad_right = (self.window_size - width % self.window_size) % self.window_size
|
515 |
+
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
516 |
+
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
|
517 |
+
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
518 |
+
return hidden_states, pad_values
|
519 |
+
|
520 |
+
def forward(
|
521 |
+
self,
|
522 |
+
hidden_states: torch.Tensor,
|
523 |
+
input_dimensions: Tuple[int, int],
|
524 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
525 |
+
output_attentions: Optional[bool] = False,
|
526 |
+
always_partition: Optional[bool] = False,
|
527 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
528 |
+
if not always_partition:
|
529 |
+
self.set_shift_and_window_size(input_dimensions)
|
530 |
+
else:
|
531 |
+
pass
|
532 |
+
height, width = input_dimensions
|
533 |
+
batch_size, _, channels = hidden_states.size()
|
534 |
+
shortcut = hidden_states
|
535 |
+
|
536 |
+
hidden_states = self.layernorm_before(hidden_states)
|
537 |
+
|
538 |
+
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
539 |
+
|
540 |
+
# pad hidden_states to multiples of window size
|
541 |
+
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
542 |
+
|
543 |
+
_, height_pad, width_pad, _ = hidden_states.shape
|
544 |
+
# cyclic shift
|
545 |
+
if self.shift_size > 0:
|
546 |
+
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
547 |
+
else:
|
548 |
+
shifted_hidden_states = hidden_states
|
549 |
+
|
550 |
+
# partition windows
|
551 |
+
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
552 |
+
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
553 |
+
attn_mask = self.get_attn_mask(
|
554 |
+
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
|
555 |
+
)
|
556 |
+
|
557 |
+
attention_outputs = self.attention(
|
558 |
+
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
559 |
+
)
|
560 |
+
|
561 |
+
attention_output = attention_outputs[0]
|
562 |
+
|
563 |
+
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
564 |
+
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
|
565 |
+
|
566 |
+
# reverse cyclic shift
|
567 |
+
if self.shift_size > 0:
|
568 |
+
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
569 |
+
else:
|
570 |
+
attention_windows = shifted_windows
|
571 |
+
|
572 |
+
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
573 |
+
if was_padded:
|
574 |
+
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
575 |
+
|
576 |
+
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
577 |
+
|
578 |
+
hidden_states = shortcut + self.drop_path(attention_windows)
|
579 |
+
|
580 |
+
layer_output = self.layernorm_after(hidden_states)
|
581 |
+
layer_output = self.intermediate(layer_output)
|
582 |
+
layer_output = hidden_states + self.output(layer_output)
|
583 |
+
|
584 |
+
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
|
585 |
+
return layer_outputs
|
586 |
+
|
587 |
+
|
588 |
+
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
|
589 |
+
class DonutSwinStage(nn.Module):
|
590 |
+
def __init__(self, config, dim, input_resolution, depth, num_heads, num_kv_heads, drop_path, downsample):
|
591 |
+
super().__init__()
|
592 |
+
self.config = config
|
593 |
+
self.dim = dim
|
594 |
+
self.blocks = nn.ModuleList(
|
595 |
+
[
|
596 |
+
DonutSwinLayer(
|
597 |
+
config=config,
|
598 |
+
dim=dim,
|
599 |
+
input_resolution=input_resolution,
|
600 |
+
num_heads=num_heads,
|
601 |
+
num_kv_heads=num_kv_heads,
|
602 |
+
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
|
603 |
+
)
|
604 |
+
for i in range(depth)
|
605 |
+
]
|
606 |
+
)
|
607 |
+
|
608 |
+
# patch merging layer
|
609 |
+
if downsample is not None:
|
610 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
|
611 |
+
else:
|
612 |
+
self.downsample = None
|
613 |
+
|
614 |
+
self.pointing = False
|
615 |
+
|
616 |
+
def forward(
|
617 |
+
self,
|
618 |
+
hidden_states: torch.Tensor,
|
619 |
+
input_dimensions: Tuple[int, int],
|
620 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
621 |
+
output_attentions: Optional[bool] = False,
|
622 |
+
always_partition: Optional[bool] = False,
|
623 |
+
) -> Tuple[torch.Tensor]:
|
624 |
+
height, width = input_dimensions
|
625 |
+
for i, layer_module in enumerate(self.blocks):
|
626 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
627 |
+
|
628 |
+
layer_outputs = layer_module(
|
629 |
+
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
630 |
+
)
|
631 |
+
|
632 |
+
hidden_states = layer_outputs[0]
|
633 |
+
|
634 |
+
hidden_states_before_downsampling = hidden_states
|
635 |
+
if self.downsample is not None:
|
636 |
+
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
637 |
+
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
638 |
+
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
|
639 |
+
else:
|
640 |
+
output_dimensions = (height, width, height, width)
|
641 |
+
|
642 |
+
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
|
643 |
+
|
644 |
+
if output_attentions:
|
645 |
+
stage_outputs += layer_outputs[1:]
|
646 |
+
return stage_outputs
|
647 |
+
|
648 |
+
|
649 |
+
# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
|
650 |
+
class DonutSwinEncoder(nn.Module):
|
651 |
+
def __init__(self, config, grid_size):
|
652 |
+
super().__init__()
|
653 |
+
self.num_layers = len(config.depths)
|
654 |
+
self.config = config
|
655 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
656 |
+
self.layers = nn.ModuleList(
|
657 |
+
[
|
658 |
+
DonutSwinStage(
|
659 |
+
config=config,
|
660 |
+
dim=int(config.embed_dim * 2**i_layer),
|
661 |
+
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
|
662 |
+
depth=config.depths[i_layer],
|
663 |
+
num_heads=config.num_heads[i_layer],
|
664 |
+
num_kv_heads=config.num_kv_heads[i_layer],
|
665 |
+
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
|
666 |
+
downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
|
667 |
+
)
|
668 |
+
for i_layer in range(self.num_layers)
|
669 |
+
]
|
670 |
+
)
|
671 |
+
|
672 |
+
self.gradient_checkpointing = False
|
673 |
+
|
674 |
+
def forward(
|
675 |
+
self,
|
676 |
+
hidden_states: torch.Tensor,
|
677 |
+
input_dimensions: Tuple[int, int],
|
678 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
679 |
+
output_attentions: Optional[bool] = False,
|
680 |
+
output_hidden_states: Optional[bool] = False,
|
681 |
+
output_hidden_states_before_downsampling: Optional[bool] = False,
|
682 |
+
always_partition: Optional[bool] = False,
|
683 |
+
return_dict: Optional[bool] = True,
|
684 |
+
) -> Union[Tuple, DonutSwinEncoderOutput]:
|
685 |
+
all_hidden_states = () if output_hidden_states else None
|
686 |
+
all_reshaped_hidden_states = () if output_hidden_states else None
|
687 |
+
all_self_attentions = () if output_attentions else None
|
688 |
+
|
689 |
+
if output_hidden_states:
|
690 |
+
batch_size, _, hidden_size = hidden_states.shape
|
691 |
+
# rearrange b (h w) c -> b c h w
|
692 |
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
693 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
694 |
+
all_hidden_states += (hidden_states,)
|
695 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
696 |
+
|
697 |
+
for i, layer_module in enumerate(self.layers):
|
698 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
699 |
+
|
700 |
+
if self.gradient_checkpointing and self.training:
|
701 |
+
layer_outputs = self._gradient_checkpointing_func(
|
702 |
+
layer_module.__call__,
|
703 |
+
hidden_states,
|
704 |
+
input_dimensions,
|
705 |
+
layer_head_mask,
|
706 |
+
output_attentions,
|
707 |
+
always_partition,
|
708 |
+
)
|
709 |
+
else:
|
710 |
+
layer_outputs = layer_module(
|
711 |
+
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
712 |
+
)
|
713 |
+
|
714 |
+
hidden_states = layer_outputs[0]
|
715 |
+
hidden_states_before_downsampling = layer_outputs[1]
|
716 |
+
output_dimensions = layer_outputs[2]
|
717 |
+
|
718 |
+
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
719 |
+
|
720 |
+
if output_hidden_states and output_hidden_states_before_downsampling:
|
721 |
+
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
722 |
+
# rearrange b (h w) c -> b c h w
|
723 |
+
# here we use the original (not downsampled) height and width
|
724 |
+
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
725 |
+
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
726 |
+
)
|
727 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
728 |
+
all_hidden_states += (hidden_states_before_downsampling,)
|
729 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
730 |
+
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
731 |
+
batch_size, _, hidden_size = hidden_states.shape
|
732 |
+
# rearrange b (h w) c -> b c h w
|
733 |
+
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
734 |
+
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
735 |
+
all_hidden_states += (hidden_states,)
|
736 |
+
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
737 |
+
|
738 |
+
if output_attentions:
|
739 |
+
all_self_attentions += layer_outputs[3:]
|
740 |
+
|
741 |
+
if not return_dict:
|
742 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
743 |
+
|
744 |
+
return DonutSwinEncoderOutput(
|
745 |
+
last_hidden_state=hidden_states,
|
746 |
+
hidden_states=all_hidden_states,
|
747 |
+
attentions=all_self_attentions,
|
748 |
+
reshaped_hidden_states=all_reshaped_hidden_states,
|
749 |
+
)
|
750 |
+
|
751 |
+
|
752 |
+
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
|
753 |
+
class DonutSwinPreTrainedModel(PreTrainedModel):
|
754 |
+
"""
|
755 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
756 |
+
models.
|
757 |
+
"""
|
758 |
+
|
759 |
+
config_class = DonutSwinConfig
|
760 |
+
base_model_prefix = "swin"
|
761 |
+
main_input_name = "pixel_values"
|
762 |
+
supports_gradient_checkpointing = True
|
763 |
+
_no_split_modules = ["DonutSwinStage"]
|
764 |
+
|
765 |
+
def _init_weights(self, module):
|
766 |
+
"""Initialize the weights"""
|
767 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
768 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
769 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
770 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
771 |
+
if module.bias is not None:
|
772 |
+
module.bias.data.zero_()
|
773 |
+
elif isinstance(module, nn.LayerNorm):
|
774 |
+
module.bias.data.zero_()
|
775 |
+
module.weight.data.fill_(1.0)
|
776 |
+
|
777 |
+
|
778 |
+
class DonutSwinModel(DonutSwinPreTrainedModel):
|
779 |
+
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
|
780 |
+
super().__init__(config)
|
781 |
+
self.config = config
|
782 |
+
self.num_layers = len(config.depths)
|
783 |
+
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
784 |
+
|
785 |
+
self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
|
786 |
+
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
|
787 |
+
|
788 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size))
|
789 |
+
|
790 |
+
# Initialize weights and apply final processing
|
791 |
+
self.post_init()
|
792 |
+
|
793 |
+
def get_input_embeddings(self):
|
794 |
+
return self.embeddings.patch_embeddings
|
795 |
+
|
796 |
+
def _prune_heads(self, heads_to_prune):
|
797 |
+
"""
|
798 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
799 |
+
class PreTrainedModel
|
800 |
+
"""
|
801 |
+
for layer, heads in heads_to_prune.items():
|
802 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
803 |
+
|
804 |
+
def forward(
|
805 |
+
self,
|
806 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
807 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
808 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
809 |
+
output_attentions: Optional[bool] = None,
|
810 |
+
output_hidden_states: Optional[bool] = None,
|
811 |
+
interpolate_pos_encoding: bool = False,
|
812 |
+
return_dict: Optional[bool] = None,
|
813 |
+
) -> Union[Tuple, DonutSwinModelOutput]:
|
814 |
+
r"""
|
815 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
816 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
817 |
+
"""
|
818 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
819 |
+
output_hidden_states = (
|
820 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
821 |
+
)
|
822 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
823 |
+
|
824 |
+
if pixel_values is None:
|
825 |
+
raise ValueError("You have to specify pixel_values")
|
826 |
+
|
827 |
+
# Prepare head mask if needed
|
828 |
+
# 1.0 in head_mask indicate we keep the head
|
829 |
+
# attention_probs has shape bsz x n_heads x N x N
|
830 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
831 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
832 |
+
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
833 |
+
|
834 |
+
embedding_output, input_dimensions = self.embeddings(
|
835 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
836 |
+
)
|
837 |
+
|
838 |
+
encoder_outputs = self.encoder(
|
839 |
+
embedding_output,
|
840 |
+
input_dimensions,
|
841 |
+
head_mask=head_mask,
|
842 |
+
output_attentions=output_attentions,
|
843 |
+
output_hidden_states=output_hidden_states,
|
844 |
+
return_dict=return_dict,
|
845 |
+
)
|
846 |
+
|
847 |
+
last_hidden_state = encoder_outputs[0]
|
848 |
+
last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :]
|
849 |
+
|
850 |
+
return DonutSwinModelOutput(
|
851 |
+
last_hidden_state=last_hidden_state,
|
852 |
+
)
|
surya/model/recognition/encoderdecoder.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
|
5 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
6 |
+
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
|
7 |
+
from surya.model.recognition.encoder import DonutSwinModel
|
8 |
+
from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder
|
9 |
+
|
10 |
+
|
11 |
+
class OCREncoderDecoderModel(PreTrainedModel):
|
12 |
+
config_class = VisionEncoderDecoderConfig
|
13 |
+
base_model_prefix = "vision_encoder_decoder"
|
14 |
+
main_input_name = "pixel_values"
|
15 |
+
supports_gradient_checkpointing = True
|
16 |
+
_supports_param_buffer_assignment = False
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
config: Optional[PretrainedConfig] = None,
|
21 |
+
encoder: Optional[PreTrainedModel] = None,
|
22 |
+
decoder: Optional[PreTrainedModel] = None,
|
23 |
+
text_encoder: Optional[PreTrainedModel] = None,
|
24 |
+
):
|
25 |
+
# initialize with config
|
26 |
+
# make sure input & output embeddings is not tied
|
27 |
+
config.tie_word_embeddings = False
|
28 |
+
config.decoder.tie_word_embeddings = False
|
29 |
+
super().__init__(config)
|
30 |
+
|
31 |
+
if encoder is None:
|
32 |
+
encoder = DonutSwinModel(config.encoder)
|
33 |
+
|
34 |
+
if decoder is None:
|
35 |
+
decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation)
|
36 |
+
|
37 |
+
if text_encoder is None:
|
38 |
+
text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
|
39 |
+
|
40 |
+
self.encoder = encoder
|
41 |
+
self.decoder = decoder
|
42 |
+
self.text_encoder = text_encoder
|
43 |
+
|
44 |
+
# make sure that the individual model's config refers to the shared config
|
45 |
+
# so that the updates to the config will be synced
|
46 |
+
self.encoder.config = self.config.encoder
|
47 |
+
self.decoder.config = self.config.decoder
|
48 |
+
self.text_encoder.config = self.config.text_encoder
|
49 |
+
|
50 |
+
def get_encoder(self):
|
51 |
+
return self.encoder
|
52 |
+
|
53 |
+
def get_decoder(self):
|
54 |
+
return self.decoder
|
55 |
+
|
56 |
+
def get_output_embeddings(self):
|
57 |
+
return self.decoder.get_output_embeddings()
|
58 |
+
|
59 |
+
def set_output_embeddings(self, new_embeddings):
|
60 |
+
return self.decoder.set_output_embeddings(new_embeddings)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
65 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
66 |
+
decoder_cache_position: Optional[torch.LongTensor] = None,
|
67 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
68 |
+
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
69 |
+
use_cache: Optional[bool] = None,
|
70 |
+
**kwargs,
|
71 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
72 |
+
|
73 |
+
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
74 |
+
|
75 |
+
kwargs_decoder = {
|
76 |
+
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
77 |
+
}
|
78 |
+
|
79 |
+
if encoder_outputs is None:
|
80 |
+
if pixel_values is None:
|
81 |
+
raise ValueError("You have to specify pixel_values")
|
82 |
+
|
83 |
+
encoder_outputs = self.encoder(
|
84 |
+
pixel_values=pixel_values,
|
85 |
+
**kwargs_encoder,
|
86 |
+
)
|
87 |
+
elif isinstance(encoder_outputs, tuple):
|
88 |
+
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
89 |
+
|
90 |
+
encoder_hidden_states = encoder_outputs[0]
|
91 |
+
|
92 |
+
# optionally project encoder_hidden_states
|
93 |
+
if (
|
94 |
+
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
95 |
+
and self.decoder.config.cross_attention_hidden_size is None
|
96 |
+
):
|
97 |
+
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
98 |
+
|
99 |
+
# else:
|
100 |
+
encoder_attention_mask = None
|
101 |
+
|
102 |
+
# Decode
|
103 |
+
decoder_outputs = self.decoder(
|
104 |
+
input_ids=decoder_input_ids,
|
105 |
+
cache_position=decoder_cache_position,
|
106 |
+
attention_mask=decoder_attention_mask,
|
107 |
+
encoder_hidden_states=encoder_hidden_states,
|
108 |
+
encoder_attention_mask=encoder_attention_mask,
|
109 |
+
use_cache=use_cache,
|
110 |
+
**kwargs_decoder,
|
111 |
+
)
|
112 |
+
|
113 |
+
return Seq2SeqLMOutput(
|
114 |
+
logits=decoder_outputs.logits,
|
115 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
116 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state
|
117 |
+
)
|
118 |
+
|
119 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
120 |
+
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
121 |
+
|
122 |
+
def prepare_inputs_for_generation(
|
123 |
+
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
124 |
+
):
|
125 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
126 |
+
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
127 |
+
input_dict = {
|
128 |
+
"attention_mask": attention_mask,
|
129 |
+
"decoder_attention_mask": decoder_attention_mask,
|
130 |
+
"decoder_input_ids": decoder_inputs["input_ids"],
|
131 |
+
"encoder_outputs": encoder_outputs,
|
132 |
+
"past_key_values": decoder_inputs["past_key_values"],
|
133 |
+
"use_cache": use_cache,
|
134 |
+
}
|
135 |
+
return input_dict
|
136 |
+
|
137 |
+
def resize_token_embeddings(self, *args, **kwargs):
|
138 |
+
raise NotImplementedError(
|
139 |
+
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
|
140 |
+
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
|
141 |
+
)
|
142 |
+
|
143 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
144 |
+
# apply decoder cache reordering here
|
145 |
+
return self.decoder._reorder_cache(past_key_values, beam_idx)
|
surya/model/recognition/model.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated")
|
6 |
+
|
7 |
+
import logging
|
8 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
9 |
+
|
10 |
+
from typing import List, Optional, Tuple
|
11 |
+
from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
|
12 |
+
from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig
|
13 |
+
from surya.model.recognition.encoder import DonutSwinModel
|
14 |
+
from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder
|
15 |
+
from surya.settings import settings
|
16 |
+
|
17 |
+
if not settings.ENABLE_EFFICIENT_ATTENTION:
|
18 |
+
print("Efficient attention is disabled. This will use significantly more VRAM.")
|
19 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
20 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
21 |
+
torch.backends.cuda.enable_math_sdp(True)
|
22 |
+
|
23 |
+
|
24 |
+
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
25 |
+
|
26 |
+
config = SuryaOCRConfig.from_pretrained(checkpoint)
|
27 |
+
decoder_config = config.decoder
|
28 |
+
decoder = SuryaOCRDecoderConfig(**decoder_config)
|
29 |
+
config.decoder = decoder
|
30 |
+
|
31 |
+
encoder_config = config.encoder
|
32 |
+
encoder = DonutSwinConfig(**encoder_config)
|
33 |
+
config.encoder = encoder
|
34 |
+
|
35 |
+
text_encoder_config = config.text_encoder
|
36 |
+
text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config)
|
37 |
+
config.text_encoder = text_encoder
|
38 |
+
|
39 |
+
model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
|
40 |
+
|
41 |
+
assert isinstance(model.decoder, SuryaOCRDecoder)
|
42 |
+
assert isinstance(model.encoder, DonutSwinModel)
|
43 |
+
assert isinstance(model.text_encoder, SuryaOCRTextEncoder)
|
44 |
+
|
45 |
+
model = model.to(device)
|
46 |
+
model = model.eval()
|
47 |
+
|
48 |
+
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
|
49 |
+
return model
|
surya/model/recognition/processor.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Union, Optional, List, Iterable
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
from torch import TensorType
|
5 |
+
from transformers import DonutImageProcessor, DonutProcessor
|
6 |
+
from transformers.image_processing_utils import BatchFeature
|
7 |
+
from transformers.image_transforms import pad, normalize
|
8 |
+
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import PIL
|
12 |
+
from surya.model.recognition.tokenizer import Byt5LangTokenizer
|
13 |
+
from surya.settings import settings
|
14 |
+
|
15 |
+
|
16 |
+
def load_processor():
|
17 |
+
processor = SuryaProcessor()
|
18 |
+
processor.image_processor.train = False
|
19 |
+
processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE
|
20 |
+
processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS
|
21 |
+
return processor
|
22 |
+
|
23 |
+
|
24 |
+
class SuryaImageProcessor(DonutImageProcessor):
|
25 |
+
def __init__(self, *args, max_size=None, train=False, **kwargs):
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
|
28 |
+
self.patch_size = kwargs.get("patch_size", (4, 4))
|
29 |
+
self.max_size = max_size
|
30 |
+
self.train = train
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
|
34 |
+
max_width, max_height = size["width"], size["height"]
|
35 |
+
|
36 |
+
resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
|
37 |
+
resized_image = resized_image.transpose(2, 0, 1)
|
38 |
+
|
39 |
+
return resized_image
|
40 |
+
|
41 |
+
def process_inner(self, images: List[np.ndarray]):
|
42 |
+
assert images[0].shape[2] == 3 # RGB input images, channel dim last
|
43 |
+
|
44 |
+
# Rotate if the bbox is wider than it is tall
|
45 |
+
images = [SuryaImageProcessor.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images]
|
46 |
+
|
47 |
+
# Verify that the image is wider than it is tall
|
48 |
+
for img in images:
|
49 |
+
assert img.shape[1] >= img.shape[0]
|
50 |
+
|
51 |
+
# This also applies the right channel dim format, to channel x height x width
|
52 |
+
images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
|
53 |
+
assert images[0].shape[0] == 3 # RGB input images, channel dim first
|
54 |
+
|
55 |
+
# Convert to float32 for rescale/normalize
|
56 |
+
images = [img.astype(np.float32) for img in images]
|
57 |
+
|
58 |
+
# Pads with 255 (whitespace)
|
59 |
+
# Pad to max size to improve performance
|
60 |
+
max_size = self.max_size
|
61 |
+
images = [
|
62 |
+
SuryaImageProcessor.pad_image(
|
63 |
+
image=image,
|
64 |
+
size=max_size,
|
65 |
+
input_data_format=ChannelDimension.FIRST,
|
66 |
+
pad_value=settings.RECOGNITION_PAD_VALUE
|
67 |
+
)
|
68 |
+
for image in images
|
69 |
+
]
|
70 |
+
# Rescale and normalize
|
71 |
+
for idx in range(len(images)):
|
72 |
+
images[idx] = images[idx] * self.rescale_factor
|
73 |
+
images = [
|
74 |
+
SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
|
75 |
+
for img in images
|
76 |
+
]
|
77 |
+
|
78 |
+
return images
|
79 |
+
|
80 |
+
def preprocess(
|
81 |
+
self,
|
82 |
+
images: ImageInput,
|
83 |
+
do_resize: bool = None,
|
84 |
+
size: Dict[str, int] = None,
|
85 |
+
resample: PILImageResampling = None,
|
86 |
+
do_thumbnail: bool = None,
|
87 |
+
do_align_long_axis: bool = None,
|
88 |
+
do_pad: bool = None,
|
89 |
+
random_padding: bool = False,
|
90 |
+
do_rescale: bool = None,
|
91 |
+
rescale_factor: float = None,
|
92 |
+
do_normalize: bool = None,
|
93 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
94 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
95 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
96 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
97 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
98 |
+
**kwargs,
|
99 |
+
) -> PIL.Image.Image:
|
100 |
+
images = make_list_of_images(images)
|
101 |
+
|
102 |
+
# Convert to numpy for later processing steps
|
103 |
+
images = [np.array(img) for img in images]
|
104 |
+
images = self.process_inner(images)
|
105 |
+
|
106 |
+
data = {"pixel_values": images}
|
107 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def pad_image(
|
111 |
+
cls,
|
112 |
+
image: np.ndarray,
|
113 |
+
size: Dict[str, int],
|
114 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
115 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
116 |
+
pad_value: float = 0.0,
|
117 |
+
) -> np.ndarray:
|
118 |
+
output_height, output_width = size["height"], size["width"]
|
119 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
120 |
+
|
121 |
+
delta_width = output_width - input_width
|
122 |
+
delta_height = output_height - input_height
|
123 |
+
|
124 |
+
assert delta_width >= 0 and delta_height >= 0
|
125 |
+
|
126 |
+
pad_top = delta_height // 2
|
127 |
+
pad_left = delta_width // 2
|
128 |
+
|
129 |
+
pad_bottom = delta_height - pad_top
|
130 |
+
pad_right = delta_width - pad_left
|
131 |
+
|
132 |
+
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
133 |
+
return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
|
134 |
+
|
135 |
+
@classmethod
|
136 |
+
def align_long_axis(
|
137 |
+
cls,
|
138 |
+
image: np.ndarray,
|
139 |
+
size: Dict[str, int],
|
140 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
141 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
142 |
+
) -> np.ndarray:
|
143 |
+
input_height, input_width = image.shape[:2]
|
144 |
+
output_height, output_width = size["height"], size["width"]
|
145 |
+
|
146 |
+
if (output_width < output_height and input_width > input_height) or (
|
147 |
+
output_width > output_height and input_width < input_height
|
148 |
+
):
|
149 |
+
image = np.rot90(image, 3)
|
150 |
+
|
151 |
+
return image
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def normalize(
|
155 |
+
cls,
|
156 |
+
image: np.ndarray,
|
157 |
+
mean: Union[float, Iterable[float]],
|
158 |
+
std: Union[float, Iterable[float]],
|
159 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
160 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
161 |
+
**kwargs,
|
162 |
+
) -> np.ndarray:
|
163 |
+
return normalize(
|
164 |
+
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
class SuryaProcessor(DonutProcessor):
|
169 |
+
def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
|
170 |
+
image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
|
171 |
+
tokenizer = Byt5LangTokenizer()
|
172 |
+
if image_processor is None:
|
173 |
+
raise ValueError("You need to specify an `image_processor`.")
|
174 |
+
if tokenizer is None:
|
175 |
+
raise ValueError("You need to specify a `tokenizer`.")
|
176 |
+
|
177 |
+
super().__init__(image_processor, tokenizer)
|
178 |
+
self.current_processor = self.image_processor
|
179 |
+
self._in_target_context_manager = False
|
180 |
+
|
181 |
+
def __call__(self, *args, **kwargs):
|
182 |
+
images = kwargs.pop("images", None)
|
183 |
+
text = kwargs.pop("text", None)
|
184 |
+
langs = kwargs.pop("langs", None)
|
185 |
+
|
186 |
+
if len(args) > 0:
|
187 |
+
images = args[0]
|
188 |
+
args = args[1:]
|
189 |
+
|
190 |
+
if images is None and text is None:
|
191 |
+
raise ValueError("You need to specify either an `images` or `text` input to process.")
|
192 |
+
|
193 |
+
if images is not None:
|
194 |
+
inputs = self.image_processor(images, *args, **kwargs)
|
195 |
+
|
196 |
+
if text is not None:
|
197 |
+
encodings = self.tokenizer(text, langs, **kwargs)
|
198 |
+
|
199 |
+
if text is None:
|
200 |
+
return inputs
|
201 |
+
elif images is None:
|
202 |
+
return encodings
|
203 |
+
else:
|
204 |
+
inputs["labels"] = encodings["input_ids"]
|
205 |
+
inputs["langs"] = encodings["langs"]
|
206 |
+
return inputs
|
surya/model/recognition/tokenizer.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import chain
|
2 |
+
import random
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
from tokenizers import AddedToken
|
5 |
+
from transformers import ByT5Tokenizer
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_utf16_numbers(text):
|
12 |
+
utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling
|
13 |
+
|
14 |
+
numbers = []
|
15 |
+
|
16 |
+
# Iterate through each pair of bytes and combine them into a single number
|
17 |
+
for i in range(0, len(utf16_bytes), 2):
|
18 |
+
# Combine two adjacent bytes into a single number
|
19 |
+
number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8)
|
20 |
+
numbers.append(number)
|
21 |
+
|
22 |
+
return numbers
|
23 |
+
|
24 |
+
|
25 |
+
def utf16_numbers_to_text(numbers):
|
26 |
+
byte_array = bytearray()
|
27 |
+
for number in numbers:
|
28 |
+
# Extract the two bytes from the number and add them to the byte array
|
29 |
+
byte_array.append(number & 0xFF) # Lower byte
|
30 |
+
byte_array.append((number >> 8) & 0xFF) # Upper byte
|
31 |
+
|
32 |
+
text = byte_array.decode('utf-16le', errors="ignore")
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True):
|
37 |
+
tokens = text_to_utf16_numbers(text)
|
38 |
+
tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens
|
39 |
+
|
40 |
+
lang_list = []
|
41 |
+
if langs:
|
42 |
+
for lang in langs:
|
43 |
+
code = LANGUAGE_MAP[lang]
|
44 |
+
lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS)
|
45 |
+
|
46 |
+
tokens = lang_list + tokens
|
47 |
+
|
48 |
+
if add_bos:
|
49 |
+
tokens.insert(0, eos_token_id)
|
50 |
+
|
51 |
+
return tokens, lang_list
|
52 |
+
|
53 |
+
|
54 |
+
class Byt5LangTokenizer(ByT5Tokenizer):
|
55 |
+
def __init__(self,
|
56 |
+
eos_token="</s>",
|
57 |
+
unk_token="<unk>",
|
58 |
+
pad_token="<pad>",
|
59 |
+
model_max_length=None,
|
60 |
+
**kwargs,
|
61 |
+
):
|
62 |
+
self.pad_token = pad_token
|
63 |
+
self.eos_token = eos_token
|
64 |
+
self.unk_token = unk_token
|
65 |
+
self.bos_token = eos_token
|
66 |
+
self.offset = TOKEN_OFFSET
|
67 |
+
|
68 |
+
self.pad_id = 0
|
69 |
+
self.eos_id = 1
|
70 |
+
self.unk_id = 2
|
71 |
+
|
72 |
+
self.model_max_length = model_max_length
|
73 |
+
self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS
|
74 |
+
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs):
|
78 |
+
tokenized = []
|
79 |
+
all_langs = []
|
80 |
+
|
81 |
+
is_list = True
|
82 |
+
# Convert to list of lists format
|
83 |
+
if isinstance(texts, str):
|
84 |
+
texts = [texts]
|
85 |
+
is_list = False
|
86 |
+
|
87 |
+
if langs is None:
|
88 |
+
langs = [None] * len(texts)
|
89 |
+
|
90 |
+
if isinstance(langs[0], str):
|
91 |
+
langs = [langs]
|
92 |
+
|
93 |
+
assert len(langs) == len(texts)
|
94 |
+
|
95 |
+
for text, lang in zip(texts, langs):
|
96 |
+
tokens, lang_list = _tokenize(text, lang)
|
97 |
+
tokenized.append(tokens)
|
98 |
+
all_langs.append(lang_list)
|
99 |
+
|
100 |
+
# Convert back to flat format
|
101 |
+
if not is_list:
|
102 |
+
tokenized = tokenized[0]
|
103 |
+
all_langs = all_langs[0]
|
104 |
+
|
105 |
+
return {"input_ids": tokenized, "langs": all_langs}
|
106 |
+
|
107 |
+
def decode(
|
108 |
+
self,
|
109 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
110 |
+
skip_special_tokens: bool = False,
|
111 |
+
clean_up_tokenization_spaces: bool = None,
|
112 |
+
**kwargs,
|
113 |
+
) -> str:
|
114 |
+
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
|
115 |
+
token_ids = token_ids.tolist()
|
116 |
+
|
117 |
+
token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start]
|
118 |
+
token_ids = [t - TOKEN_OFFSET for t in token_ids]
|
119 |
+
text = utf16_numbers_to_text(token_ids)
|
120 |
+
return text
|
surya/model/table_rec/config.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from surya.settings import settings
|
3 |
+
|
4 |
+
BOX_DIM = 1024
|
5 |
+
SPECIAL_TOKENS = 7
|
6 |
+
MAX_ROWS = 384
|
7 |
+
|
8 |
+
|
9 |
+
class SuryaTableRecConfig(PretrainedConfig):
|
10 |
+
model_type = "vision-encoder-decoder"
|
11 |
+
is_composition = True
|
12 |
+
|
13 |
+
def __init__(self, **kwargs):
|
14 |
+
super().__init__(**kwargs)
|
15 |
+
|
16 |
+
encoder_config = kwargs.pop("encoder")
|
17 |
+
decoder_config = kwargs.pop("decoder")
|
18 |
+
text_enc_config = kwargs.pop("text_encoder")
|
19 |
+
|
20 |
+
self.encoder = encoder_config
|
21 |
+
self.decoder = decoder_config
|
22 |
+
self.text_encoder = text_enc_config
|
23 |
+
self.is_encoder_decoder = True
|
24 |
+
|
25 |
+
if isinstance(decoder_config, dict):
|
26 |
+
self.decoder_start_token_id = decoder_config["bos_token_id"]
|
27 |
+
self.pad_token_id = decoder_config["pad_token_id"]
|
28 |
+
self.eos_token_id = decoder_config["eos_token_id"]
|
29 |
+
else:
|
30 |
+
self.decoder_start_token_id = decoder_config.bos_token_id
|
31 |
+
self.pad_token_id = decoder_config.pad_token_id
|
32 |
+
self.eos_token_id = decoder_config.eos_token_id
|
33 |
+
|
34 |
+
|
35 |
+
class DonutSwinTableRecConfig(PretrainedConfig):
|
36 |
+
model_type = "donut-swin"
|
37 |
+
|
38 |
+
attribute_map = {
|
39 |
+
"num_attention_heads": "num_heads",
|
40 |
+
"num_hidden_layers": "num_layers",
|
41 |
+
}
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
|
46 |
+
patch_size=4,
|
47 |
+
num_channels=3,
|
48 |
+
embed_dim=128,
|
49 |
+
depths=[2, 2, 14, 2],
|
50 |
+
num_heads=[4, 8, 16, 32],
|
51 |
+
num_kv_heads=[4, 8, 16, 32],
|
52 |
+
window_size=8,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
hidden_dropout_prob=0.0,
|
56 |
+
attention_probs_dropout_prob=0.0,
|
57 |
+
drop_path_rate=0.1,
|
58 |
+
hidden_act="gelu",
|
59 |
+
use_absolute_embeddings=True,
|
60 |
+
initializer_range=0.02,
|
61 |
+
layer_norm_eps=1e-5,
|
62 |
+
encoder_length=1024,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
super().__init__(**kwargs)
|
66 |
+
|
67 |
+
self.image_size = image_size
|
68 |
+
self.patch_size = patch_size
|
69 |
+
self.num_channels = num_channels
|
70 |
+
self.embed_dim = embed_dim
|
71 |
+
self.depths = depths
|
72 |
+
self.num_layers = len(depths)
|
73 |
+
self.num_heads = num_heads
|
74 |
+
self.num_kv_heads = num_kv_heads
|
75 |
+
self.window_size = window_size
|
76 |
+
self.mlp_ratio = mlp_ratio
|
77 |
+
self.qkv_bias = qkv_bias
|
78 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
79 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
80 |
+
self.drop_path_rate = drop_path_rate
|
81 |
+
self.hidden_act = hidden_act
|
82 |
+
self.use_absolute_embeddings = use_absolute_embeddings
|
83 |
+
self.layer_norm_eps = layer_norm_eps
|
84 |
+
self.initializer_range = initializer_range
|
85 |
+
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
86 |
+
# this indicates the channel dimension after the last stage of the model
|
87 |
+
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
88 |
+
self.encoder_length = encoder_length
|
89 |
+
|
90 |
+
|
91 |
+
class SuryaTableRecDecoderConfig(PretrainedConfig):
|
92 |
+
model_type = "surya_tablerec"
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
num_hidden_layers=3,
|
97 |
+
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
|
98 |
+
hidden_size=512,
|
99 |
+
intermediate_size=4 * 512,
|
100 |
+
encoder_hidden_size=1024,
|
101 |
+
num_attention_heads=8,
|
102 |
+
lru_width=None,
|
103 |
+
attention_window_size=16,
|
104 |
+
conv1d_width=4,
|
105 |
+
logits_soft_cap=30.0,
|
106 |
+
rms_norm_eps=1e-6,
|
107 |
+
use_cache=True,
|
108 |
+
pad_token_id=0,
|
109 |
+
eos_token_id=1,
|
110 |
+
bos_token_id=2,
|
111 |
+
hidden_activation="gelu_pytorch_tanh",
|
112 |
+
rope_theta=10000.0,
|
113 |
+
block_types=("attention",),
|
114 |
+
cross_attn_layers=(0, 1, 2, 3),
|
115 |
+
encoder_cross_attn_layers=(0, 1, 2, 3),
|
116 |
+
self_attn_layers=(0, 1, 2, 3),
|
117 |
+
global_attn_layers=(0, 1, 2, 3),
|
118 |
+
attention_dropout=0.0,
|
119 |
+
num_key_value_heads=4,
|
120 |
+
attention_bias=False,
|
121 |
+
w_init_variance_scale=0.01,
|
122 |
+
init_std=0.02,
|
123 |
+
tie_word_embeddings=False,
|
124 |
+
aux_heads=0, # How many n-token-ahead heads to add
|
125 |
+
causal=True,
|
126 |
+
max_classes=2 + SPECIAL_TOKENS,
|
127 |
+
max_width=1024 + SPECIAL_TOKENS,
|
128 |
+
max_height=1024 + SPECIAL_TOKENS,
|
129 |
+
out_box_size=1024,
|
130 |
+
**kwargs,
|
131 |
+
):
|
132 |
+
self.num_hidden_layers = num_hidden_layers
|
133 |
+
self.vocab_size = vocab_size
|
134 |
+
self.hidden_size = hidden_size
|
135 |
+
self.intermediate_size = intermediate_size
|
136 |
+
self.num_attention_heads = num_attention_heads
|
137 |
+
self.lru_width = lru_width if lru_width is not None else hidden_size
|
138 |
+
self.attention_window_size = attention_window_size
|
139 |
+
self.conv1d_width = conv1d_width
|
140 |
+
self.logits_soft_cap = logits_soft_cap
|
141 |
+
self.rms_norm_eps = rms_norm_eps
|
142 |
+
self.use_cache = use_cache
|
143 |
+
self.rope_theta = rope_theta
|
144 |
+
self.block_types = list(block_types)
|
145 |
+
self.hidden_activation = hidden_activation
|
146 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
147 |
+
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
148 |
+
if self.num_key_value_heads > self.num_attention_heads:
|
149 |
+
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
150 |
+
self.cross_attn_layers = cross_attn_layers
|
151 |
+
self.self_attn_layers = self_attn_layers
|
152 |
+
self.global_attn_layers = global_attn_layers
|
153 |
+
self.attention_dropout = attention_dropout
|
154 |
+
self.attention_bias = attention_bias
|
155 |
+
self.w_init_variance_scale = w_init_variance_scale
|
156 |
+
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
157 |
+
self.init_std = init_std
|
158 |
+
self.tie_word_embeddings = tie_word_embeddings
|
159 |
+
self.aux_heads = aux_heads
|
160 |
+
self.encoder_hidden_size=encoder_hidden_size
|
161 |
+
self.causal = causal
|
162 |
+
self.encoder_cross_attn_layers = encoder_cross_attn_layers
|
163 |
+
self.max_classes = max_classes
|
164 |
+
self.max_width = max_width
|
165 |
+
self.max_height = max_height
|
166 |
+
self.out_box_size = out_box_size
|
167 |
+
|
168 |
+
super().__init__(
|
169 |
+
pad_token_id=pad_token_id,
|
170 |
+
bos_token_id=bos_token_id,
|
171 |
+
eos_token_id=eos_token_id,
|
172 |
+
**kwargs,
|
173 |
+
)
|
174 |
+
|
175 |
+
@property
|
176 |
+
def layers_block_type(self):
|
177 |
+
return (self.block_types * 100)[: self.num_hidden_layers]
|
178 |
+
|
179 |
+
|
180 |
+
class SuryaTableRecTextEncoderConfig(PretrainedConfig):
|
181 |
+
model_type = "surya_tablerec"
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
num_hidden_layers=4,
|
186 |
+
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
|
187 |
+
hidden_size=1024,
|
188 |
+
intermediate_size=4 * 1024,
|
189 |
+
encoder_hidden_size=1024,
|
190 |
+
num_attention_heads=16,
|
191 |
+
lru_width=None,
|
192 |
+
attention_window_size=16,
|
193 |
+
conv1d_width=4,
|
194 |
+
logits_soft_cap=30.0,
|
195 |
+
rms_norm_eps=1e-6,
|
196 |
+
use_cache=True,
|
197 |
+
pad_token_id=0,
|
198 |
+
eos_token_id=1,
|
199 |
+
bos_token_id=2,
|
200 |
+
hidden_activation="gelu_pytorch_tanh",
|
201 |
+
rope_theta=10000.0,
|
202 |
+
block_types=("attention",),
|
203 |
+
cross_attn_layers=(0, 1, 2, 3, 4, 5),
|
204 |
+
self_attn_layers=(0, 1, 2, 3, 4, 5),
|
205 |
+
global_attn_layers=(0, 1, 2, 3, 4, 5),
|
206 |
+
attention_dropout=0.0,
|
207 |
+
num_key_value_heads=16,
|
208 |
+
attention_bias=False,
|
209 |
+
w_init_variance_scale=0.01,
|
210 |
+
init_std=0.02,
|
211 |
+
tie_word_embeddings=False,
|
212 |
+
causal=False,
|
213 |
+
max_width=BOX_DIM + SPECIAL_TOKENS,
|
214 |
+
max_height=BOX_DIM + SPECIAL_TOKENS,
|
215 |
+
max_position_embeddings=1024,
|
216 |
+
**kwargs,
|
217 |
+
):
|
218 |
+
self.num_hidden_layers = num_hidden_layers
|
219 |
+
self.vocab_size = vocab_size
|
220 |
+
self.hidden_size = hidden_size
|
221 |
+
self.intermediate_size = intermediate_size
|
222 |
+
self.num_attention_heads = num_attention_heads
|
223 |
+
self.lru_width = lru_width if lru_width is not None else hidden_size
|
224 |
+
self.attention_window_size = attention_window_size
|
225 |
+
self.conv1d_width = conv1d_width
|
226 |
+
self.logits_soft_cap = logits_soft_cap
|
227 |
+
self.rms_norm_eps = rms_norm_eps
|
228 |
+
self.use_cache = use_cache
|
229 |
+
self.rope_theta = rope_theta
|
230 |
+
self.block_types = list(block_types)
|
231 |
+
self.hidden_activation = hidden_activation
|
232 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
233 |
+
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
234 |
+
if self.num_key_value_heads > self.num_attention_heads:
|
235 |
+
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
236 |
+
self.cross_attn_layers = cross_attn_layers
|
237 |
+
self.self_attn_layers = self_attn_layers
|
238 |
+
self.global_attn_layers = global_attn_layers
|
239 |
+
self.attention_dropout = attention_dropout
|
240 |
+
self.attention_bias = attention_bias
|
241 |
+
self.w_init_variance_scale = w_init_variance_scale
|
242 |
+
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
243 |
+
self.init_std = init_std
|
244 |
+
self.tie_word_embeddings = tie_word_embeddings
|
245 |
+
self.encoder_hidden_size = encoder_hidden_size
|
246 |
+
self.causal = causal
|
247 |
+
self.max_width = max_width
|
248 |
+
self.max_height = max_height
|
249 |
+
self.max_position_embeddings = max_position_embeddings
|
250 |
+
|
251 |
+
super().__init__(
|
252 |
+
pad_token_id=pad_token_id,
|
253 |
+
bos_token_id=bos_token_id,
|
254 |
+
eos_token_id=eos_token_id,
|
255 |
+
**kwargs,
|
256 |
+
)
|
257 |
+
|
258 |
+
@property
|
259 |
+
def layers_block_type(self):
|
260 |
+
return (self.block_types * 100)[: self.num_hidden_layers]
|
surya/model/table_rec/decoder.py
ADDED
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
from transformers.utils import ModelOutput
|
8 |
+
|
9 |
+
from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
|
10 |
+
from transformers import PreTrainedModel
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
|
14 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
15 |
+
|
16 |
+
from surya.settings import settings
|
17 |
+
|
18 |
+
_MAX_SQRT_GRADIENT = 1000.0
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class TableRecModelOutput(ModelOutput):
|
22 |
+
bbox_logits: torch.Tensor
|
23 |
+
class_logits: torch.Tensor | None = None
|
24 |
+
hidden_states: torch.Tensor | None = None
|
25 |
+
|
26 |
+
|
27 |
+
class SuryaTableRecDecoderRMSNorm(nn.Module):
|
28 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
29 |
+
super().__init__()
|
30 |
+
self.eps = eps
|
31 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
32 |
+
|
33 |
+
def _norm(self, x):
|
34 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
output = self._norm(x.float())
|
38 |
+
# Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
|
39 |
+
# See https://github.com/huggingface/transformers/pull/29402
|
40 |
+
output = output * (1.0 + self.weight.float())
|
41 |
+
return output.type_as(x)
|
42 |
+
|
43 |
+
def extra_repr(self):
|
44 |
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
45 |
+
|
46 |
+
|
47 |
+
ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)
|
48 |
+
|
49 |
+
|
50 |
+
class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
|
51 |
+
def __init__(self, dim, base=10000, device=None):
|
52 |
+
super().__init__()
|
53 |
+
self.dim = dim
|
54 |
+
self.base = base
|
55 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
56 |
+
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
57 |
+
|
58 |
+
@torch.no_grad()
|
59 |
+
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
|
60 |
+
def forward(self, x, position_ids, seq_len=None):
|
61 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
62 |
+
self.inv_freq.to(x.device)
|
63 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
64 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
65 |
+
|
66 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
67 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
68 |
+
cos = emb.cos()
|
69 |
+
sin = emb.sin()
|
70 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
71 |
+
|
72 |
+
|
73 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
74 |
+
def rotate_half(x):
|
75 |
+
"""Rotates half the hidden dims of the input."""
|
76 |
+
x1 = x[..., : x.shape[-1] // 2]
|
77 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
78 |
+
return torch.cat((-x2, x1), dim=-1)
|
79 |
+
|
80 |
+
|
81 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
82 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
83 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
q (`torch.Tensor`): The query tensor.
|
87 |
+
k (`torch.Tensor`): The key tensor.
|
88 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
89 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
90 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
91 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
92 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
93 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
94 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
95 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
96 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
97 |
+
Returns:
|
98 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
99 |
+
"""
|
100 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
101 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
102 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
103 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
104 |
+
return q_embed, k_embed
|
105 |
+
|
106 |
+
|
107 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
108 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
109 |
+
"""
|
110 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
111 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
112 |
+
"""
|
113 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
114 |
+
if n_rep == 1:
|
115 |
+
return hidden_states
|
116 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
117 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
118 |
+
|
119 |
+
|
120 |
+
class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
|
121 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper
|
122 |
+
Modified for GQA
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, config: SuryaTableRecDecoderConfig):
|
126 |
+
super().__init__()
|
127 |
+
self.config = config
|
128 |
+
self.attention_dropout = config.attention_dropout
|
129 |
+
self.hidden_size = config.hidden_size
|
130 |
+
self.num_attention_heads = config.num_attention_heads
|
131 |
+
self.head_dim = config.head_dim
|
132 |
+
self.num_key_value_heads = config.num_key_value_heads
|
133 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
134 |
+
|
135 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
136 |
+
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
137 |
+
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
138 |
+
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
139 |
+
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
|
140 |
+
self.head_dim,
|
141 |
+
base=config.rope_theta,
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(
|
145 |
+
self,
|
146 |
+
hidden_states: torch.Tensor,
|
147 |
+
encoder_hidden_states: torch.Tensor,
|
148 |
+
attention_mask: Optional[torch.Tensor] = None,
|
149 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
150 |
+
use_cache: bool = False,
|
151 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
152 |
+
# Encoder attention mask currently ignored
|
153 |
+
|
154 |
+
bsz, q_len, _ = hidden_states.size()
|
155 |
+
_, v_len, _ = encoder_hidden_states.size()
|
156 |
+
|
157 |
+
query_states = self.q_proj(hidden_states)
|
158 |
+
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
159 |
+
|
160 |
+
if self.key_states is None:
|
161 |
+
key_states = self.k_proj(encoder_hidden_states)
|
162 |
+
value_states = self.v_proj(encoder_hidden_states)
|
163 |
+
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
164 |
+
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
165 |
+
if use_cache:
|
166 |
+
self._update_cache(key_states, value_states)
|
167 |
+
else:
|
168 |
+
key_states = self.key_states
|
169 |
+
value_states = self.value_states
|
170 |
+
|
171 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
172 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
173 |
+
|
174 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
175 |
+
query_states.contiguous(),
|
176 |
+
key_states.contiguous(),
|
177 |
+
value_states.contiguous(),
|
178 |
+
attn_mask=None,
|
179 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
180 |
+
scale=self.head_dim**-0.5,
|
181 |
+
)
|
182 |
+
|
183 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
184 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
185 |
+
attn_output = self.o_proj(attn_output)
|
186 |
+
return attn_output
|
187 |
+
|
188 |
+
def _setup_cache(self, batch_size, device, dtype=None):
|
189 |
+
# Setup initial caches
|
190 |
+
self.value_states = None
|
191 |
+
self.key_states = None
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
195 |
+
self.value_states = value_states
|
196 |
+
self.key_states = key_states
|
197 |
+
|
198 |
+
|
199 |
+
class SuryaTableRecDecoderSdpaAttention(nn.Module):
|
200 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
201 |
+
|
202 |
+
def __init__(self, config: SuryaTableRecDecoderConfig):
|
203 |
+
super().__init__()
|
204 |
+
self.config = config
|
205 |
+
self.attention_dropout = config.attention_dropout
|
206 |
+
self.hidden_size = config.hidden_size
|
207 |
+
self.num_attention_heads = config.num_attention_heads
|
208 |
+
self.head_dim = config.head_dim
|
209 |
+
self.num_key_value_heads = config.num_key_value_heads
|
210 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
211 |
+
|
212 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
213 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
214 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
215 |
+
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
216 |
+
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
|
217 |
+
self.head_dim,
|
218 |
+
base=config.rope_theta,
|
219 |
+
)
|
220 |
+
|
221 |
+
def forward(
|
222 |
+
self,
|
223 |
+
hidden_states: torch.Tensor,
|
224 |
+
position_ids: Optional[torch.LongTensor] = None,
|
225 |
+
attention_mask: Optional[torch.Tensor] = None,
|
226 |
+
cache_position: Optional[torch.LongTensor] = None,
|
227 |
+
use_cache: bool = False,
|
228 |
+
window_attn: bool = False,
|
229 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
230 |
+
bsz, q_len, _ = hidden_states.size()
|
231 |
+
|
232 |
+
query_states = self.q_proj(hidden_states)
|
233 |
+
key_states = self.k_proj(hidden_states)
|
234 |
+
value_states = self.v_proj(hidden_states)
|
235 |
+
|
236 |
+
# Final is bsz, num_attention_heads, seq_len, head_dim
|
237 |
+
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
238 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
239 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
240 |
+
|
241 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
242 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
243 |
+
|
244 |
+
if use_cache and hasattr(self, "key_states"):
|
245 |
+
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
|
246 |
+
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
|
247 |
+
|
248 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
249 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
250 |
+
|
251 |
+
causal_mask = attention_mask
|
252 |
+
if attention_mask is not None:
|
253 |
+
# Mask is batch, head, seq_len, kv_len
|
254 |
+
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
|
255 |
+
current_cache_position = cache_position[-1].item() if cache_position is not None else None
|
256 |
+
if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
|
257 |
+
# Mask out future cache positions
|
258 |
+
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
|
259 |
+
position_mask[:, :, :, :current_cache_position + 1] = False
|
260 |
+
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
|
261 |
+
|
262 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
263 |
+
query_states.contiguous(),
|
264 |
+
key_states.contiguous(),
|
265 |
+
value_states.contiguous(),
|
266 |
+
attn_mask=causal_mask,
|
267 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
268 |
+
scale=self.head_dim**-0.5,
|
269 |
+
)
|
270 |
+
|
271 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
272 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
273 |
+
attn_output = self.o_proj(attn_output)
|
274 |
+
return attn_output
|
275 |
+
|
276 |
+
def _setup_cache(self, batch_size, device, dtype=None):
|
277 |
+
if dtype is None and self.config.torch_dtype is not None:
|
278 |
+
dtype = self.config.torch_dtype
|
279 |
+
dtype = dtype if dtype is not None else torch.float32
|
280 |
+
|
281 |
+
# Setup initial caches
|
282 |
+
self.value_states = None
|
283 |
+
self.key_states = None
|
284 |
+
|
285 |
+
if settings.RECOGNITION_STATIC_CACHE:
|
286 |
+
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
|
287 |
+
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
288 |
+
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
289 |
+
|
290 |
+
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
|
291 |
+
cache_position = cache_kwargs.get("cache_position")
|
292 |
+
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
|
293 |
+
|
294 |
+
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
|
295 |
+
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
|
296 |
+
|
297 |
+
self.key_states, self.value_states = k_out, v_out
|
298 |
+
return k_out, v_out
|
299 |
+
|
300 |
+
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
|
301 |
+
k_out = key_states
|
302 |
+
if self.key_states is not None:
|
303 |
+
k_out = torch.cat([self.key_states, key_states], dim=2)
|
304 |
+
|
305 |
+
v_out = value_states
|
306 |
+
if self.value_states is not None:
|
307 |
+
v_out = torch.cat([self.value_states, value_states], dim=2)
|
308 |
+
|
309 |
+
self.key_states, self.value_states = k_out, v_out
|
310 |
+
return k_out, v_out
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
314 |
+
if settings.RECOGNITION_STATIC_CACHE:
|
315 |
+
return self._update_static_cache(key_states, value_states, **cache_kwargs)
|
316 |
+
|
317 |
+
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
|
318 |
+
|
319 |
+
|
320 |
+
class SuryaTableRecDecoderMlp(nn.Module):
|
321 |
+
def __init__(self, config):
|
322 |
+
super().__init__()
|
323 |
+
self.config = config
|
324 |
+
self.hidden_size = config.hidden_size
|
325 |
+
self.intermediate_size = config.intermediate_size
|
326 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
327 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
328 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
329 |
+
if config.hidden_activation is None:
|
330 |
+
config.hidden_activation = "gelu_pytorch_tanh"
|
331 |
+
hidden_activation = config.hidden_activation
|
332 |
+
self.act_fn = ACT2FN[hidden_activation]
|
333 |
+
|
334 |
+
def forward(self, x):
|
335 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
336 |
+
|
337 |
+
|
338 |
+
class SuryaTableRecDecoderLayer(nn.Module):
|
339 |
+
def __init__(self, config, layer_idx):
|
340 |
+
super().__init__()
|
341 |
+
super().__init__()
|
342 |
+
self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
343 |
+
self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
344 |
+
|
345 |
+
self.temporal_block = None
|
346 |
+
if layer_idx in config.self_attn_layers:
|
347 |
+
self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)
|
348 |
+
|
349 |
+
self.cross_attn_block = None
|
350 |
+
if layer_idx in config.cross_attn_layers:
|
351 |
+
self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)
|
352 |
+
|
353 |
+
self.window_attn = layer_idx not in config.global_attn_layers
|
354 |
+
self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
355 |
+
self.mlp_block = SuryaTableRecDecoderMlp(config)
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
activations: torch.Tensor,
|
360 |
+
position_ids: torch.Tensor,
|
361 |
+
attention_mask: torch.Tensor,
|
362 |
+
encoder_hidden_states: torch.Tensor = None,
|
363 |
+
encoder_attention_mask: torch.Tensor = None,
|
364 |
+
cache_position: torch.Tensor = None,
|
365 |
+
use_cache: bool = None,
|
366 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
367 |
+
raw_activations = activations
|
368 |
+
|
369 |
+
if self.cross_attn_block is not None:
|
370 |
+
# Do cross-attention on encoder outputs
|
371 |
+
cross_attn_inputs = self.cross_pre_norm(activations)
|
372 |
+
cross_attn_path = self.cross_attn_block(
|
373 |
+
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
|
374 |
+
)
|
375 |
+
cross_attn_output = cross_attn_path + raw_activations
|
376 |
+
else:
|
377 |
+
cross_attn_output = raw_activations
|
378 |
+
|
379 |
+
if self.temporal_block is not None:
|
380 |
+
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
|
381 |
+
hidden_states = self.temporal_block(
|
382 |
+
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
|
383 |
+
)
|
384 |
+
|
385 |
+
residual = hidden_states + raw_activations
|
386 |
+
else:
|
387 |
+
residual = cross_attn_output
|
388 |
+
|
389 |
+
hidden_states = self.channel_pre_norm(residual)
|
390 |
+
hidden_states = self.mlp_block(hidden_states)
|
391 |
+
|
392 |
+
hidden_states = hidden_states + residual
|
393 |
+
return hidden_states
|
394 |
+
|
395 |
+
|
396 |
+
class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
|
397 |
+
config_class = SuryaTableRecDecoderConfig
|
398 |
+
base_model_prefix = "model"
|
399 |
+
supports_gradient_checkpointing = True
|
400 |
+
_no_split_modules = ["SuryaTableRecDecoderLayer"]
|
401 |
+
_skip_keys_device_placement = ["cache"]
|
402 |
+
_supports_flash_attn_2 = False
|
403 |
+
_supports_sdpa = False # we can't compare with eager for now
|
404 |
+
_supports_cache_class = True
|
405 |
+
_supports_quantized_cache = True
|
406 |
+
|
407 |
+
def _init_weights(self, module):
|
408 |
+
if isinstance(module, SuryaTableRecDecoderSdpaAttention):
|
409 |
+
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
|
410 |
+
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
|
411 |
+
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
|
412 |
+
|
413 |
+
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
|
414 |
+
elif isinstance(module, nn.Linear):
|
415 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
416 |
+
if getattr(module, "bias", None) is not None:
|
417 |
+
torch.nn.init.zeros_(module.bias)
|
418 |
+
elif isinstance(module, nn.Embedding):
|
419 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
420 |
+
if module.padding_idx is not None:
|
421 |
+
module.weight.data[module.padding_idx].zero_()
|
422 |
+
|
423 |
+
def _setup_cache(self, config, batch, device, dtype):
|
424 |
+
layers = getattr(self, "model", self).layers
|
425 |
+
for layer in layers:
|
426 |
+
if layer.temporal_block:
|
427 |
+
layer.temporal_block._setup_cache(batch, device, dtype)
|
428 |
+
if layer.cross_attn_block:
|
429 |
+
layer.cross_attn_block._setup_cache(batch, device, dtype)
|
430 |
+
|
431 |
+
def reset_cache(self, batch, device, dtype):
|
432 |
+
pass
|
433 |
+
|
434 |
+
def _tie_weights(self):
|
435 |
+
pass
|
436 |
+
|
437 |
+
def tie_weights(self):
|
438 |
+
pass
|
439 |
+
|
440 |
+
|
441 |
+
class LabelEmbedding(nn.Module):
|
442 |
+
def __init__(self, config):
|
443 |
+
super().__init__()
|
444 |
+
self.vocab_size = config.vocab_size
|
445 |
+
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
|
446 |
+
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
|
447 |
+
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
|
448 |
+
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
|
449 |
+
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
|
450 |
+
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
|
451 |
+
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
|
452 |
+
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
|
453 |
+
self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
|
454 |
+
self.max_width = config.max_width
|
455 |
+
self.max_height = config.max_height
|
456 |
+
self.max_classes = config.max_classes
|
457 |
+
|
458 |
+
def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
|
459 |
+
cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
|
460 |
+
# Shape is (batch_size, num_boxes/seq len, d_model)
|
461 |
+
x1 = (cx - w // 2).long()
|
462 |
+
y1 = (cy - h // 2).long()
|
463 |
+
x2 = (cx + w // 2).long()
|
464 |
+
y2 = (cy + h // 2).long()
|
465 |
+
x1 = torch.clamp(x1, 0, self.max_width - 1)
|
466 |
+
y1 = torch.clamp(y1, 0, self.max_height - 1)
|
467 |
+
x2 = torch.clamp(x2, 0, self.max_width - 1)
|
468 |
+
y2 = torch.clamp(y2, 0, self.max_height - 1)
|
469 |
+
|
470 |
+
class_ = torch.clamp(class_, 0, self.max_classes - 1).long()
|
471 |
+
|
472 |
+
w = torch.clamp(w, 0, self.max_width - 1).long()
|
473 |
+
h = torch.clamp(h, 0, self.max_height - 1).long()
|
474 |
+
cx = torch.clamp(cx, 0, self.max_width - 1).long()
|
475 |
+
cy = torch.clamp(cy, 0, self.max_height - 1).long()
|
476 |
+
|
477 |
+
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
|
478 |
+
class_embeds = self.class_embed(class_)
|
479 |
+
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds
|
480 |
+
|
481 |
+
return embedded
|
482 |
+
|
483 |
+
|
484 |
+
class BboxEmbedding(nn.Module):
|
485 |
+
def __init__(self, config, embed_positions=False):
|
486 |
+
super().__init__()
|
487 |
+
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
|
488 |
+
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
|
489 |
+
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
|
490 |
+
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
|
491 |
+
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
|
492 |
+
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
|
493 |
+
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
|
494 |
+
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
|
495 |
+
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
496 |
+
self.max_width = config.max_width
|
497 |
+
self.max_height = config.max_height
|
498 |
+
self.embed_positions = embed_positions
|
499 |
+
|
500 |
+
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
|
501 |
+
x1, y1, x2, y2 = boxes.unbind(dim=-1)
|
502 |
+
x1 = torch.clamp(x1, 0, self.max_width - 1).long()
|
503 |
+
y1 = torch.clamp(y1, 0, self.max_height - 1).long()
|
504 |
+
x2 = torch.clamp(x2, 0, self.max_width - 1).long()
|
505 |
+
y2 = torch.clamp(y2, 0, self.max_height - 1).long()
|
506 |
+
|
507 |
+
# Shape is (batch_size, num_boxes/seq len, d_model)
|
508 |
+
w = x2 - x1
|
509 |
+
h = y2 - y1
|
510 |
+
# Center x and y in torch long tensors
|
511 |
+
cx = (x1 + x2) / 2
|
512 |
+
cy = (y1 + y2) / 2
|
513 |
+
cx = cx.long()
|
514 |
+
cy = cy.long()
|
515 |
+
|
516 |
+
w = torch.clamp(w, 0, self.max_width - 1).long()
|
517 |
+
h = torch.clamp(h, 0, self.max_height - 1).long()
|
518 |
+
cx = torch.clamp(cx, 0, self.max_width - 1).long()
|
519 |
+
cy = torch.clamp(cy, 0, self.max_height - 1).long()
|
520 |
+
|
521 |
+
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
|
522 |
+
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
|
523 |
+
|
524 |
+
# Add in positional embeddings for the boxes and labels
|
525 |
+
if self.embed_positions:
|
526 |
+
for j in range(embedded.shape[0]):
|
527 |
+
box_start = input_box_counts[j, 0]
|
528 |
+
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
|
529 |
+
box_count = box_end - box_start
|
530 |
+
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
|
531 |
+
|
532 |
+
return embedded
|
533 |
+
|
534 |
+
|
535 |
+
class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
|
536 |
+
"""
|
537 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]
|
538 |
+
|
539 |
+
Args:
|
540 |
+
config: SuryaTableRecDecoderConfig
|
541 |
+
"""
|
542 |
+
|
543 |
+
def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
|
544 |
+
super().__init__(config)
|
545 |
+
self.padding_idx = config.pad_token_id
|
546 |
+
self.vocab_size = config.vocab_size
|
547 |
+
self.causal = config.causal
|
548 |
+
|
549 |
+
if embed_labels:
|
550 |
+
self.embed_tokens = LabelEmbedding(config)
|
551 |
+
else:
|
552 |
+
self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)
|
553 |
+
|
554 |
+
self.layers = nn.ModuleList(
|
555 |
+
[SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
556 |
+
)
|
557 |
+
self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
558 |
+
self.gradient_checkpointing = False
|
559 |
+
|
560 |
+
self.register_buffer(
|
561 |
+
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
|
562 |
+
)
|
563 |
+
# Initialize weights and apply final processing
|
564 |
+
self.post_init()
|
565 |
+
|
566 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
|
567 |
+
def get_input_embeddings(self):
|
568 |
+
return self.embed_tokens
|
569 |
+
|
570 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
|
571 |
+
def set_input_embeddings(self, value):
|
572 |
+
self.embed_tokens = value
|
573 |
+
|
574 |
+
def forward(
|
575 |
+
self,
|
576 |
+
input_ids: torch.LongTensor = None,
|
577 |
+
input_boxes_counts: torch.LongTensor = None,
|
578 |
+
position_ids: Optional[torch.LongTensor] = None,
|
579 |
+
attention_mask: Optional[torch.Tensor] = None,
|
580 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
581 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
582 |
+
cache_position: Optional[torch.LongTensor] = None,
|
583 |
+
use_cache: Optional[bool] = None,
|
584 |
+
output_hidden_states: Optional[bool] = None,
|
585 |
+
return_dict: Optional[bool] = None,
|
586 |
+
prefill: bool = False
|
587 |
+
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
588 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
589 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
590 |
+
|
591 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
592 |
+
use_cache = False
|
593 |
+
|
594 |
+
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
|
595 |
+
hidden_states = inputs_embeds
|
596 |
+
|
597 |
+
if use_cache and prefill:
|
598 |
+
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
|
599 |
+
|
600 |
+
if cache_position is None:
|
601 |
+
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
602 |
+
if position_ids is None:
|
603 |
+
position_ids = cache_position.unsqueeze(0)
|
604 |
+
|
605 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
606 |
+
|
607 |
+
all_hidden_states = () if output_hidden_states else None
|
608 |
+
for i, residual_block in enumerate(self.layers):
|
609 |
+
if output_hidden_states:
|
610 |
+
all_hidden_states += (hidden_states,)
|
611 |
+
if self.gradient_checkpointing and self.training:
|
612 |
+
hidden_states = self._gradient_checkpointing_func(
|
613 |
+
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
|
617 |
+
|
618 |
+
hidden_states = self.final_norm(hidden_states)
|
619 |
+
|
620 |
+
# add hidden states from the last decoder layer
|
621 |
+
if output_hidden_states:
|
622 |
+
all_hidden_states += (hidden_states,)
|
623 |
+
|
624 |
+
if not return_dict:
|
625 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
626 |
+
|
627 |
+
return BaseModelOutputWithNoAttention(
|
628 |
+
last_hidden_state=hidden_states,
|
629 |
+
hidden_states=all_hidden_states,
|
630 |
+
)
|
631 |
+
|
632 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
633 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
634 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
635 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
636 |
+
# Ignore copy
|
637 |
+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
638 |
+
if not self.causal:
|
639 |
+
return None
|
640 |
+
|
641 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
642 |
+
min_dtype = torch.finfo(dtype).min
|
643 |
+
sequence_length = input_tensor.shape[1]
|
644 |
+
target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)
|
645 |
+
|
646 |
+
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
647 |
+
causal_mask = diagonal
|
648 |
+
if sequence_length != 1:
|
649 |
+
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
|
650 |
+
# triu will be the min_dtype, everything else is 0 (attended to)
|
651 |
+
causal_mask = torch.triu(diagonal, diagonal=1)
|
652 |
+
|
653 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
654 |
+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
655 |
+
if attention_mask is not None:
|
656 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
657 |
+
if attention_mask.dim() == 2:
|
658 |
+
# Mask positions in the causal mask that are masked in the attention mask
|
659 |
+
mask_length = attention_mask.shape[-1]
|
660 |
+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
661 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
662 |
+
|
663 |
+
if attention_mask is not None and attention_mask.device.type == "cuda":
|
664 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
665 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
666 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
667 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
668 |
+
|
669 |
+
return causal_mask
|
670 |
+
|
671 |
+
|
672 |
+
class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
|
673 |
+
_tied_weights_keys = None
|
674 |
+
|
675 |
+
def __init__(self, config, **kwargs):
|
676 |
+
super().__init__(config)
|
677 |
+
self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
|
678 |
+
self.vocab_size = config.vocab_size
|
679 |
+
|
680 |
+
self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
|
681 |
+
self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
|
682 |
+
self.max_width = config.max_width
|
683 |
+
|
684 |
+
# Initialize weights and apply final processing
|
685 |
+
self.post_init()
|
686 |
+
|
687 |
+
def get_input_embeddings(self):
|
688 |
+
return self.model.embed_tokens
|
689 |
+
|
690 |
+
def set_input_embeddings(self, value):
|
691 |
+
self.model.embed_tokens = value
|
692 |
+
|
693 |
+
def get_output_embeddings(self):
|
694 |
+
return self.lm_head
|
695 |
+
|
696 |
+
def set_output_embeddings(self, new_embeddings):
|
697 |
+
self.lm_head = new_embeddings
|
698 |
+
|
699 |
+
def set_decoder(self, decoder):
|
700 |
+
self.model = decoder
|
701 |
+
|
702 |
+
def get_decoder(self):
|
703 |
+
return self.model
|
704 |
+
|
705 |
+
# Ignore copy
|
706 |
+
def forward(
|
707 |
+
self,
|
708 |
+
input_ids: Optional[torch.LongTensor] = None,
|
709 |
+
cache_position: Optional[torch.LongTensor] = None,
|
710 |
+
attention_mask: Optional[torch.Tensor] = None,
|
711 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
712 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
713 |
+
use_cache: Optional[bool] = None,
|
714 |
+
prefill: bool = False,
|
715 |
+
**kwargs
|
716 |
+
) -> Union[Tuple, TableRecModelOutput]:
|
717 |
+
outputs = self.model(
|
718 |
+
input_ids=input_ids,
|
719 |
+
cache_position=cache_position,
|
720 |
+
attention_mask=attention_mask,
|
721 |
+
encoder_hidden_states=encoder_hidden_states,
|
722 |
+
encoder_attention_mask=encoder_attention_mask,
|
723 |
+
use_cache=use_cache,
|
724 |
+
output_hidden_states=True,
|
725 |
+
return_dict=True,
|
726 |
+
prefill=prefill,
|
727 |
+
)
|
728 |
+
|
729 |
+
hidden_states = outputs[0]
|
730 |
+
bbox_logits = self.bbox_head(hidden_states)
|
731 |
+
class_logits = self.class_head(hidden_states)
|
732 |
+
bsz, seq_len = class_logits.shape[:2]
|
733 |
+
bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)
|
734 |
+
|
735 |
+
return TableRecModelOutput(
|
736 |
+
bbox_logits=bbox_logits,
|
737 |
+
class_logits=class_logits,
|
738 |
+
hidden_states=hidden_states,
|
739 |
+
)
|
740 |
+
@dataclass
|
741 |
+
class TextEncoderOutput(CausalLMOutput):
|
742 |
+
hidden_states: torch.FloatTensor = None
|
743 |
+
|
744 |
+
|
745 |
+
class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
|
746 |
+
_tied_weights_keys = None
|
747 |
+
config_class = SuryaTableRecTextEncoderConfig
|
748 |
+
|
749 |
+
def __init__(self, config, **kwargs):
|
750 |
+
super().__init__(config)
|
751 |
+
self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
|
752 |
+
self.vocab_size = config.vocab_size
|
753 |
+
|
754 |
+
# Initialize weights and apply final processing
|
755 |
+
self.post_init()
|
756 |
+
|
757 |
+
def get_input_embeddings(self):
|
758 |
+
return self.model.embed_tokens
|
759 |
+
|
760 |
+
def set_input_embeddings(self, value):
|
761 |
+
self.model.embed_tokens = value
|
762 |
+
|
763 |
+
def set_decoder(self, decoder):
|
764 |
+
self.model = decoder
|
765 |
+
|
766 |
+
def get_decoder(self):
|
767 |
+
return self.model
|
768 |
+
|
769 |
+
# Ignore copy
|
770 |
+
def forward(
|
771 |
+
self,
|
772 |
+
input_boxes: Optional[torch.LongTensor] = None,
|
773 |
+
input_boxes_counts: Optional[torch.LongTensor] = None,
|
774 |
+
cache_position: Optional[torch.LongTensor] = None,
|
775 |
+
attention_mask: Optional[torch.Tensor] = None,
|
776 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
777 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
778 |
+
use_cache: Optional[bool] = None,
|
779 |
+
**kwargs
|
780 |
+
) -> Union[Tuple, CausalLMOutput]:
|
781 |
+
outputs = self.model(
|
782 |
+
input_ids=input_boxes,
|
783 |
+
input_boxes_counts=input_boxes_counts,
|
784 |
+
cache_position=cache_position,
|
785 |
+
attention_mask=attention_mask,
|
786 |
+
encoder_hidden_states=encoder_hidden_states,
|
787 |
+
encoder_attention_mask=encoder_attention_mask,
|
788 |
+
use_cache=use_cache,
|
789 |
+
output_hidden_states=True,
|
790 |
+
return_dict=True,
|
791 |
+
)
|
792 |
+
|
793 |
+
return TextEncoderOutput(
|
794 |
+
hidden_states=outputs.last_hidden_state,
|
795 |
+
)
|
surya/model/table_rec/encoderdecoder.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional, Union, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
|
9 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
10 |
+
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
|
11 |
+
from surya.model.table_rec.decoder import SuryaTableRecTextEncoder, SuryaTableRecDecoder
|
12 |
+
from surya.model.recognition.encoder import DonutSwinModel
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from transformers.utils import ModelOutput
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class TableRecOutput(ModelOutput):
|
19 |
+
row_logits: torch.FloatTensor = None
|
20 |
+
col_logits: torch.FloatTensor = None
|
21 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
22 |
+
|
23 |
+
|
24 |
+
class TableRecEncoderDecoderModel(PreTrainedModel):
|
25 |
+
config_class = VisionEncoderDecoderConfig
|
26 |
+
base_model_prefix = "vision_encoder_decoder"
|
27 |
+
main_input_name = "pixel_values"
|
28 |
+
supports_gradient_checkpointing = True
|
29 |
+
_supports_param_buffer_assignment = False
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
config: Optional[PretrainedConfig] = None,
|
34 |
+
encoder: Optional[PreTrainedModel] = None,
|
35 |
+
text_encoder: Optional[PreTrainedModel] = None,
|
36 |
+
decoder: Optional[PreTrainedModel] = None,
|
37 |
+
):
|
38 |
+
# initialize with config
|
39 |
+
# make sure input & output embeddings is not tied
|
40 |
+
config.tie_word_embeddings = False
|
41 |
+
config.decoder.tie_word_embeddings = False
|
42 |
+
super().__init__(config)
|
43 |
+
|
44 |
+
if encoder is None:
|
45 |
+
encoder = DonutSwinModel(config.encoder)
|
46 |
+
|
47 |
+
if text_encoder is None:
|
48 |
+
text_encoder = SuryaTableRecTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
|
49 |
+
|
50 |
+
if decoder is None:
|
51 |
+
decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation)
|
52 |
+
|
53 |
+
self.encoder = encoder
|
54 |
+
self.decoder = decoder
|
55 |
+
self.text_encoder = text_encoder
|
56 |
+
|
57 |
+
# make sure that the individual model's config refers to the shared config
|
58 |
+
# so that the updates to the config will be synced
|
59 |
+
self.encoder.config = self.config.encoder
|
60 |
+
self.decoder.config = self.config.decoder
|
61 |
+
self.text_encoder.config = self.config.text_encoder
|
62 |
+
|
63 |
+
def get_encoder(self):
|
64 |
+
return self.encoder
|
65 |
+
|
66 |
+
def get_decoder(self):
|
67 |
+
return self.decoder
|
68 |
+
|
69 |
+
def get_output_embeddings(self):
|
70 |
+
return self.decoder.get_output_embeddings()
|
71 |
+
|
72 |
+
def set_output_embeddings(self, new_embeddings):
|
73 |
+
return self.decoder.set_output_embeddings(new_embeddings)
|
74 |
+
|
75 |
+
def forward(
|
76 |
+
self,
|
77 |
+
decoder_input_ids: torch.LongTensor = None,
|
78 |
+
decoder_cache_position: Optional[torch.LongTensor] = None,
|
79 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
80 |
+
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
81 |
+
use_cache: Optional[bool] = None,
|
82 |
+
return_dict: Optional[bool] = None,
|
83 |
+
**kwargs,
|
84 |
+
) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
|
85 |
+
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
86 |
+
|
87 |
+
kwargs_decoder = {
|
88 |
+
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
89 |
+
}
|
90 |
+
|
91 |
+
# Decode
|
92 |
+
decoder_outputs = self.decoder(
|
93 |
+
input_labels=decoder_input_ids,
|
94 |
+
input_boxes_counts=None,
|
95 |
+
cache_position=decoder_cache_position,
|
96 |
+
attention_mask=decoder_attention_mask,
|
97 |
+
encoder_hidden_states=encoder_outputs,
|
98 |
+
encoder_attention_mask=None,
|
99 |
+
use_cache=use_cache,
|
100 |
+
**kwargs_decoder,
|
101 |
+
)
|
102 |
+
|
103 |
+
return TableRecOutput(
|
104 |
+
row_logits=decoder_outputs.row_logits,
|
105 |
+
col_logits=decoder_outputs.col_logits,
|
106 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
107 |
+
)
|
108 |
+
|
109 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
110 |
+
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
111 |
+
|
112 |
+
def prepare_inputs_for_generation(
|
113 |
+
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
114 |
+
):
|
115 |
+
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
116 |
+
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
117 |
+
input_dict = {
|
118 |
+
"attention_mask": attention_mask,
|
119 |
+
"decoder_attention_mask": decoder_attention_mask,
|
120 |
+
"decoder_input_ids": decoder_inputs["input_ids"],
|
121 |
+
"encoder_outputs": encoder_outputs,
|
122 |
+
"past_key_values": decoder_inputs["past_key_values"],
|
123 |
+
"use_cache": use_cache,
|
124 |
+
}
|
125 |
+
return input_dict
|
126 |
+
|
127 |
+
def resize_token_embeddings(self, *args, **kwargs):
|
128 |
+
raise NotImplementedError(
|
129 |
+
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
|
130 |
+
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
|
131 |
+
)
|
132 |
+
|
133 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
134 |
+
# apply decoder cache reordering here
|
135 |
+
return self.decoder._reorder_cache(past_key_values, beam_idx)
|
surya/model/table_rec/model.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from surya.model.recognition.encoder import DonutSwinModel
|
2 |
+
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
|
3 |
+
SuryaTableRecTextEncoderConfig
|
4 |
+
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
|
5 |
+
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
6 |
+
from surya.settings import settings
|
7 |
+
|
8 |
+
|
9 |
+
def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
10 |
+
|
11 |
+
config = SuryaTableRecConfig.from_pretrained(checkpoint)
|
12 |
+
decoder_config = config.decoder
|
13 |
+
decoder = SuryaTableRecDecoderConfig(**decoder_config)
|
14 |
+
config.decoder = decoder
|
15 |
+
|
16 |
+
encoder_config = config.encoder
|
17 |
+
encoder = DonutSwinTableRecConfig(**encoder_config)
|
18 |
+
config.encoder = encoder
|
19 |
+
|
20 |
+
text_encoder_config = config.text_encoder
|
21 |
+
text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config)
|
22 |
+
config.text_encoder = text_encoder
|
23 |
+
|
24 |
+
model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
|
25 |
+
|
26 |
+
assert isinstance(model.decoder, SuryaTableRecDecoder)
|
27 |
+
assert isinstance(model.encoder, DonutSwinModel)
|
28 |
+
assert isinstance(model.text_encoder, SuryaTableRecTextEncoder)
|
29 |
+
|
30 |
+
model = model.to(device)
|
31 |
+
model = model.eval()
|
32 |
+
|
33 |
+
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
|
34 |
+
return model
|
surya/model/table_rec/processor.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, Union, Optional, List, Iterable
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from torch import TensorType
|
7 |
+
from transformers import DonutImageProcessor, DonutProcessor
|
8 |
+
from transformers.image_processing_utils import BatchFeature
|
9 |
+
from transformers.image_transforms import pad, normalize
|
10 |
+
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import PIL
|
14 |
+
from surya.model.recognition.tokenizer import Byt5LangTokenizer
|
15 |
+
from surya.settings import settings
|
16 |
+
from surya.model.table_rec.config import BOX_DIM, SPECIAL_TOKENS
|
17 |
+
|
18 |
+
|
19 |
+
def load_processor():
|
20 |
+
processor = SuryaProcessor()
|
21 |
+
processor.image_processor.train = False
|
22 |
+
processor.image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
|
23 |
+
|
24 |
+
processor.token_pad_id = 0
|
25 |
+
processor.token_eos_id = 1
|
26 |
+
processor.token_bos_id = 2
|
27 |
+
processor.token_row_id = 3
|
28 |
+
processor.token_unused_id = 4
|
29 |
+
processor.box_size = (BOX_DIM, BOX_DIM)
|
30 |
+
processor.special_token_count = SPECIAL_TOKENS
|
31 |
+
return processor
|
32 |
+
|
33 |
+
|
34 |
+
class SuryaImageProcessor(DonutImageProcessor):
|
35 |
+
def __init__(self, *args, max_size=None, train=False, **kwargs):
|
36 |
+
super().__init__(*args, **kwargs)
|
37 |
+
|
38 |
+
self.patch_size = kwargs.get("patch_size", (4, 4))
|
39 |
+
self.max_size = max_size
|
40 |
+
self.train = train
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
|
44 |
+
max_width, max_height = size["width"], size["height"]
|
45 |
+
|
46 |
+
resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
|
47 |
+
resized_image = resized_image.transpose(2, 0, 1)
|
48 |
+
|
49 |
+
return resized_image
|
50 |
+
|
51 |
+
def process_inner(self, images: List[np.ndarray]):
|
52 |
+
assert images[0].shape[2] == 3 # RGB input images, channel dim last
|
53 |
+
|
54 |
+
# This also applies the right channel dim format, to channel x height x width
|
55 |
+
images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
|
56 |
+
assert images[0].shape[0] == 3 # RGB input images, channel dim first
|
57 |
+
|
58 |
+
# Convert to float32 for rescale/normalize
|
59 |
+
images = [img.astype(np.float32) for img in images]
|
60 |
+
|
61 |
+
# Pads with 255 (whitespace)
|
62 |
+
# Pad to max size to improve performance
|
63 |
+
max_size = self.max_size
|
64 |
+
images = [
|
65 |
+
SuryaImageProcessor.pad_image(
|
66 |
+
image=image,
|
67 |
+
size=max_size,
|
68 |
+
input_data_format=ChannelDimension.FIRST,
|
69 |
+
pad_value=settings.RECOGNITION_PAD_VALUE
|
70 |
+
)
|
71 |
+
for image in images
|
72 |
+
]
|
73 |
+
# Rescale and normalize
|
74 |
+
for idx in range(len(images)):
|
75 |
+
images[idx] = images[idx] * self.rescale_factor
|
76 |
+
images = [
|
77 |
+
SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
|
78 |
+
for img in images
|
79 |
+
]
|
80 |
+
|
81 |
+
return images
|
82 |
+
|
83 |
+
def preprocess(
|
84 |
+
self,
|
85 |
+
images: ImageInput,
|
86 |
+
do_resize: bool = None,
|
87 |
+
size: Dict[str, int] = None,
|
88 |
+
resample: PILImageResampling = None,
|
89 |
+
do_thumbnail: bool = None,
|
90 |
+
do_align_long_axis: bool = None,
|
91 |
+
do_pad: bool = None,
|
92 |
+
random_padding: bool = False,
|
93 |
+
do_rescale: bool = None,
|
94 |
+
rescale_factor: float = None,
|
95 |
+
do_normalize: bool = None,
|
96 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
97 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
98 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
99 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
100 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
101 |
+
**kwargs,
|
102 |
+
) -> PIL.Image.Image:
|
103 |
+
images = make_list_of_images(images)
|
104 |
+
|
105 |
+
# Convert to numpy for later processing steps
|
106 |
+
images = [np.array(img) for img in images]
|
107 |
+
images = self.process_inner(images)
|
108 |
+
|
109 |
+
data = {"pixel_values": images}
|
110 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def pad_image(
|
114 |
+
cls,
|
115 |
+
image: np.ndarray,
|
116 |
+
size: Dict[str, int],
|
117 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
118 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
119 |
+
pad_value: float = 0.0,
|
120 |
+
) -> np.ndarray:
|
121 |
+
output_height, output_width = size["height"], size["width"]
|
122 |
+
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
123 |
+
|
124 |
+
delta_width = output_width - input_width
|
125 |
+
delta_height = output_height - input_height
|
126 |
+
|
127 |
+
assert delta_width >= 0 and delta_height >= 0
|
128 |
+
|
129 |
+
pad_top = delta_height // 2
|
130 |
+
pad_left = delta_width // 2
|
131 |
+
|
132 |
+
pad_bottom = delta_height - pad_top
|
133 |
+
pad_right = delta_width - pad_left
|
134 |
+
|
135 |
+
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
136 |
+
return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def align_long_axis(
|
140 |
+
cls,
|
141 |
+
image: np.ndarray,
|
142 |
+
size: Dict[str, int],
|
143 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
144 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
145 |
+
) -> np.ndarray:
|
146 |
+
input_height, input_width = image.shape[:2]
|
147 |
+
output_height, output_width = size["height"], size["width"]
|
148 |
+
|
149 |
+
if (output_width < output_height and input_width > input_height) or (
|
150 |
+
output_width > output_height and input_width < input_height
|
151 |
+
):
|
152 |
+
image = np.rot90(image, 3)
|
153 |
+
|
154 |
+
return image
|
155 |
+
|
156 |
+
@classmethod
|
157 |
+
def normalize(
|
158 |
+
cls,
|
159 |
+
image: np.ndarray,
|
160 |
+
mean: Union[float, Iterable[float]],
|
161 |
+
std: Union[float, Iterable[float]],
|
162 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
163 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
164 |
+
**kwargs,
|
165 |
+
) -> np.ndarray:
|
166 |
+
return normalize(
|
167 |
+
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
class SuryaProcessor(DonutProcessor):
|
172 |
+
def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
|
173 |
+
image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
|
174 |
+
if image_processor is None:
|
175 |
+
raise ValueError("You need to specify an `image_processor`.")
|
176 |
+
|
177 |
+
tokenizer = Byt5LangTokenizer()
|
178 |
+
super().__init__(image_processor, tokenizer)
|
179 |
+
self.current_processor = self.image_processor
|
180 |
+
self._in_target_context_manager = False
|
181 |
+
self.max_input_boxes = kwargs.get("max_input_boxes", 256)
|
182 |
+
self.extra_input_boxes = kwargs.get("extra_input_boxes", 32)
|
183 |
+
|
184 |
+
def resize_boxes(self, img, boxes):
|
185 |
+
width, height = img.size
|
186 |
+
box_width, box_height = self.box_size
|
187 |
+
for box in boxes:
|
188 |
+
# Rescale to 0-1024
|
189 |
+
box[0] = box[0] / width * box_width
|
190 |
+
box[1] = box[1] / height * box_height
|
191 |
+
box[2] = box[2] / width * box_width
|
192 |
+
box[3] = box[3] / height * box_height
|
193 |
+
|
194 |
+
if box[0] < 0:
|
195 |
+
box[0] = 0
|
196 |
+
if box[1] < 0:
|
197 |
+
box[1] = 0
|
198 |
+
if box[2] > box_width:
|
199 |
+
box[2] = box_width
|
200 |
+
if box[3] > box_height:
|
201 |
+
box[3] = box_height
|
202 |
+
|
203 |
+
return boxes
|
204 |
+
|
205 |
+
def __call__(self, *args, **kwargs):
|
206 |
+
images = kwargs.pop("images", [])
|
207 |
+
boxes = kwargs.pop("boxes", [])
|
208 |
+
assert len(images) == len(boxes)
|
209 |
+
|
210 |
+
if len(args) > 0:
|
211 |
+
images = args[0]
|
212 |
+
args = args[1:]
|
213 |
+
|
214 |
+
for i in range(len(boxes)):
|
215 |
+
if len(boxes[i]) > self.max_input_boxes:
|
216 |
+
downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes)
|
217 |
+
boxes[i] = boxes[i][::downsample_ratio]
|
218 |
+
|
219 |
+
new_boxes = []
|
220 |
+
max_len = self.max_input_boxes + self.extra_input_boxes
|
221 |
+
box_masks = []
|
222 |
+
box_ends = []
|
223 |
+
for i in range(len(boxes)):
|
224 |
+
nb = self.resize_boxes(images[i], boxes[i])
|
225 |
+
nb = [[b + self.special_token_count for b in box] for box in nb] # shift up
|
226 |
+
nb = nb[:self.max_input_boxes - 1]
|
227 |
+
|
228 |
+
nb.insert(0, [self.token_row_id] * 4) # Insert special token for max rows/cols
|
229 |
+
for _ in range(self.extra_input_boxes):
|
230 |
+
nb.append([self.token_unused_id] * 4)
|
231 |
+
|
232 |
+
pad_length = max_len - len(nb)
|
233 |
+
box_mask = [1] * len(nb) + [1] * (pad_length)
|
234 |
+
box_ends.append(len(nb))
|
235 |
+
nb = nb + [[self.token_unused_id] * 4] * pad_length
|
236 |
+
|
237 |
+
new_boxes.append(nb)
|
238 |
+
box_masks.append(box_mask)
|
239 |
+
|
240 |
+
box_ends = torch.tensor(box_ends, dtype=torch.long)
|
241 |
+
box_starts = torch.tensor([0] * len(boxes), dtype=torch.long)
|
242 |
+
box_ranges = torch.stack([box_starts, box_ends], dim=1)
|
243 |
+
|
244 |
+
inputs = self.image_processor(images, *args, **kwargs)
|
245 |
+
inputs["input_boxes"] = torch.tensor(new_boxes, dtype=torch.long)
|
246 |
+
inputs["input_boxes_mask"] = torch.tensor(box_masks, dtype=torch.long)
|
247 |
+
inputs["input_boxes_counts"] = box_ranges
|
248 |
+
return inputs
|
surya/ocr.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import List
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from surya.detection import batch_text_detection
|
6 |
+
from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb
|
7 |
+
from surya.postprocessing.text import sort_text_lines
|
8 |
+
from surya.recognition import batch_recognition
|
9 |
+
from surya.schema import TextLine, OCRResult
|
10 |
+
|
11 |
+
|
12 |
+
def run_recognition(images: List[Image.Image], langs: List[List[str] | None], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]:
|
13 |
+
# Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format
|
14 |
+
assert bboxes is not None or polygons is not None
|
15 |
+
assert len(images) == len(langs), "You need to pass in one list of languages for each image"
|
16 |
+
|
17 |
+
images = convert_if_not_rgb(images)
|
18 |
+
|
19 |
+
slice_map = []
|
20 |
+
all_slices = []
|
21 |
+
all_langs = []
|
22 |
+
for idx, (image, lang) in enumerate(zip(images, langs)):
|
23 |
+
if polygons is not None:
|
24 |
+
slices = slice_polys_from_image(image, polygons[idx])
|
25 |
+
else:
|
26 |
+
slices = slice_bboxes_from_image(image, bboxes[idx])
|
27 |
+
slice_map.append(len(slices))
|
28 |
+
all_slices.extend(slices)
|
29 |
+
all_langs.extend([deepcopy(lang)] * len(slices))
|
30 |
+
|
31 |
+
rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)
|
32 |
+
|
33 |
+
predictions_by_image = []
|
34 |
+
slice_start = 0
|
35 |
+
for idx, (image, lang) in enumerate(zip(images, langs)):
|
36 |
+
slice_end = slice_start + slice_map[idx]
|
37 |
+
image_lines = rec_predictions[slice_start:slice_end]
|
38 |
+
slice_start = slice_end
|
39 |
+
|
40 |
+
text_lines = []
|
41 |
+
for i in range(len(image_lines)):
|
42 |
+
if polygons is not None:
|
43 |
+
poly = polygons[idx][i]
|
44 |
+
else:
|
45 |
+
bbox = bboxes[idx][i]
|
46 |
+
poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]
|
47 |
+
|
48 |
+
text_lines.append(TextLine(
|
49 |
+
text=image_lines[i],
|
50 |
+
polygon=poly
|
51 |
+
))
|
52 |
+
|
53 |
+
pred = OCRResult(
|
54 |
+
text_lines=text_lines,
|
55 |
+
languages=lang,
|
56 |
+
image_bbox=[0, 0, image.size[0], image.size[1]]
|
57 |
+
)
|
58 |
+
predictions_by_image.append(pred)
|
59 |
+
|
60 |
+
return predictions_by_image
|
61 |
+
|
62 |
+
|
63 |
+
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
|
64 |
+
images = convert_if_not_rgb(images)
|
65 |
+
highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images)
|
66 |
+
det_predictions = batch_text_detection(images, det_model, det_processor)
|
67 |
+
|
68 |
+
all_slices = []
|
69 |
+
slice_map = []
|
70 |
+
all_langs = []
|
71 |
+
|
72 |
+
for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)):
|
73 |
+
polygons = [p.polygon for p in det_pred.bboxes]
|
74 |
+
if highres_image:
|
75 |
+
width_scaler = highres_image.size[0] / image.size[0]
|
76 |
+
height_scaler = highres_image.size[1] / image.size[1]
|
77 |
+
scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons]
|
78 |
+
slices = slice_polys_from_image(highres_image, scaled_polygons)
|
79 |
+
else:
|
80 |
+
slices = slice_polys_from_image(image, polygons)
|
81 |
+
slice_map.append(len(slices))
|
82 |
+
all_langs.extend([lang] * len(slices))
|
83 |
+
all_slices.extend(slices)
|
84 |
+
|
85 |
+
rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)
|
86 |
+
|
87 |
+
predictions_by_image = []
|
88 |
+
slice_start = 0
|
89 |
+
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
|
90 |
+
slice_end = slice_start + slice_map[idx]
|
91 |
+
image_lines = rec_predictions[slice_start:slice_end]
|
92 |
+
line_confidences = confidence_scores[slice_start:slice_end]
|
93 |
+
slice_start = slice_end
|
94 |
+
|
95 |
+
assert len(image_lines) == len(det_pred.bboxes)
|
96 |
+
|
97 |
+
lines = []
|
98 |
+
for text_line, confidence, bbox in zip(image_lines, line_confidences, det_pred.bboxes):
|
99 |
+
lines.append(TextLine(
|
100 |
+
text=text_line,
|
101 |
+
polygon=bbox.polygon,
|
102 |
+
bbox=bbox.bbox,
|
103 |
+
confidence=confidence
|
104 |
+
))
|
105 |
+
|
106 |
+
lines = sort_text_lines(lines)
|
107 |
+
|
108 |
+
predictions_by_image.append(OCRResult(
|
109 |
+
text_lines=lines,
|
110 |
+
languages=lang,
|
111 |
+
image_bbox=det_pred.image_bbox
|
112 |
+
))
|
113 |
+
|
114 |
+
return predictions_by_image
|
surya/ordering.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import List
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from surya.input.processing import convert_if_not_rgb
|
7 |
+
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
|
8 |
+
from surya.schema import OrderBox, OrderResult
|
9 |
+
from surya.settings import settings
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def get_batch_size():
|
15 |
+
batch_size = settings.ORDER_BATCH_SIZE
|
16 |
+
if batch_size is None:
|
17 |
+
batch_size = 8
|
18 |
+
if settings.TORCH_DEVICE_MODEL == "mps":
|
19 |
+
batch_size = 8
|
20 |
+
if settings.TORCH_DEVICE_MODEL == "cuda":
|
21 |
+
batch_size = 32
|
22 |
+
return batch_size
|
23 |
+
|
24 |
+
|
25 |
+
def rank_elements(arr):
|
26 |
+
enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1])
|
27 |
+
rank = [0] * len(arr)
|
28 |
+
|
29 |
+
for rank_value, (original_index, value) in enumerate(enumerated_and_sorted):
|
30 |
+
rank[original_index] = rank_value
|
31 |
+
|
32 |
+
return rank
|
33 |
+
|
34 |
+
|
35 |
+
def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]:
|
36 |
+
assert all([isinstance(image, Image.Image) for image in images])
|
37 |
+
assert len(images) == len(bboxes)
|
38 |
+
if batch_size is None:
|
39 |
+
batch_size = get_batch_size()
|
40 |
+
|
41 |
+
|
42 |
+
output_order = []
|
43 |
+
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"):
|
44 |
+
batch_bboxes = deepcopy(bboxes[i:i+batch_size])
|
45 |
+
batch_images = images[i:i+batch_size]
|
46 |
+
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
|
47 |
+
|
48 |
+
orig_sizes = [image.size for image in batch_images]
|
49 |
+
model_inputs = processor(images=batch_images, boxes=batch_bboxes)
|
50 |
+
|
51 |
+
batch_pixel_values = model_inputs["pixel_values"]
|
52 |
+
batch_bboxes = model_inputs["input_boxes"]
|
53 |
+
batch_bbox_mask = model_inputs["input_boxes_mask"]
|
54 |
+
batch_bbox_counts = model_inputs["input_boxes_counts"]
|
55 |
+
|
56 |
+
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
|
57 |
+
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
|
58 |
+
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
|
59 |
+
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
|
60 |
+
|
61 |
+
token_count = 0
|
62 |
+
past_key_values = None
|
63 |
+
encoder_outputs = None
|
64 |
+
batch_predictions = [[] for _ in range(len(batch_images))]
|
65 |
+
done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)
|
66 |
+
|
67 |
+
with torch.inference_mode():
|
68 |
+
while token_count < settings.ORDER_MAX_BOXES:
|
69 |
+
return_dict = model(
|
70 |
+
pixel_values=batch_pixel_values,
|
71 |
+
decoder_input_boxes=batch_bboxes,
|
72 |
+
decoder_input_boxes_mask=batch_bbox_mask,
|
73 |
+
decoder_input_boxes_counts=batch_bbox_counts,
|
74 |
+
encoder_outputs=encoder_outputs,
|
75 |
+
past_key_values=past_key_values,
|
76 |
+
)
|
77 |
+
logits = return_dict["logits"].detach()
|
78 |
+
|
79 |
+
last_tokens = []
|
80 |
+
last_token_mask = []
|
81 |
+
min_val = torch.finfo(model.dtype).min
|
82 |
+
for j in range(logits.shape[0]):
|
83 |
+
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
|
84 |
+
new_logits = logits[j, -1]
|
85 |
+
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
|
86 |
+
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
|
87 |
+
pred = int(torch.argmax(new_logits, dim=-1).item())
|
88 |
+
|
89 |
+
# Add one to avoid colliding with the 1000 height/width token for bboxes
|
90 |
+
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4])
|
91 |
+
if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label
|
92 |
+
last_token_mask.append([0])
|
93 |
+
batch_predictions[j].append(pred)
|
94 |
+
done[j] = True
|
95 |
+
elif len(batch_predictions[j]) < label_count - 1:
|
96 |
+
last_token_mask.append([1])
|
97 |
+
batch_predictions[j].append(pred) # Get rank prediction for given position
|
98 |
+
else:
|
99 |
+
last_token_mask.append([0])
|
100 |
+
|
101 |
+
if done.all():
|
102 |
+
break
|
103 |
+
|
104 |
+
past_key_values = return_dict["past_key_values"]
|
105 |
+
encoder_outputs = (return_dict["encoder_last_hidden_state"],)
|
106 |
+
|
107 |
+
batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
|
108 |
+
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
|
109 |
+
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1)
|
110 |
+
token_count += 1
|
111 |
+
|
112 |
+
for j, row_pred in enumerate(batch_predictions):
|
113 |
+
row_bboxes = bboxes[i+j]
|
114 |
+
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}"
|
115 |
+
|
116 |
+
orig_size = orig_sizes[j]
|
117 |
+
ranks = [0] * len(row_bboxes)
|
118 |
+
|
119 |
+
for box_idx in range(len(row_bboxes)):
|
120 |
+
ranks[row_pred[box_idx]] = box_idx
|
121 |
+
|
122 |
+
order_boxes = []
|
123 |
+
for row_bbox, rank in zip(row_bboxes, ranks):
|
124 |
+
order_box = OrderBox(
|
125 |
+
bbox=row_bbox,
|
126 |
+
position=rank,
|
127 |
+
)
|
128 |
+
order_boxes.append(order_box)
|
129 |
+
|
130 |
+
result = OrderResult(
|
131 |
+
bboxes=order_boxes,
|
132 |
+
image_bbox=[0, 0, orig_size[0], orig_size[1]],
|
133 |
+
)
|
134 |
+
output_order.append(result)
|
135 |
+
return output_order
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
surya/postprocessing/affinity.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
|
8 |
+
from surya.postprocessing.util import get_line_angle, rescale_bbox
|
9 |
+
from surya.schema import ColumnLine
|
10 |
+
|
11 |
+
|
12 |
+
def get_detected_lines_sobel(image, vertical=True):
|
13 |
+
# Apply Sobel operator with a kernel size of 3 to detect vertical edges
|
14 |
+
if vertical:
|
15 |
+
dx = 1
|
16 |
+
dy = 0
|
17 |
+
else:
|
18 |
+
dx = 0
|
19 |
+
dy = 1
|
20 |
+
|
21 |
+
sobelx = cv2.Sobel(image, cv2.CV_32F, dx, dy, ksize=3)
|
22 |
+
|
23 |
+
|
24 |
+
# Absolute Sobel (to capture both edges)
|
25 |
+
abs_sobelx = np.absolute(sobelx)
|
26 |
+
|
27 |
+
# Convert to 8-bit image
|
28 |
+
scaled_sobel = np.uint8(255 * abs_sobelx / np.max(abs_sobelx))
|
29 |
+
|
30 |
+
kernel = np.ones((20, 1), np.uint8)
|
31 |
+
eroded = cv2.erode(scaled_sobel, kernel, iterations=1)
|
32 |
+
scaled_sobel = cv2.dilate(eroded, kernel, iterations=3)
|
33 |
+
|
34 |
+
return scaled_sobel
|
35 |
+
|
36 |
+
|
37 |
+
def get_detected_lines(image, slope_tol_deg=2, vertical=False, horizontal=False) -> List[ColumnLine]:
|
38 |
+
assert not (vertical and horizontal)
|
39 |
+
new_image = image.astype(np.float32) * 255 # Convert to 0-255 range
|
40 |
+
if vertical or horizontal:
|
41 |
+
new_image = get_detected_lines_sobel(new_image, vertical)
|
42 |
+
new_image = new_image.astype(np.uint8)
|
43 |
+
|
44 |
+
edges = cv2.Canny(new_image, 150, 200, apertureSize=3)
|
45 |
+
if vertical:
|
46 |
+
max_gap = 100
|
47 |
+
min_length = 10
|
48 |
+
else:
|
49 |
+
max_gap = 10
|
50 |
+
min_length = 4
|
51 |
+
|
52 |
+
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=150, minLineLength=min_length, maxLineGap=max_gap)
|
53 |
+
|
54 |
+
line_info = []
|
55 |
+
if lines is not None:
|
56 |
+
for line in lines:
|
57 |
+
vertical_line = False
|
58 |
+
horizontal_line = False
|
59 |
+
x1, y1, x2, y2 = line[0]
|
60 |
+
bbox = [x1, y1, x2, y2]
|
61 |
+
|
62 |
+
if x2 == x1:
|
63 |
+
vertical_line = True
|
64 |
+
else:
|
65 |
+
line_angle = get_line_angle(x1, y1, x2, y2)
|
66 |
+
if 90 - slope_tol_deg < line_angle < 90 + slope_tol_deg:
|
67 |
+
vertical_line = True
|
68 |
+
elif -90 - slope_tol_deg < line_angle < -90 + slope_tol_deg:
|
69 |
+
vertical_line = True
|
70 |
+
elif -slope_tol_deg < line_angle < slope_tol_deg:
|
71 |
+
horizontal_line = True
|
72 |
+
|
73 |
+
if bbox[3] < bbox[1]:
|
74 |
+
bbox[1], bbox[3] = bbox[3], bbox[1]
|
75 |
+
if bbox[2] < bbox[0]:
|
76 |
+
bbox[0], bbox[2] = bbox[2], bbox[0]
|
77 |
+
row = ColumnLine(bbox=bbox, vertical=vertical_line, horizontal=horizontal_line)
|
78 |
+
line_info.append(row)
|
79 |
+
|
80 |
+
if vertical:
|
81 |
+
line_info = [line for line in line_info if line.vertical]
|
82 |
+
|
83 |
+
if horizontal:
|
84 |
+
line_info = [line for line in line_info if line.horizontal]
|
85 |
+
|
86 |
+
return line_info
|
87 |
+
|
88 |
+
|
89 |
+
def draw_lines_on_image(line_info: List[ColumnLine], img):
|
90 |
+
draw = ImageDraw.Draw(img)
|
91 |
+
|
92 |
+
for line in line_info:
|
93 |
+
divisor = 20
|
94 |
+
if line.horizontal:
|
95 |
+
divisor = 200
|
96 |
+
x1, y1, x2, y2 = [x // divisor * divisor for x in line.bbox]
|
97 |
+
if line.vertical:
|
98 |
+
draw.line((x1, y1, x2, y2), fill="red", width=3)
|
99 |
+
|
100 |
+
return img
|
101 |
+
|
102 |
+
|
103 |
+
def get_vertical_lines(image, processor_size, image_size, divisor=20, x_tolerance=40, y_tolerance=20) -> List[ColumnLine]:
|
104 |
+
vertical_lines = get_detected_lines(image, vertical=True)
|
105 |
+
for line in vertical_lines:
|
106 |
+
line.rescale_bbox(processor_size, image_size)
|
107 |
+
vertical_lines = sorted(vertical_lines, key=lambda x: x.bbox[0])
|
108 |
+
for line in vertical_lines:
|
109 |
+
line.round_bbox(divisor)
|
110 |
+
|
111 |
+
# Merge adjacent line segments together
|
112 |
+
to_remove = []
|
113 |
+
for i, line in enumerate(vertical_lines):
|
114 |
+
for j, line2 in enumerate(vertical_lines):
|
115 |
+
if j <= i:
|
116 |
+
continue
|
117 |
+
if line.bbox[0] != line2.bbox[0]:
|
118 |
+
continue
|
119 |
+
|
120 |
+
expanded_line1 = [line.bbox[0], line.bbox[1] - y_tolerance, line.bbox[2],
|
121 |
+
line.bbox[3] + y_tolerance]
|
122 |
+
|
123 |
+
line1_points = set(range(int(expanded_line1[1]), int(expanded_line1[3])))
|
124 |
+
line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3])))
|
125 |
+
intersect_y = len(line1_points.intersection(line2_points)) > 0
|
126 |
+
|
127 |
+
if intersect_y:
|
128 |
+
vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1])
|
129 |
+
vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3])
|
130 |
+
to_remove.append(i)
|
131 |
+
|
132 |
+
vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove]
|
133 |
+
|
134 |
+
# Remove redundant segments
|
135 |
+
to_remove = []
|
136 |
+
for i, line in enumerate(vertical_lines):
|
137 |
+
if i in to_remove:
|
138 |
+
continue
|
139 |
+
for j, line2 in enumerate(vertical_lines):
|
140 |
+
if j <= i or j in to_remove:
|
141 |
+
continue
|
142 |
+
close_in_x = abs(line.bbox[0] - line2.bbox[0]) < x_tolerance
|
143 |
+
line1_points = set(range(int(line.bbox[1]), int(line.bbox[3])))
|
144 |
+
line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3])))
|
145 |
+
|
146 |
+
intersect_y = len(line1_points.intersection(line2_points)) > 0
|
147 |
+
|
148 |
+
if close_in_x and intersect_y:
|
149 |
+
# Keep the longer line and extend it
|
150 |
+
if len(line2_points) > len(line1_points):
|
151 |
+
vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1])
|
152 |
+
vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3])
|
153 |
+
to_remove.append(i)
|
154 |
+
else:
|
155 |
+
vertical_lines[i].bbox[1] = min(line.bbox[1], line2.bbox[1])
|
156 |
+
vertical_lines[i].bbox[3] = max(line.bbox[3], line2.bbox[3])
|
157 |
+
to_remove.append(j)
|
158 |
+
|
159 |
+
vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove]
|
160 |
+
|
161 |
+
if len(vertical_lines) > 0:
|
162 |
+
# Always start with top left of page
|
163 |
+
vertical_lines[0].bbox[1] = 0
|
164 |
+
|
165 |
+
return vertical_lines
|
surya/postprocessing/fonts.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from surya.settings import settings
|
6 |
+
|
7 |
+
|
8 |
+
def get_font_path(langs: Optional[List[str]] = None) -> str:
|
9 |
+
font_path = settings.RECOGNITION_RENDER_FONTS["all"]
|
10 |
+
if langs is not None:
|
11 |
+
for k in settings.RECOGNITION_RENDER_FONTS:
|
12 |
+
if k in langs and len(langs) == 1:
|
13 |
+
font_path = settings.RECOGNITION_RENDER_FONTS[k]
|
14 |
+
break
|
15 |
+
|
16 |
+
if not os.path.exists(font_path):
|
17 |
+
os.makedirs(os.path.dirname(font_path), exist_ok=True)
|
18 |
+
font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}"
|
19 |
+
with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f:
|
20 |
+
r.raise_for_status()
|
21 |
+
for chunk in r.iter_content(chunk_size=8192):
|
22 |
+
f.write(chunk)
|
23 |
+
|
24 |
+
return font_path
|
surya/postprocessing/heatmap.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import math
|
6 |
+
from PIL import ImageDraw, ImageFont
|
7 |
+
|
8 |
+
from surya.postprocessing.fonts import get_font_path
|
9 |
+
from surya.postprocessing.util import rescale_bbox
|
10 |
+
from surya.schema import PolygonBox
|
11 |
+
from surya.settings import settings
|
12 |
+
from surya.postprocessing.text import get_text_size
|
13 |
+
|
14 |
+
|
15 |
+
def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
|
16 |
+
new_boxes = []
|
17 |
+
for box_obj in boxes:
|
18 |
+
box = box_obj.bbox
|
19 |
+
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
20 |
+
contained = False
|
21 |
+
for other_box_obj in boxes:
|
22 |
+
if other_box_obj.polygon == box_obj.polygon:
|
23 |
+
continue
|
24 |
+
|
25 |
+
other_box = other_box_obj.bbox
|
26 |
+
other_box_area = (other_box[2] - other_box[0]) * (other_box[3] - other_box[1])
|
27 |
+
if box == other_box:
|
28 |
+
continue
|
29 |
+
# find overlap percentage
|
30 |
+
overlap = box_obj.intersection_pct(other_box_obj)
|
31 |
+
if overlap > .9 and box_area < other_box_area:
|
32 |
+
contained = True
|
33 |
+
break
|
34 |
+
if not contained:
|
35 |
+
new_boxes.append(box_obj)
|
36 |
+
return new_boxes
|
37 |
+
|
38 |
+
|
39 |
+
def clean_contained_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
|
40 |
+
new_boxes = []
|
41 |
+
for box_obj in boxes:
|
42 |
+
box = box_obj.bbox
|
43 |
+
contained = False
|
44 |
+
for other_box_obj in boxes:
|
45 |
+
if other_box_obj.polygon == box_obj.polygon:
|
46 |
+
continue
|
47 |
+
|
48 |
+
other_box = other_box_obj.bbox
|
49 |
+
if box == other_box:
|
50 |
+
continue
|
51 |
+
if box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3]:
|
52 |
+
contained = True
|
53 |
+
break
|
54 |
+
if not contained:
|
55 |
+
new_boxes.append(box_obj)
|
56 |
+
return new_boxes
|
57 |
+
|
58 |
+
|
59 |
+
def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7):
|
60 |
+
# Find average intensity of top 10% pixels
|
61 |
+
flat_map = linemap.ravel()
|
62 |
+
top_10_count = int(len(flat_map) * 0.9)
|
63 |
+
avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:])
|
64 |
+
scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2)
|
65 |
+
|
66 |
+
low_text = np.clip(low_text * scaling_factor, 0.1, 0.6)
|
67 |
+
text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8)
|
68 |
+
|
69 |
+
return text_threshold, low_text
|
70 |
+
|
71 |
+
|
72 |
+
def detect_boxes(linemap, text_threshold, low_text):
|
73 |
+
# From CRAFT - https://github.com/clovaai/CRAFT-pytorch
|
74 |
+
# Modified to return boxes and for speed, accuracy
|
75 |
+
img_h, img_w = linemap.shape
|
76 |
+
|
77 |
+
text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text)
|
78 |
+
|
79 |
+
text_score_comb = (linemap > low_text).astype(np.uint8)
|
80 |
+
label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb, connectivity=4)
|
81 |
+
|
82 |
+
det = []
|
83 |
+
confidences = []
|
84 |
+
max_confidence = 0
|
85 |
+
|
86 |
+
for k in range(1, label_count):
|
87 |
+
# size filtering
|
88 |
+
size = stats[k, cv2.CC_STAT_AREA]
|
89 |
+
if size < 10:
|
90 |
+
continue
|
91 |
+
|
92 |
+
# make segmentation map
|
93 |
+
x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]]
|
94 |
+
|
95 |
+
try:
|
96 |
+
niter = int(np.sqrt(min(w, h)))
|
97 |
+
except ValueError:
|
98 |
+
niter = 0
|
99 |
+
|
100 |
+
buffer = 1
|
101 |
+
sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
|
102 |
+
ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)
|
103 |
+
|
104 |
+
mask = (labels[sy:ey, sx:ex] == k)
|
105 |
+
selected_linemap = linemap[sy:ey, sx:ex][mask]
|
106 |
+
line_max = np.max(selected_linemap)
|
107 |
+
|
108 |
+
# thresholding
|
109 |
+
if line_max < text_threshold:
|
110 |
+
continue
|
111 |
+
|
112 |
+
segmap = mask.astype(np.uint8)
|
113 |
+
|
114 |
+
ksize = buffer + niter
|
115 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize))
|
116 |
+
selected_segmap = cv2.dilate(segmap, kernel)
|
117 |
+
|
118 |
+
# make box
|
119 |
+
indices = np.nonzero(selected_segmap)
|
120 |
+
x_inds = indices[1] + sx
|
121 |
+
y_inds = indices[0] + sy
|
122 |
+
np_contours = np.column_stack((x_inds, y_inds))
|
123 |
+
rectangle = cv2.minAreaRect(np_contours)
|
124 |
+
box = cv2.boxPoints(rectangle)
|
125 |
+
|
126 |
+
# align diamond-shape
|
127 |
+
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
128 |
+
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
129 |
+
if abs(1 - box_ratio) <= 0.1:
|
130 |
+
l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
|
131 |
+
t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
|
132 |
+
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
133 |
+
|
134 |
+
# make clock-wise order
|
135 |
+
startidx = box.sum(axis=1).argmin()
|
136 |
+
box = np.roll(box, 4-startidx, 0)
|
137 |
+
box = np.array(box)
|
138 |
+
|
139 |
+
confidence = line_max
|
140 |
+
max_confidence = max(max_confidence, line_max)
|
141 |
+
|
142 |
+
confidences.append(confidence)
|
143 |
+
det.append(box)
|
144 |
+
|
145 |
+
if max_confidence > 0:
|
146 |
+
confidences = [c / max_confidence for c in confidences]
|
147 |
+
return det, confidences
|
148 |
+
|
149 |
+
|
150 |
+
def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]:
|
151 |
+
if text_threshold is None:
|
152 |
+
text_threshold = settings.DETECTOR_TEXT_THRESHOLD
|
153 |
+
|
154 |
+
if low_text is None:
|
155 |
+
low_text = settings.DETECTOR_BLANK_THRESHOLD
|
156 |
+
|
157 |
+
textmap = textmap.copy()
|
158 |
+
textmap = textmap.astype(np.float32)
|
159 |
+
boxes, confidences = detect_boxes(textmap, text_threshold, low_text)
|
160 |
+
# From point form to box form
|
161 |
+
boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)]
|
162 |
+
return boxes
|
163 |
+
|
164 |
+
|
165 |
+
def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None, low_text=None) -> List[PolygonBox]:
|
166 |
+
bboxes = get_detected_boxes(textmap, text_threshold, low_text)
|
167 |
+
for bbox in bboxes:
|
168 |
+
bbox.rescale(processor_size, image_size)
|
169 |
+
bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]])
|
170 |
+
|
171 |
+
bboxes = clean_contained_boxes(bboxes)
|
172 |
+
return bboxes
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'):
|
177 |
+
polys = []
|
178 |
+
for bb in bboxes:
|
179 |
+
# Clockwise polygon
|
180 |
+
poly = [
|
181 |
+
[bb[0], bb[1]],
|
182 |
+
[bb[2], bb[1]],
|
183 |
+
[bb[2], bb[3]],
|
184 |
+
[bb[0], bb[3]]
|
185 |
+
]
|
186 |
+
polys.append(poly)
|
187 |
+
|
188 |
+
return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color)
|
189 |
+
|
190 |
+
|
191 |
+
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'):
|
192 |
+
draw = ImageDraw.Draw(image)
|
193 |
+
font_path = get_font_path()
|
194 |
+
label_font = ImageFont.truetype(font_path, label_font_size)
|
195 |
+
|
196 |
+
for i in range(len(corners)):
|
197 |
+
poly = corners[i]
|
198 |
+
poly = [(int(p[0]), int(p[1])) for p in poly]
|
199 |
+
draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1)
|
200 |
+
|
201 |
+
if labels is not None:
|
202 |
+
label = labels[i]
|
203 |
+
text_position = (
|
204 |
+
min([p[0] for p in poly]) + label_offset,
|
205 |
+
min([p[1] for p in poly]) + label_offset
|
206 |
+
)
|
207 |
+
text_size = get_text_size(label, label_font)
|
208 |
+
box_position = (
|
209 |
+
text_position[0] - box_padding + label_offset,
|
210 |
+
text_position[1] - box_padding + label_offset,
|
211 |
+
text_position[0] + text_size[0] + box_padding + label_offset,
|
212 |
+
text_position[1] + text_size[1] + box_padding + label_offset
|
213 |
+
)
|
214 |
+
draw.rectangle(box_position, fill="white")
|
215 |
+
draw.text(
|
216 |
+
text_position,
|
217 |
+
label,
|
218 |
+
fill=color[i] if isinstance(color, list) else color,
|
219 |
+
font=label_font
|
220 |
+
)
|
221 |
+
|
222 |
+
return image
|
223 |
+
|
224 |
+
|
surya/postprocessing/math/latex.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from ftfy import fix_text
|
3 |
+
|
4 |
+
|
5 |
+
def contains_math(text):
|
6 |
+
return text.startswith("$") or text.endswith("$")
|
7 |
+
|
8 |
+
|
9 |
+
def fix_math(text):
|
10 |
+
# Fix any issues with the text
|
11 |
+
text = fix_text(text)
|
12 |
+
|
13 |
+
# Remove LaTeX labels and references
|
14 |
+
text = remove_labels(text)
|
15 |
+
text = replace_katex_invalid(text)
|
16 |
+
text = fix_fences(text)
|
17 |
+
return text
|
18 |
+
|
19 |
+
|
20 |
+
def remove_labels(text):
|
21 |
+
pattern = r'\\label\{[^}]*\}'
|
22 |
+
text = re.sub(pattern, '', text)
|
23 |
+
|
24 |
+
ref_pattern = r'\\ref\{[^}]*\}'
|
25 |
+
text = re.sub(ref_pattern, '', text)
|
26 |
+
|
27 |
+
pageref_pattern = r'\\pageref\{[^}]*\}'
|
28 |
+
text = re.sub(pageref_pattern, '', text)
|
29 |
+
return text
|
30 |
+
|
31 |
+
|
32 |
+
def replace_katex_invalid(string):
|
33 |
+
# KaTeX cannot render all LaTeX, so we need to replace some things
|
34 |
+
string = re.sub(r'\\tag\{.*?\}', '', string)
|
35 |
+
string = re.sub(r'\\(?:Bigg?|bigg?)\{(.*?)\}', r'\1', string)
|
36 |
+
string = re.sub(r'\\quad\\mbox\{(.*?)\}', r'\1', string)
|
37 |
+
string = re.sub(r'\\mbox\{(.*?)\}', r'\1', string)
|
38 |
+
string = remove_inner_dollars(string)
|
39 |
+
return string
|
40 |
+
|
41 |
+
|
42 |
+
def remove_inner_dollars(text):
|
43 |
+
def replace_dollar(match):
|
44 |
+
# Replace single $ with nothing, keep $$ intact
|
45 |
+
math_block = match.group(1)
|
46 |
+
return '$$' + math_block.replace('$', '') + '$$'
|
47 |
+
|
48 |
+
pattern = r'\$\$(.*?)\$\$'
|
49 |
+
return re.sub(pattern, replace_dollar, text, flags=re.DOTALL)
|
50 |
+
|
51 |
+
|
52 |
+
def extract_latex_with_positions(text):
|
53 |
+
pattern = r'(\$\$.*?\$\$|\$.*?\$)'
|
54 |
+
matches = []
|
55 |
+
for match in re.finditer(pattern, text, re.DOTALL):
|
56 |
+
matches.append((match.group(), match.start(), match.end()))
|
57 |
+
return matches
|
58 |
+
|
59 |
+
|
60 |
+
def slice_latex(text):
|
61 |
+
# Extract LaTeX blocks along with their positions
|
62 |
+
latex_blocks_with_positions = extract_latex_with_positions(text)
|
63 |
+
|
64 |
+
chunks = []
|
65 |
+
last_position = 0
|
66 |
+
for block, start, end in latex_blocks_with_positions:
|
67 |
+
# Add text before the current LaTeX block, if any
|
68 |
+
if start > last_position:
|
69 |
+
chunks.append({"text": text[last_position:start], "type": "text"})
|
70 |
+
# Add the LaTeX block
|
71 |
+
chunks.append({"text": block, "type": "latex"})
|
72 |
+
last_position = end
|
73 |
+
# Add remaining text after the last LaTeX block, if any
|
74 |
+
if last_position < len(text):
|
75 |
+
chunks.append({"text": text[last_position:], "type": "text"})
|
76 |
+
|
77 |
+
return chunks
|
78 |
+
|
79 |
+
|
80 |
+
def is_latex(text):
|
81 |
+
latex_patterns = [
|
82 |
+
r'\\(?:begin|end)\{[a-zA-Z]*\}',
|
83 |
+
r'\$.*?\$',
|
84 |
+
r'\$\$.*?\$\$',
|
85 |
+
r'\\[a-zA-Z]+',
|
86 |
+
r'\\[^a-zA-Z]',
|
87 |
+
]
|
88 |
+
|
89 |
+
combined_pattern = '|'.join(latex_patterns)
|
90 |
+
if re.search(combined_pattern, text, re.DOTALL):
|
91 |
+
return True
|
92 |
+
|
93 |
+
return False
|
94 |
+
|
95 |
+
|
96 |
+
def fix_fences(text):
|
97 |
+
if text.startswith("$$") and not text.endswith("$$"):
|
98 |
+
if text[-1] == "$":
|
99 |
+
text += "$"
|
100 |
+
else:
|
101 |
+
text += "$$"
|
102 |
+
|
103 |
+
if text.endswith("$$") and not text.startswith("$$"):
|
104 |
+
if text[0] == "$":
|
105 |
+
text = "$" + text
|
106 |
+
else:
|
107 |
+
text = "$$" + text
|
108 |
+
|
109 |
+
if text.startswith("$") and not text.endswith("$"):
|
110 |
+
text = "$" + text + "$$"
|
111 |
+
|
112 |
+
if text.endswith("$") and not text.startswith("$"):
|
113 |
+
text = "$$" + text + "$"
|
114 |
+
|
115 |
+
return text
|
116 |
+
|
117 |
+
|
118 |
+
def strip_fences(text):
|
119 |
+
while text.startswith("$"):
|
120 |
+
text = text[1:]
|
121 |
+
while text.endswith("$"):
|
122 |
+
text = text[:-1]
|
123 |
+
return text
|
124 |
+
|
125 |
+
|
surya/postprocessing/math/render.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from playwright.sync_api import sync_playwright
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
|
5 |
+
|
6 |
+
def latex_to_pil(latex_code, target_width, target_height, fontsize=18):
|
7 |
+
html_template = """
|
8 |
+
<!DOCTYPE html>
|
9 |
+
<html>
|
10 |
+
<head>
|
11 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.0/dist/katex.min.css">
|
12 |
+
<script src="https://cdn.jsdelivr.net/npm/katex@0.12.0/dist/katex.min.js"></script>
|
13 |
+
<style>
|
14 |
+
body {
|
15 |
+
margin: 0;
|
16 |
+
padding: 0;
|
17 |
+
display: flex;
|
18 |
+
}
|
19 |
+
#content {
|
20 |
+
font-size: {fontsize}px;
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
</head>
|
24 |
+
<body>
|
25 |
+
<div id="content">{content}</div>
|
26 |
+
<script>
|
27 |
+
function renderMath() {
|
28 |
+
let content = document.getElementById('content');
|
29 |
+
let html = content.innerHTML;
|
30 |
+
|
31 |
+
// Replace display equations
|
32 |
+
html = html.replace(/\\$\\$(.*?)\\$\\$/gs, (match, equation) => {
|
33 |
+
let span = document.createElement('span');
|
34 |
+
katex.render(equation, span, { displayMode: true, throwOnError: false, errorColor: '#000' });
|
35 |
+
return span.outerHTML;
|
36 |
+
});
|
37 |
+
|
38 |
+
// Replace inline equations
|
39 |
+
html = html.replace(/\\$(.*?)\\$/g, (match, equation) => {
|
40 |
+
if(match.startsWith('\\\\$')) return match; // Ignore escaped dollars
|
41 |
+
let span = document.createElement('span');
|
42 |
+
katex.render(equation, span, { displayMode: false, throwOnError: false, errorColor: '#000' });
|
43 |
+
return span.outerHTML;
|
44 |
+
});
|
45 |
+
|
46 |
+
content.innerHTML = html;
|
47 |
+
}
|
48 |
+
|
49 |
+
renderMath();
|
50 |
+
</script>
|
51 |
+
</body>
|
52 |
+
</html>
|
53 |
+
"""
|
54 |
+
|
55 |
+
formatted_latex = latex_code.replace('\n', '\\n').replace('"', '\\"')
|
56 |
+
with sync_playwright() as p:
|
57 |
+
browser = p.chromium.launch()
|
58 |
+
page = browser.new_page()
|
59 |
+
page.set_viewport_size({'width': target_width, 'height': target_height})
|
60 |
+
|
61 |
+
while fontsize <= 30:
|
62 |
+
html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize))
|
63 |
+
page.set_content(html_content)
|
64 |
+
|
65 |
+
dimensions = page.evaluate("""() => {
|
66 |
+
const render = document.getElementById('content');
|
67 |
+
return {
|
68 |
+
width: render.offsetWidth,
|
69 |
+
height: render.offsetHeight
|
70 |
+
};
|
71 |
+
}""")
|
72 |
+
|
73 |
+
if dimensions['width'] >= target_width or dimensions['height'] >= target_height:
|
74 |
+
fontsize -= 1
|
75 |
+
break
|
76 |
+
else:
|
77 |
+
fontsize += 1
|
78 |
+
|
79 |
+
html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize))
|
80 |
+
page.set_content(html_content)
|
81 |
+
|
82 |
+
screenshot_bytes = page.screenshot()
|
83 |
+
browser.close()
|
84 |
+
|
85 |
+
image_stream = io.BytesIO(screenshot_bytes)
|
86 |
+
pil_image = Image.open(image_stream)
|
87 |
+
pil_image.load()
|
88 |
+
return pil_image
|
surya/postprocessing/text.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
|
7 |
+
from surya.postprocessing.fonts import get_font_path
|
8 |
+
from surya.schema import TextLine
|
9 |
+
from surya.settings import settings
|
10 |
+
from surya.postprocessing.math.latex import is_latex
|
11 |
+
|
12 |
+
|
13 |
+
def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
|
14 |
+
# Sorts in reading order. Not 100% accurate, this should only
|
15 |
+
# be used as a starting point for more advanced sorting.
|
16 |
+
vertical_groups = {}
|
17 |
+
for line in lines:
|
18 |
+
group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance
|
19 |
+
if group_key not in vertical_groups:
|
20 |
+
vertical_groups[group_key] = []
|
21 |
+
vertical_groups[group_key].append(line)
|
22 |
+
|
23 |
+
# Sort each group horizontally and flatten the groups into a single list
|
24 |
+
sorted_lines = []
|
25 |
+
for _, group in sorted(vertical_groups.items()):
|
26 |
+
sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0])
|
27 |
+
sorted_lines.extend(sorted_group)
|
28 |
+
|
29 |
+
return sorted_lines
|
30 |
+
|
31 |
+
|
32 |
+
def truncate_repetitions(text: str, min_len=15):
|
33 |
+
# From nougat, with some cleanup
|
34 |
+
if len(text) < 2 * min_len:
|
35 |
+
return text
|
36 |
+
|
37 |
+
# try to find a length at which the tail is repeating
|
38 |
+
max_rep_len = None
|
39 |
+
for rep_len in range(min_len, int(len(text) / 2)):
|
40 |
+
# check if there is a repetition at the end
|
41 |
+
same = True
|
42 |
+
for i in range(0, rep_len):
|
43 |
+
if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]:
|
44 |
+
same = False
|
45 |
+
break
|
46 |
+
|
47 |
+
if same:
|
48 |
+
max_rep_len = rep_len
|
49 |
+
|
50 |
+
if max_rep_len is None:
|
51 |
+
return text
|
52 |
+
|
53 |
+
lcs = text[-max_rep_len:]
|
54 |
+
|
55 |
+
# remove all but the last repetition
|
56 |
+
text_to_truncate = text
|
57 |
+
while text_to_truncate.endswith(lcs):
|
58 |
+
text_to_truncate = text_to_truncate[:-max_rep_len]
|
59 |
+
|
60 |
+
return text[:len(text_to_truncate)]
|
61 |
+
|
62 |
+
|
63 |
+
def get_text_size(text, font):
|
64 |
+
im = Image.new(mode="P", size=(0, 0))
|
65 |
+
draw = ImageDraw.Draw(im)
|
66 |
+
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
|
67 |
+
return width, height
|
68 |
+
|
69 |
+
|
70 |
+
def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
|
71 |
+
font = ImageFont.truetype(font_path, box_font_size)
|
72 |
+
text_width, text_height = get_text_size(text, font)
|
73 |
+
while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
|
74 |
+
box_font_size = box_font_size - 1
|
75 |
+
font = ImageFont.truetype(font_path, box_font_size)
|
76 |
+
text_width, text_height = get_text_size(text, font)
|
77 |
+
|
78 |
+
# Calculate text position (centered in bbox)
|
79 |
+
text_width, text_height = get_text_size(text, font)
|
80 |
+
x = s_bbox[0]
|
81 |
+
y = s_bbox[1] + (bbox_height - text_height) / 2
|
82 |
+
|
83 |
+
draw.text((x, y), text, fill="black", font=font)
|
84 |
+
|
85 |
+
|
86 |
+
def render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path):
|
87 |
+
try:
|
88 |
+
from surya.postprocessing.math.render import latex_to_pil
|
89 |
+
box_font_size = max(10, min(int(.2 * bbox_height), 24))
|
90 |
+
img = latex_to_pil(text, bbox_width, bbox_height, fontsize=box_font_size)
|
91 |
+
img.thumbnail((bbox_width, bbox_height))
|
92 |
+
image.paste(img, (s_bbox[0], s_bbox[1]))
|
93 |
+
except Exception as e:
|
94 |
+
print(f"Failed to render math: {e}")
|
95 |
+
box_font_size = max(10, min(int(.75 * bbox_height), 24))
|
96 |
+
render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size)
|
97 |
+
|
98 |
+
|
99 |
+
def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[str], font_path=None, max_font_size=60, res_upscale=2, has_math=False):
|
100 |
+
if font_path is None:
|
101 |
+
font_path = get_font_path(langs)
|
102 |
+
new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
|
103 |
+
image = Image.new('RGB', new_image_size, color='white')
|
104 |
+
draw = ImageDraw.Draw(image)
|
105 |
+
|
106 |
+
for bbox, text in zip(bboxes, texts):
|
107 |
+
s_bbox = [int(coord * res_upscale) for coord in bbox]
|
108 |
+
bbox_width = s_bbox[2] - s_bbox[0]
|
109 |
+
bbox_height = s_bbox[3] - s_bbox[1]
|
110 |
+
|
111 |
+
# Shrink the text to fit in the bbox if needed
|
112 |
+
if has_math and is_latex(text):
|
113 |
+
render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path)
|
114 |
+
else:
|
115 |
+
box_font_size = max(6, min(int(.75 * bbox_height), max_font_size))
|
116 |
+
render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size)
|
117 |
+
|
118 |
+
return image
|