File size: 6,454 Bytes
fa50974 679d3c5 190bae6 fa50974 190bae6 74c842e 679d3c5 fa50974 679d3c5 fa50974 679d3c5 a231a61 fa50974 27a5f83 fa50974 679d3c5 fa50974 679d3c5 fa50974 679d3c5 fa50974 faf74d9 fa50974 679d3c5 fa50974 679d3c5 45440e8 63b70c6 679d3c5 2132edd 679d3c5 3ec7d0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from ultralytics import RTDETR
import gradio as gr
from huggingface_hub import snapshot_download
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from surya.ordering import batch_ordering
from surya.model.ordering.processor import load_processor
from surya.model.ordering.model import load_model
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
model = RTDETR(model_dir)
order_model = load_model()
processor = load_processor()
def detect_layout(img, conf_threshold, iou_threshold):
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
results = model.predict(
source=img,
conf=conf_threshold,
iou=iou_threshold,
show_labels=True,
show_conf=True,
imgsz=1024,
agnostic_nms= True,
max_det=34,
nms=True
)[0]
bboxes = results.boxes.xyxy.cpu().tolist()
classes = results.boxes.cls.cpu().tolist()
mapping = {0: 'CheckBox',
1: 'List',
2: 'P',
3: 'abandon',
4: 'figure',
5: 'gridless_table',
6: 'handwritten_signature',
7: 'qr_code',
8: 'table',
9: 'title'}
classes = [mapping[i] for i in classes]
return bboxes , classes
def get_orders(image_path , boxes):
order_predictions = batch_ordering([image_path], [boxes], order_model, processor)
return [i.position for i in order_predictions[0].bboxes]
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
# Define a color map for each class name
class_colors = {
'CheckBox': 'orange',
'List': 'blue',
'P': 'green',
'abandon': 'purple',
'figure': 'cyan',
'gridless_table': 'yellow',
'handwritten_signature': 'magenta',
'qr_code': 'red',
'table': 'brown',
'title': 'pink'
}
# Open the image using PIL
image = Image.open(image_path)
# Prepare to draw on the image
draw = ImageDraw.Draw(image)
# Try loading a default font, if it fails, use a basic font
try:
font = ImageFont.truetype("arial.ttf", 20)
title_font = ImageFont.truetype("arial.ttf", 30) # Larger font for titles
except IOError:
font = ImageFont.load_default(size = 30)
title_font = font # Use the same font for title if custom font fails
# Loop through the bounding boxes and corresponding labels
for i in range(len(bboxes)):
x1, y1, x2, y2 = bboxes[i]
class_name = classes[i]
order = reading_order[i]
# Get the color for the class
color = class_colors[class_name]
# If it's a title, make the bounding box thicker and text larger
if class_name == 'title':
box_thickness = 4 # Thicker box for title
label_font = title_font # Larger font for title
else:
box_thickness = 2 # Default box thickness
label_font = font # Default font for other classes
# Draw the rectangle with the class color and box thickness
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness)
# Label the box with the class and order
label = f"{class_name}-{order}"
# Calculate text size using textbbox() to get the bounding box of the text
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font)
label_width = bbox[2] - bbox[0]
label_height = bbox[3] - bbox[1]
# Draw the text above the box
draw.text((x1, y1 - label_height), label, fill="black", font=label_font)
# Return the modified image as a PIL image object
return image
from PIL import Image, ImageDraw
def is_inside(box1, box2):
# Check if box1 is inside box2
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
def is_overlap(box1, box2):
# Check if box1 overlaps with box2
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
# No overlap if one box is to the left, right, above, or below the other box
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1)
def remove_overlapping_and_inside_boxes(boxes, classes):
to_remove = []
for i, box1 in enumerate(boxes):
for j, box2 in enumerate(boxes):
if i != j:
if is_inside(box1, box2):
# Mark the smaller (inside) box for removal
to_remove.append(i)
elif is_inside(box2, box1):
# Mark the smaller (inside) box for removal
to_remove.append(j)
elif is_overlap(box1, box2):
# If the boxes overlap, mark the smaller one for removal
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]):
to_remove.append(j)
else:
to_remove.append(i)
# Remove duplicates and sort by the index to keep original boxes
to_remove = sorted(set(to_remove), reverse=True)
# Remove the boxes and their corresponding classes from the list
for idx in to_remove:
del boxes[idx]
del classes[idx]
return boxes, classes
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
bboxes , classes = detect_layout(IMAGE_PATH ,conf_threshold, iou_threshold)
bboxes , classes = remove_overlapping_and_inside_boxes(bboxes,classes)
orders = get_orders(IMAGE_PATH , bboxes)
final_image = draw_bboxes_on_image(IMAGE_PATH , bboxes , classes , orders)
return final_image
iface = gr.Interface(
fn=full_predictions,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"),
],
outputs=gr.Image(type="pil", label="Result"),
title="Ultralytics Gradio",
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.",
examples=[
["kashida.png", 0.2, 0.45],
["image.jpg", 0.2, 0.45],
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45]
],
theme=gr.themes.Default()
)
if __name__ == "__main__":
iface.launch() |