omarelsayeed
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
-
from ultralytics import
|
2 |
import gradio as gr
|
3 |
from huggingface_hub import snapshot_download
|
4 |
from PIL import Image
|
|
|
|
|
|
|
|
|
5 |
|
6 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
7 |
model = RTDETR(model_dir)
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
def predict_image(img, conf_threshold, iou_threshold):
|
11 |
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
|
12 |
results = model.predict(
|
13 |
source=img,
|
@@ -16,17 +21,143 @@ def predict_image(img, conf_threshold, iou_threshold):
|
|
16 |
show_labels=True,
|
17 |
show_conf=True,
|
18 |
imgsz=1024,
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
im = Image.fromarray(im_array[..., ::-1])
|
24 |
|
25 |
-
|
|
|
|
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
iface = gr.Interface(
|
29 |
-
fn=
|
30 |
inputs=[
|
31 |
gr.Image(type="pil", label="Upload Image"),
|
32 |
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
|
|
|
1 |
+
from ultralytics import RTDETR
|
2 |
import gradio as gr
|
3 |
from huggingface_hub import snapshot_download
|
4 |
from PIL import Image
|
5 |
+
from PIL import Image, ImageDraw, ImageFont
|
6 |
+
from surya.ordering import batch_ordering
|
7 |
+
from surya.model.ordering.processor import load_processor
|
8 |
+
from surya.model.ordering.model import load_model
|
9 |
|
10 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
11 |
model = RTDETR(model_dir)
|
12 |
+
order_model = load_model()
|
13 |
+
processor = load_processor()
|
14 |
|
15 |
+
def detect_layout(img, conf_threshold, iou_threshold):
|
|
|
16 |
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
|
17 |
results = model.predict(
|
18 |
source=img,
|
|
|
21 |
show_labels=True,
|
22 |
show_conf=True,
|
23 |
imgsz=1024,
|
24 |
+
agnostic_nms= True,
|
25 |
+
max_det=34,
|
26 |
+
nms=True
|
27 |
+
)[0]
|
28 |
+
bboxes = results.boxes.xyxy.cpu().tolist()
|
29 |
+
classes = results.boxes.cls.cpu().tolist()
|
30 |
+
mapping = {0: 'CheckBox',
|
31 |
+
1: 'List',
|
32 |
+
2: 'P',
|
33 |
+
3: 'abandon',
|
34 |
+
4: 'figure',
|
35 |
+
5: 'gridless_table',
|
36 |
+
6: 'handwritten_signature',
|
37 |
+
7: 'qr_code',
|
38 |
+
8: 'table',
|
39 |
+
9: 'title'}
|
40 |
+
classes = [mapping[i] for i in classes]
|
41 |
+
return bboxes , classes
|
42 |
+
|
43 |
+
def get_orders(image_path , boxes):
|
44 |
+
image = Image.open(image_path)
|
45 |
+
order_predictions = batch_ordering([image], [bboxes], order_model, processor)
|
46 |
+
return [i.position for i in order_predictions[0].bboxes]
|
47 |
+
|
48 |
+
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
|
49 |
+
# Define a color map for each class name
|
50 |
+
class_colors = {
|
51 |
+
'CheckBox': 'orange',
|
52 |
+
'List': 'blue',
|
53 |
+
'P': 'green',
|
54 |
+
'abandon': 'purple',
|
55 |
+
'figure': 'cyan',
|
56 |
+
'gridless_table': 'yellow',
|
57 |
+
'handwritten_signature': 'magenta',
|
58 |
+
'qr_code': 'red',
|
59 |
+
'table': 'brown',
|
60 |
+
'title': 'pink'
|
61 |
+
}
|
62 |
+
|
63 |
+
# Open the image using PIL
|
64 |
+
image = Image.open(image_path)
|
65 |
+
|
66 |
+
# Prepare to draw on the image
|
67 |
+
draw = ImageDraw.Draw(image)
|
68 |
+
|
69 |
+
# Try loading a default font, if it fails, use a basic font
|
70 |
+
try:
|
71 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
72 |
+
title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles
|
73 |
+
except IOError:
|
74 |
+
font = ImageFont.load_default(size = 30)
|
75 |
+
title_font = font # Use the same font for title if custom font fails
|
76 |
+
|
77 |
+
# Loop through the bounding boxes and corresponding labels
|
78 |
+
for i in range(len(bboxes)):
|
79 |
+
x1, y1, x2, y2 = bboxes[i]
|
80 |
+
class_name = classes[i]
|
81 |
+
order = reading_order[i]
|
82 |
+
|
83 |
+
# Get the color for the class
|
84 |
+
color = class_colors[class_name]
|
85 |
+
|
86 |
+
# If it's a title, make the bounding box thicker and text larger
|
87 |
+
if class_name == 'title':
|
88 |
+
box_thickness = 4 # Thicker box for title
|
89 |
+
label_font = title_font # Larger font for title
|
90 |
+
else:
|
91 |
+
box_thickness = 2 # Default box thickness
|
92 |
+
label_font = font # Default font for other classes
|
93 |
+
|
94 |
+
# Draw the rectangle with the class color and box thickness
|
95 |
+
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
|
96 |
+
|
97 |
+
# Label the box with the class and order
|
98 |
+
label = f"{class_name}-{order}"
|
99 |
+
|
100 |
+
# Calculate text size using textbbox() to get the bounding box of the text
|
101 |
+
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
|
102 |
+
label_width = bbox[2] - bbox[0]
|
103 |
+
label_height = bbox[3] - bbox[1]
|
104 |
+
|
105 |
+
# Draw the text above the box
|
106 |
+
draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
|
107 |
+
|
108 |
+
# Return the modified image as a PIL image object
|
109 |
+
return image
|
110 |
+
from PIL import Image, ImageDraw
|
111 |
+
|
112 |
+
def is_inside(box1, box2):
|
113 |
+
# Check if box1 is inside box2
|
114 |
+
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
115 |
+
|
116 |
+
def is_overlap(box1, box2):
|
117 |
+
# Check if box1 overlaps with box2
|
118 |
+
x1, y1, x2, y2 = box1
|
119 |
+
x3, y3, x4, y4 = box2
|
120 |
+
|
121 |
+
# No overlap if one box is to the left, right, above, or below the other box
|
122 |
+
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
|
123 |
+
|
124 |
+
def remove_overlapping_and_inside_boxes(boxes, classes):
|
125 |
+
to_remove = []
|
126 |
+
|
127 |
+
for i, box1 in enumerate(boxes):
|
128 |
+
for j, box2 in enumerate(boxes):
|
129 |
+
if i != j:
|
130 |
+
if is_inside(box1, box2):
|
131 |
+
# Mark the smaller (inside) box for removal
|
132 |
+
to_remove.append(i)
|
133 |
+
elif is_inside(box2, box1):
|
134 |
+
# Mark the smaller (inside) box for removal
|
135 |
+
to_remove.append(j)
|
136 |
+
elif is_overlap(box1, box2):
|
137 |
+
# If the boxes overlap, mark the smaller one for removal
|
138 |
+
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
|
139 |
+
to_remove.append(j)
|
140 |
+
else:
|
141 |
+
to_remove.append(i)
|
142 |
|
143 |
+
# Remove duplicates and sort by the index to keep original boxes
|
144 |
+
to_remove = sorted(set(to_remove), reverse=True)
|
|
|
145 |
|
146 |
+
# Remove the boxes and their corresponding classes from the list
|
147 |
+
for idx in to_remove:
|
148 |
+
del boxes[idx]
|
149 |
+
del classes[idx]
|
150 |
|
151 |
+
return boxes, classes
|
152 |
+
def full_predictions(IMAGE_PATH)
|
153 |
+
bboxes , classes = detect_layout(IMAGE_PATH , 0.3, 0)
|
154 |
+
bboxes , classes = remove_overlapping_and_inside_boxes(bboxes,classes)
|
155 |
+
orders = get_orders(IMAGE_PATH , bboxes)
|
156 |
+
final_image = draw_bboxes_on_image(IMAGE_PATH , bboxes , classes , orders)
|
157 |
+
return final_image
|
158 |
|
159 |
iface = gr.Interface(
|
160 |
+
fn=full_predictions,
|
161 |
inputs=[
|
162 |
gr.Image(type="pil", label="Upload Image"),
|
163 |
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
|