omarelsayeed commited on
Commit
6e73f0b
·
verified ·
1 Parent(s): e059e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -13
app.py CHANGED
@@ -3,14 +3,67 @@ import gradio as gr
3
  from huggingface_hub import snapshot_download
4
  from PIL import Image
5
  from PIL import Image, ImageDraw, ImageFont
6
- from surya.ordering import batch_ordering
7
- from surya.model.ordering.processor import load_processor
8
- from surya.model.ordering.model import load_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
11
  model = RTDETR(model_dir)
12
- order_model = load_model()
13
- processor = load_processor()
14
 
15
  def detect_layout(img, conf_threshold, iou_threshold):
16
  """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
@@ -40,9 +93,6 @@ def detect_layout(img, conf_threshold, iou_threshold):
40
  classes = [mapping[i] for i in classes]
41
  return bboxes , classes
42
 
43
- def get_orders(image_path , boxes):
44
- order_predictions = batch_ordering([image_path], [boxes], order_model, processor)
45
- return [i.position for i in order_predictions[0].bboxes]
46
 
47
  def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
48
  # Define a color map for each class name
@@ -149,11 +199,12 @@ def remove_overlapping_and_inside_boxes(boxes, classes):
149
 
150
  return boxes, classes
151
  def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
152
- bboxes , classes = detect_layout(IMAGE_PATH ,conf_threshold, iou_threshold)
153
- bboxes , classes = remove_overlapping_and_inside_boxes(bboxes,classes)
154
- orders = get_orders(IMAGE_PATH , bboxes)
155
- final_image = draw_bboxes_on_image(IMAGE_PATH , bboxes , classes , orders)
156
- return final_image
 
157
 
158
  iface = gr.Interface(
159
  fn=full_predictions,
 
3
  from huggingface_hub import snapshot_download
4
  from PIL import Image
5
  from PIL import Image, ImageDraw, ImageFont
6
+
7
+
8
+ from collections import defaultdict
9
+ from typing import List, Dict
10
+ import torch
11
+ from transformers import LayoutLMv3ForTokenClassification
12
+
13
+ # Load the LayoutLMv3 model
14
+ layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small")
15
+
16
+ MAX_LEN = 100
17
+ CLS_TOKEN_ID = 0
18
+ UNK_TOKEN_ID = 3
19
+ EOS_TOKEN_ID = 2
20
+
21
+
22
+ def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
23
+ bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
24
+ input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
25
+ attention_mask = [1] + [1] * len(boxes) + [1]
26
+ return {
27
+ "bbox": torch.tensor([bbox]),
28
+ "attention_mask": torch.tensor([attention_mask]),
29
+ "input_ids": torch.tensor([input_ids]),
30
+ }
31
+
32
+ def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
33
+ """
34
+ Parse logits to determine the reading order.
35
+ """
36
+ logits = logits[1: length + 1, :length]
37
+ orders = logits.argsort(descending=False).tolist()
38
+ ret = [o.pop() for o in orders]
39
+ while True:
40
+ order_to_idxes = defaultdict(list)
41
+ for idx, order in enumerate(ret):
42
+ order_to_idxes[order].append(idx)
43
+ # Filter indices with length > 1
44
+ order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
45
+ if not order_to_idxes:
46
+ break
47
+ # Resolve conflicts
48
+ for order, idxes in order_to_idxes.items():
49
+ idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
50
+ idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
51
+ for idx, _ in idxes_to_logit[1:]:
52
+ ret[idx] = orders[idx].pop()
53
+
54
+ return ret
55
+
56
+ def get_orders(image_path, boxes):
57
+ inputs = boxes2inputs(boxes)
58
+ inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
59
+ logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
60
+ orders = parse_logits(logits, len(boxes))
61
+ return orders
62
+
63
 
64
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
65
  model = RTDETR(model_dir)
66
+
 
67
 
68
  def detect_layout(img, conf_threshold, iou_threshold):
69
  """Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds."""
 
93
  classes = [mapping[i] for i in classes]
94
  return bboxes , classes
95
 
 
 
 
96
 
97
  def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
98
  # Define a color map for each class name
 
199
 
200
  return boxes, classes
201
  def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold):
202
+ bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold)
203
+ bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes)
204
+ orders = get_orders(IMAGE_PATH, bboxes)
205
+ final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders)
206
+ return final_image
207
+
208
 
209
  iface = gr.Interface(
210
  fn=full_predictions,