omarelsayeed
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,67 @@ import gradio as gr
|
|
3 |
from huggingface_hub import snapshot_download
|
4 |
from PIL import Image
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
-
|
7 |
-
|
8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
11 |
model = RTDETR(model_dir)
|
12 |
-
|
13 |
-
processor = load_processor()
|
14 |
|
15 |
def detect_layout(img, conf_threshold, iou_threshold):
|
16 |
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
|
@@ -40,9 +93,6 @@ def detect_layout(img, conf_threshold, iou_threshold):
|
|
40 |
classes = [mapping[i] for i in classes]
|
41 |
return bboxes , classes
|
42 |
|
43 |
-
def get_orders(image_path , boxes):
|
44 |
-
order_predictions = batch_ordering([image_path], [boxes], order_model, processor)
|
45 |
-
return [i.position for i in order_predictions[0].bboxes]
|
46 |
|
47 |
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
|
48 |
# Define a color map for each class name
|
@@ -149,11 +199,12 @@ def remove_overlapping_and_inside_boxes(boxes, classes):
|
|
149 |
|
150 |
return boxes, classes
|
151 |
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
157 |
|
158 |
iface = gr.Interface(
|
159 |
fn=full_predictions,
|
|
|
3 |
from huggingface_hub import snapshot_download
|
4 |
from PIL import Image
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
|
7 |
+
|
8 |
+
from collections import defaultdict
|
9 |
+
from typing import List, Dict
|
10 |
+
import torch
|
11 |
+
from transformers import LayoutLMv3ForTokenClassification
|
12 |
+
|
13 |
+
# Load the LayoutLMv3 model
|
14 |
+
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small")
|
15 |
+
|
16 |
+
MAX_LEN = 100
|
17 |
+
CLS_TOKEN_ID = 0
|
18 |
+
UNK_TOKEN_ID = 3
|
19 |
+
EOS_TOKEN_ID = 2
|
20 |
+
|
21 |
+
|
22 |
+
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
|
23 |
+
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
24 |
+
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
25 |
+
attention_mask = [1] + [1] * len(boxes) + [1]
|
26 |
+
return {
|
27 |
+
"bbox": torch.tensor([bbox]),
|
28 |
+
"attention_mask": torch.tensor([attention_mask]),
|
29 |
+
"input_ids": torch.tensor([input_ids]),
|
30 |
+
}
|
31 |
+
|
32 |
+
def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
|
33 |
+
"""
|
34 |
+
Parse logits to determine the reading order.
|
35 |
+
"""
|
36 |
+
logits = logits[1: length + 1, :length]
|
37 |
+
orders = logits.argsort(descending=False).tolist()
|
38 |
+
ret = [o.pop() for o in orders]
|
39 |
+
while True:
|
40 |
+
order_to_idxes = defaultdict(list)
|
41 |
+
for idx, order in enumerate(ret):
|
42 |
+
order_to_idxes[order].append(idx)
|
43 |
+
# Filter indices with length > 1
|
44 |
+
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
45 |
+
if not order_to_idxes:
|
46 |
+
break
|
47 |
+
# Resolve conflicts
|
48 |
+
for order, idxes in order_to_idxes.items():
|
49 |
+
idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
|
50 |
+
idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
|
51 |
+
for idx, _ in idxes_to_logit[1:]:
|
52 |
+
ret[idx] = orders[idx].pop()
|
53 |
+
|
54 |
+
return ret
|
55 |
+
|
56 |
+
def get_orders(image_path, boxes):
|
57 |
+
inputs = boxes2inputs(boxes)
|
58 |
+
inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
|
59 |
+
logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
|
60 |
+
orders = parse_logits(logits, len(boxes))
|
61 |
+
return orders
|
62 |
+
|
63 |
|
64 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
65 |
model = RTDETR(model_dir)
|
66 |
+
|
|
|
67 |
|
68 |
def detect_layout(img, conf_threshold, iou_threshold):
|
69 |
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
|
|
|
93 |
classes = [mapping[i] for i in classes]
|
94 |
return bboxes , classes
|
95 |
|
|
|
|
|
|
|
96 |
|
97 |
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
|
98 |
# Define a color map for each class name
|
|
|
199 |
|
200 |
return boxes, classes
|
201 |
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
|
202 |
+
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
|
203 |
+
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
|
204 |
+
orders = get_orders(IMAGE_PATH, bboxes)
|
205 |
+
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
|
206 |
+
return final_image
|
207 |
+
|
208 |
|
209 |
iface = gr.Interface(
|
210 |
fn=full_predictions,
|