omarelsayeed commited on
Commit
fa50974
·
verified ·
1 Parent(s): 74c842e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -9
app.py CHANGED
@@ -1,13 +1,18 @@
1
- from ultralytics import ASSETS, YOLO, RTDETR
2
  import gradio as gr
3
  from huggingface_hub import snapshot_download
4
  from PIL import Image
 
 
 
 
5
 
6
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
7
  model = RTDETR(model_dir)
 
 
8
 
9
-
10
- def predict_image(img, conf_threshold, iou_threshold):
11
  """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
12
  results = model.predict(
13
  source=img,
@@ -16,17 +21,143 @@ def predict_image(img, conf_threshold, iou_threshold):
16
  show_labels=True,
17
  show_conf=True,
18
  imgsz=1024,
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- for r in results:
22
- im_array = r.plot()
23
- im = Image.fromarray(im_array[..., ::-1])
24
 
25
- return im
 
 
 
26
 
 
 
 
 
 
 
 
27
 
28
  iface = gr.Interface(
29
- fn=predict_image,
30
  inputs=[
31
  gr.Image(type="pil", label="Upload Image"),
32
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
 
1
+ from ultralytics import RTDETR
2
  import gradio as gr
3
  from huggingface_hub import snapshot_download
4
  from PIL import Image
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from surya.ordering import batch_ordering
7
+ from surya.model.ordering.processor import load_processor
8
+ from surya.model.ordering.model import load_model
9
 
10
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
11
  model = RTDETR(model_dir)
12
+ order_model = load_model()
13
+ processor = load_processor()
14
 
15
+ def detect_layout(img, conf_threshold, iou_threshold):
 
16
  """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
17
  results = model.predict(
18
  source=img,
 
21
  show_labels=True,
22
  show_conf=True,
23
  imgsz=1024,
24
+ agnostic_nms= True,
25
+ max_det=34,
26
+ nms=True
27
+ )[0]
28
+ bboxes = results.boxes.xyxy.cpu().tolist()
29
+ classes = results.boxes.cls.cpu().tolist()
30
+ mapping = {0: 'CheckBox',
31
+ 1: 'List',
32
+ 2: 'P',
33
+ 3: 'abandon',
34
+ 4: 'figure',
35
+ 5: 'gridless_table',
36
+ 6: 'handwritten_signature',
37
+ 7: 'qr_code',
38
+ 8: 'table',
39
+ 9: 'title'}
40
+ classes = [mapping[i] for i in classes]
41
+ return bboxes , classes
42
+
43
+ def get_orders(image_path , boxes):
44
+ image = Image.open(image_path)
45
+ order_predictions = batch_ordering([image], [bboxes], order_model, processor)
46
+ return [i.position for i in order_predictions[0].bboxes]
47
+
48
+ def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
49
+ # Define a color map for each class name
50
+ class_colors = {
51
+ 'CheckBox': 'orange',
52
+ 'List': 'blue',
53
+ 'P': 'green',
54
+ 'abandon': 'purple',
55
+ 'figure': 'cyan',
56
+ 'gridless_table': 'yellow',
57
+ 'handwritten_signature': 'magenta',
58
+ 'qr_code': 'red',
59
+ 'table': 'brown',
60
+ 'title': 'pink'
61
+ }
62
+
63
+ # Open the image using PIL
64
+ image = Image.open(image_path)
65
+
66
+ # Prepare to draw on the image
67
+ draw = ImageDraw.Draw(image)
68
+
69
+ # Try loading a default font, if it fails, use a basic font
70
+ try:
71
+ font = ImageFont.truetype("arial.ttf", 20)
72
+ title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles
73
+ except IOError:
74
+ font = ImageFont.load_default(size = 30)
75
+ title_font = font # Use the same font for title if custom font fails
76
+
77
+ # Loop through the bounding boxes and corresponding labels
78
+ for i in range(len(bboxes)):
79
+ x1, y1, x2, y2 = bboxes[i]
80
+ class_name = classes[i]
81
+ order = reading_order[i]
82
+
83
+ # Get the color for the class
84
+ color = class_colors[class_name]
85
+
86
+ # If it's a title, make the bounding box thicker and text larger
87
+ if class_name == 'title':
88
+ box_thickness = 4 # Thicker box for title
89
+ label_font = title_font # Larger font for title
90
+ else:
91
+ box_thickness = 2 # Default box thickness
92
+ label_font = font # Default font for other classes
93
+
94
+ # Draw the rectangle with the class color and box thickness
95
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
96
+
97
+ # Label the box with the class and order
98
+ label = f"{class_name}-{order}"
99
+
100
+ # Calculate text size using textbbox() to get the bounding box of the text
101
+ bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
102
+ label_width = bbox[2] - bbox[0]
103
+ label_height = bbox[3] - bbox[1]
104
+
105
+ # Draw the text above the box
106
+ draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
107
+
108
+ # Return the modified image as a PIL image object
109
+ return image
110
+ from PIL import Image, ImageDraw
111
+
112
+ def is_inside(box1, box2):
113
+ # Check if box1 is inside box2
114
+ return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
115
+
116
+ def is_overlap(box1, box2):
117
+ # Check if box1 overlaps with box2
118
+ x1, y1, x2, y2 = box1
119
+ x3, y3, x4, y4 = box2
120
+
121
+ # No overlap if one box is to the left, right, above, or below the other box
122
+ return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
123
+
124
+ def remove_overlapping_and_inside_boxes(boxes, classes):
125
+ to_remove = []
126
+
127
+ for i, box1 in enumerate(boxes):
128
+ for j, box2 in enumerate(boxes):
129
+ if i != j:
130
+ if is_inside(box1, box2):
131
+ # Mark the smaller (inside) box for removal
132
+ to_remove.append(i)
133
+ elif is_inside(box2, box1):
134
+ # Mark the smaller (inside) box for removal
135
+ to_remove.append(j)
136
+ elif is_overlap(box1, box2):
137
+ # If the boxes overlap, mark the smaller one for removal
138
+ if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
139
+ to_remove.append(j)
140
+ else:
141
+ to_remove.append(i)
142
 
143
+ # Remove duplicates and sort by the index to keep original boxes
144
+ to_remove = sorted(set(to_remove), reverse=True)
 
145
 
146
+ # Remove the boxes and their corresponding classes from the list
147
+ for idx in to_remove:
148
+ del boxes[idx]
149
+ del classes[idx]
150
 
151
+ return boxes, classes
152
+ def full_predictions(IMAGE_PATH)
153
+ bboxes , classes = detect_layout(IMAGE_PATH , 0.3, 0)
154
+ bboxes , classes = remove_overlapping_and_inside_boxes(bboxes,classes)
155
+ orders = get_orders(IMAGE_PATH , bboxes)
156
+ final_image = draw_bboxes_on_image(IMAGE_PATH , bboxes , classes , orders)
157
+ return final_image
158
 
159
  iface = gr.Interface(
160
+ fn=full_predictions,
161
  inputs=[
162
  gr.Image(type="pil", label="Upload Image"),
163
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),