omarelsayeed commited on
Commit
1ebac6f
·
verified ·
1 Parent(s): 86dc437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -94
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from ultralytics import RTDETR
2
  import gradio as gr
3
  from huggingface_hub import snapshot_download
4
- from PIL import Image
5
  from PIL import Image, ImageDraw, ImageFont
6
  import numpy as np
7
  import random
@@ -9,99 +8,36 @@ from collections import defaultdict
9
  from typing import List, Dict
10
  import torch
11
  from transformers import LayoutLMv3ForTokenClassification
12
-
13
- # Load the LayoutLMv3 model
14
- layout_model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
15
-
16
- MAX_LEN = 100
17
- CLS_TOKEN_ID = 0
18
- UNK_TOKEN_ID = 3
19
- EOS_TOKEN_ID = 2
20
-
21
-
22
- def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
23
- bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
24
- input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
25
- attention_mask = [1] + [1] * len(boxes) + [1]
26
- return {
27
- "bbox": torch.tensor([bbox]),
28
- "attention_mask": torch.tensor([attention_mask]),
29
- "input_ids": torch.tensor([input_ids]),
30
- }
31
-
32
- def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
33
- """
34
- Parse logits to determine the reading order.
35
- """
36
- logits = logits[1: length + 1, :length]
37
- orders = logits.argsort(descending=False).tolist()
38
- ret = [o.pop() for o in orders]
39
- while True:
40
- order_to_idxes = defaultdict(list)
41
- for idx, order in enumerate(ret):
42
- order_to_idxes[order].append(idx)
43
- # Filter indices with length > 1
44
- order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
45
- if not order_to_idxes:
46
- break
47
- # Resolve conflicts
48
- for order, idxes in order_to_idxes.items():
49
- idxes_to_logit = {idx: logits[idx, order] for idx in idxes}
50
- idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True)
51
- for idx, _ in idxes_to_logit[1:]:
52
- ret[idx] = orders[idx].pop()
53
-
54
- return ret
55
- def get_orders(_,bounding_boxes):
56
- """
57
- Detects reading order for Arabic text layout, given bounding boxes in xyxy format.
58
 
59
- Args:
60
- - bounding_boxes: List of tuples (x1, y1, x2, y2), where
61
- (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner of the bounding box.
62
-
63
- Returns:
64
- - A list of indices representing the reading order.
65
- """
66
- # Convert to numpy array for easier processing
67
- bounding_boxes = [tuple(b) for b in bounding_boxes]
68
- boxes = np.array(bounding_boxes)
69
-
70
- # Extract positions: (x1, y1) as the top-left, (x2, y2) as the bottom-right
71
- # Sort by vertical position first (y1), and then horizontal position (x1), with right-to-left sorting
72
- sorted_indices = np.lexsort((boxes[:, 0], boxes[:, 1])) # Sort by y1, then by x1 (right-to-left)
73
-
74
- # Sort within rows by checking overlap tolerance for y coordinates
75
- rows = []
76
- tolerance = 10 # Tolerance for grouping elements into rows
77
- for idx in sorted_indices:
78
- placed = False
79
- for row in rows:
80
- # Check if the box belongs to an existing row (y1 overlap within tolerance)
81
- if abs(row[-1][1] - boxes[idx][1]) < tolerance:
82
- row.append(boxes[idx])
83
- placed = True
84
- break
85
- if not placed:
86
- rows.append([boxes[idx]])
87
-
88
- # Within each row, sort by x1 (right-to-left)
89
- reading_order = []
90
- for row in rows:
91
- row.sort(key=lambda b: -b[0]) # Sort by x1 descending (right-to-left)
92
- reading_order.extend(row)
93
-
94
- # Return the indices of the bounding boxes in the correct reading order
95
- return [bounding_boxes.index(tuple(box)) for box in reading_order]
96
-
97
-
98
- # def get_orders(image_path, boxes):
99
- # b = scale_and_normalize_boxes(boxes)
100
- # inputs = boxes2inputs(b)
101
- # inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} # Move inputs to model device
102
- # logits = layout_model(**inputs).logits.cpu().squeeze(0) # Perform inference and get logits
103
- # orders = parse_logits(logits, len(b))
104
- # return orders
105
 
106
 
107
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
@@ -203,7 +139,7 @@ def draw_bboxes_on_image(image_path, bboxes, classes, reading_order):
203
 
204
 
205
 
206
- def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000):
207
  """
208
  Scales and normalizes bounding boxes from original dimensions to new dimensions.
209
 
 
1
  from ultralytics import RTDETR
2
  import gradio as gr
3
  from huggingface_hub import snapshot_download
 
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import random
 
8
  from typing import List, Dict
9
  import torch
10
  from transformers import LayoutLMv3ForTokenClassification
11
+ from transformers import AutoProcessor
12
+ from transformers import AutoModelForTokenClassification
13
+
14
+ reading_order_model = AutoModelForTokenClassification.from_pretrained("omarelsayeed/yea_yea").to("cuda")
15
+ processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
16
+ apply_ocr=False)
17
+
18
+ def predict_reading_order(boxes,image_path):
19
+ words = ["<unk>"]*len(boxes)
20
+ print(boxes)
21
+ encoding = processor(image_path , text = words
22
+ , boxes=boxes
23
+ ,return_tensors="pt" ,
24
+ return_offsets_mapping=True)
25
+ offset_mapping = encoding.pop('offset_mapping')
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ for k,v in encoding.items():
28
+ encoding[k] = v.to(device)
29
+ outputs = model(**encoding)
30
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
31
+ token_boxes = encoding.bbox.squeeze().tolist()
32
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
33
+ # true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
34
+ predictions = predictions[1:-1]
35
+ return predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def get_orders(image_path, boxes):
38
+ b = scale_and_normalize_boxes(boxes)
39
+ orders = predict_reading_order(b, image_path)
40
+ return orders
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
 
139
 
140
 
141
 
142
+ def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=595.303955, new_height=841.889771, normalize_width=1000, normalize_height=1000):
143
  """
144
  Scales and normalizes bounding boxes from original dimensions to new dimensions.
145