|
from ultralytics import RTDETR |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
from PIL import Image, ImageDraw, ImageFont |
|
from surya.ordering import batch_ordering |
|
from surya.model.ordering.processor import load_processor |
|
from surya.model.ordering.model import load_model |
|
|
|
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt" |
|
model = RTDETR(model_dir) |
|
order_model = load_model() |
|
processor = load_processor() |
|
|
|
def detect_layout(img, conf_threshold, iou_threshold): |
|
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds.""" |
|
results = model.predict( |
|
source=img, |
|
conf=conf_threshold, |
|
iou=iou_threshold, |
|
show_labels=True, |
|
show_conf=True, |
|
imgsz=1024, |
|
agnostic_nms= True, |
|
max_det=34, |
|
nms=True |
|
)[0] |
|
bboxes = results.boxes.xyxy.cpu().tolist() |
|
classes = results.boxes.cls.cpu().tolist() |
|
mapping = {0: 'CheckBox', |
|
1: 'List', |
|
2: 'P', |
|
3: 'abandon', |
|
4: 'figure', |
|
5: 'gridless_table', |
|
6: 'handwritten_signature', |
|
7: 'qr_code', |
|
8: 'table', |
|
9: 'title'} |
|
classes = [mapping[i] for i in classes] |
|
return bboxes , classes |
|
|
|
def get_orders(image_path , boxes): |
|
order_predictions = batch_ordering([image_path], [boxes], order_model, processor) |
|
return [i.position for i in order_predictions[0].bboxes] |
|
|
|
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order): |
|
|
|
class_colors = { |
|
'CheckBox': 'orange', |
|
'List': 'blue', |
|
'P': 'green', |
|
'abandon': 'purple', |
|
'figure': 'cyan', |
|
'gridless_table': 'yellow', |
|
'handwritten_signature': 'magenta', |
|
'qr_code': 'red', |
|
'table': 'brown', |
|
'title': 'pink' |
|
} |
|
|
|
|
|
image = Image.open(image_path) |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
title_font = ImageFont.truetype("arial.ttf", 30) |
|
except IOError: |
|
font = ImageFont.load_default(size = 30) |
|
title_font = font |
|
|
|
|
|
for i in range(len(bboxes)): |
|
x1, y1, x2, y2 = bboxes[i] |
|
class_name = classes[i] |
|
order = reading_order[i] |
|
|
|
|
|
color = class_colors[class_name] |
|
|
|
|
|
if class_name == 'title': |
|
box_thickness = 4 |
|
label_font = title_font |
|
else: |
|
box_thickness = 2 |
|
label_font = font |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness) |
|
|
|
|
|
label = f"{class_name}-{order}" |
|
|
|
|
|
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font) |
|
label_width = bbox[2] - bbox[0] |
|
label_height = bbox[3] - bbox[1] |
|
|
|
|
|
draw.text((x1, y1 - label_height), label, fill="black", font=label_font) |
|
|
|
|
|
return image |
|
from PIL import Image, ImageDraw |
|
|
|
def is_inside(box1, box2): |
|
|
|
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] |
|
|
|
def is_overlap(box1, box2): |
|
|
|
x1, y1, x2, y2 = box1 |
|
x3, y3, x4, y4 = box2 |
|
|
|
|
|
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1) |
|
|
|
def remove_overlapping_and_inside_boxes(boxes, classes): |
|
to_remove = [] |
|
|
|
for i, box1 in enumerate(boxes): |
|
for j, box2 in enumerate(boxes): |
|
if i != j: |
|
if is_inside(box1, box2): |
|
|
|
to_remove.append(i) |
|
elif is_inside(box2, box1): |
|
|
|
to_remove.append(j) |
|
elif is_overlap(box1, box2): |
|
|
|
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]): |
|
to_remove.append(j) |
|
else: |
|
to_remove.append(i) |
|
|
|
|
|
to_remove = sorted(set(to_remove), reverse=True) |
|
|
|
|
|
for idx in to_remove: |
|
del boxes[idx] |
|
del classes[idx] |
|
|
|
return boxes, classes |
|
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold): |
|
bboxes , classes = detect_layout(IMAGE_PATH ,conf_threshold, iou_threshold) |
|
bboxes , classes = remove_overlapping_and_inside_boxes(bboxes,classes) |
|
orders = get_orders(IMAGE_PATH , bboxes) |
|
final_image = draw_bboxes_on_image(IMAGE_PATH , bboxes , classes , orders) |
|
return final_image |
|
|
|
iface = gr.Interface( |
|
fn=full_predictions, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"), |
|
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"), |
|
], |
|
outputs=gr.Image(type="pil", label="Result"), |
|
title="Ultralytics Gradio", |
|
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.", |
|
examples=[ |
|
["kashida.png", 0.2, 0.45], |
|
["image.jpg", 0.2, 0.45], |
|
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45] |
|
], |
|
theme=gr.themes.Default() |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |