juliozhao commited on
Commit
778c8b4
·
verified ·
1 Parent(s): dfdfaf2

Upload 13 files

Browse files
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: DocLayout YOLO
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Online demo for DocLayout-YOLO
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: DocLayout YOLO Demo
3
+ emoji: 🐢
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Demo for DocLayout-YOLO
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GRADIO_TEMP_DIR"] = "./tmp"
3
+
4
+ import sys
5
+ import torch
6
+ import torchvision
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ from huggingface_hub import snapshot_download
11
+ from visualization import visualize_bbox
12
+
13
+ # == download weights ==
14
+ model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench', local_dir='./models/DocLayout-YOLO-DocStructBench')
15
+ # == select device ==
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ id_to_names = {
19
+ 0: 'title',
20
+ 1: 'plain text',
21
+ 2: 'abandon',
22
+ 3: 'figure',
23
+ 4: 'figure_caption',
24
+ 5: 'table',
25
+ 6: 'table_caption',
26
+ 7: 'table_footnote',
27
+ 8: 'isolate_formula',
28
+ 9: 'formula_caption'
29
+ }
30
+
31
+ def recognize_image(input_img, conf_threshold, iou_threshold):
32
+ det_res = model.predict(
33
+ input_img,
34
+ imgsz=1024,
35
+ conf=conf_threshold,
36
+ device=device,
37
+ )[0]
38
+ boxes = det_res.__dict__['boxes'].xyxy
39
+ classes = det_res.__dict__['boxes'].cls
40
+ scores = det_res.__dict__['boxes'].conf
41
+
42
+ indices = torchvision.ops.nms(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=iou_threshold)
43
+ boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
44
+ if len(boxes.shape) == 1:
45
+ boxes = np.expand_dims(boxes, 0)
46
+ scores = np.expand_dims(scores, 0)
47
+ classes = np.expand_dims(classes, 0)
48
+
49
+ vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
50
+ return vis_result
51
+
52
+ def gradio_reset():
53
+ return gr.update(value=None), gr.update(value=None)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ root_path = os.path.abspath(os.getcwd())
58
+ # == load model ==
59
+ from doclayout_yolo import YOLOv10
60
+ print(f"Using device: {device}")
61
+ model = YOLOv10(os.path.join(os.path.dirname(__file__), "models", "DocLayout-YOLO-DocStructBench", "doclayout_yolo_docstructbench_imgsz1024.pt")) # load an official model
62
+
63
+ with open("header.html", "r") as file:
64
+ header = file.read()
65
+ with gr.Blocks() as demo:
66
+ gr.HTML(header)
67
+
68
+ with gr.Row():
69
+ with gr.Column():
70
+
71
+ input_img = gr.Image(label=" ", interactive=True)
72
+ with gr.Row():
73
+ clear = gr.Button(value="Clear")
74
+ predict = gr.Button(value="Detect", interactive=True, variant="primary")
75
+
76
+ with gr.Row():
77
+ conf_threshold = gr.Slider(
78
+ label="Confidence Threshold",
79
+ minimum=0.0,
80
+ maximum=1.0,
81
+ step=0.05,
82
+ value=0.25,
83
+ )
84
+
85
+ with gr.Row():
86
+ iou_threshold = gr.Slider(
87
+ label="NMS IOU Threshold",
88
+ minimum=0.0,
89
+ maximum=1.0,
90
+ step=0.05,
91
+ value=0.45,
92
+ )
93
+
94
+ with gr.Accordion("Examples:"):
95
+ example_root = os.path.join(os.path.dirname(__file__), "assets", "example")
96
+ gr.Examples(
97
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
98
+ _.endswith("jpg")],
99
+ inputs=[input_img],
100
+ )
101
+ with gr.Column():
102
+ gr.Button(value="Predict Result:", interactive=False)
103
+ output_img = gr.Image(label=" ", interactive=False)
104
+
105
+ clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
106
+ predict.click(recognize_image, inputs=[input_img,conf_threshold,iou_threshold], outputs=[output_img])
107
+
108
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/example/academic.jpg ADDED
assets/example/exam_paper.jpg ADDED
assets/example/financial.jpg ADDED
assets/example/fuzzy_scan.jpg ADDED
assets/example/poster.jpg ADDED
assets/example/ppt.jpg ADDED
assets/example/textbook.jpg ADDED
header.html ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html><head>
2
+ <!-- <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> -->
3
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
4
+ <style>
5
+ .link-block {
6
+ border: 1px solid transparent;
7
+ border-radius: 24px;
8
+ background-color: rgba(54, 54, 54, 1);
9
+ cursor: pointer !important;
10
+ }
11
+ .link-block:hover {
12
+ background-color: rgba(54, 54, 54, 0.75) !important;
13
+ cursor: pointer !important;
14
+ }
15
+ .external-link {
16
+ display: inline-flex;
17
+ align-items: center;
18
+ height: 36px;
19
+ line-height: 36px;
20
+ padding: 0 16px;
21
+ cursor: pointer !important;
22
+ }
23
+ .external-link,
24
+ .external-link:hover {
25
+ cursor: pointer !important;
26
+ }
27
+ a {
28
+ text-decoration: none;
29
+ }
30
+ </style></head>
31
+
32
+ <body>
33
+ <div style="
34
+ display: flex;
35
+ flex-direction: column;
36
+ justify-content: center;
37
+ align-items: center;
38
+ text-align: center;
39
+ background: linear-gradient(45deg, #007bff 0%, #0056b3 100%);
40
+ padding: 24px;
41
+ gap: 24px;
42
+ border-radius: 8px;
43
+ ">
44
+ <div style="
45
+ display: flex;
46
+ flex-direction: column;
47
+ align-items: center;
48
+ gap: 16px;
49
+ ">
50
+ <div style="display: flex; flex-direction: column; gap: 8px">
51
+ <h1 style="
52
+ font-size: 48px;
53
+ color: #fafafa;
54
+ margin: 0;
55
+ font-family: 'Trebuchet MS', 'Lucida Sans Unicode',
56
+ 'Lucida Grande', 'Lucida Sans', Arial, sans-serif;
57
+ ">
58
+ DocLayout-YOLO
59
+ </h1>
60
+ </div>
61
+ </div>
62
+
63
+ <p style="
64
+ margin: 0;
65
+ line-height: 1.6rem;
66
+ font-size: 16px;
67
+ color: #fafafa;
68
+ opacity: 0.8;
69
+ ">
70
+ An efficient and robust Model for Real-World Document Layout Analysis.<br>
71
+ </p>
72
+ <style>
73
+ .link-block {
74
+ display: inline-block;
75
+ }
76
+ .link-block + .link-block {
77
+ margin-left: 20px;
78
+ }
79
+ </style>
80
+
81
+ <div class="column has-text-centered">
82
+ <div class="publication-links">
83
+ <!-- Code Link. -->
84
+ <span class="link-block">
85
+ <a href="https://github.com/opendatalab/DocLayout-YOLO" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
86
+ <span class="icon" style="margin-right: 4px">
87
+ <i class="fab fa-github" style="color: white; margin-right: 4px"></i>
88
+ </span>
89
+ <span style="color: white">Code</span>
90
+ </a>
91
+ </span>
92
+
93
+ <!-- Paper Link. -->
94
+ <span class="link-block">
95
+ <a href="https://arxiv.org/abs/2410.12628" class="external-link button is-normal is-rounded is-dark" style="text-decoration: none; cursor: pointer">
96
+ <span class="icon" style="margin-right: 8px">
97
+ <i class="fas fa-globe" style="color: white"></i>
98
+ </span>
99
+ <span style="color: white">Paper</span>
100
+ </a>
101
+ </span>
102
+ </div>
103
+ </div>
104
+
105
+ <!-- New Demo Links -->
106
+ </div>
107
+
108
+
109
+ </body></html>
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ doclayout-yolo==0.0.2
2
+ gradio==5.1.0
3
+ gradio-client==1.4.0
4
+ huggingface_hub
visualization.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+
5
+ def colormap(N=256, normalized=False):
6
+ """
7
+ Generate the color map.
8
+
9
+ Args:
10
+ N (int): Number of labels (default is 256).
11
+ normalized (bool): If True, return colors normalized to [0, 1]. Otherwise, return [0, 255].
12
+
13
+ Returns:
14
+ np.ndarray: Color map array of shape (N, 3).
15
+ """
16
+ def bitget(byteval, idx):
17
+ """
18
+ Get the bit value at the specified index.
19
+
20
+ Args:
21
+ byteval (int): The byte value.
22
+ idx (int): The index of the bit.
23
+
24
+ Returns:
25
+ int: The bit value (0 or 1).
26
+ """
27
+ return ((byteval & (1 << idx)) != 0)
28
+
29
+ cmap = np.zeros((N, 3), dtype=np.uint8)
30
+ for i in range(N):
31
+ r = g = b = 0
32
+ c = i
33
+ for j in range(8):
34
+ r = r | (bitget(c, 0) << (7 - j))
35
+ g = g | (bitget(c, 1) << (7 - j))
36
+ b = b | (bitget(c, 2) << (7 - j))
37
+ c = c >> 3
38
+ cmap[i] = np.array([r, g, b])
39
+
40
+ if normalized:
41
+ cmap = cmap.astype(np.float32) / 255.0
42
+
43
+ return cmap
44
+
45
+ def visualize_bbox(image_path, bboxes, classes, scores, id_to_names, alpha=0.3):
46
+ """
47
+ Visualize layout detection results on an image.
48
+
49
+ Args:
50
+ image_path (str): Path to the input image.
51
+ bboxes (list): List of bounding boxes, each represented as [x_min, y_min, x_max, y_max].
52
+ classes (list): List of class IDs corresponding to the bounding boxes.
53
+ id_to_names (dict): Dictionary mapping class IDs to class names.
54
+ alpha (float): Transparency factor for the filled color (default is 0.3).
55
+
56
+ Returns:
57
+ np.ndarray: Image with visualized layout detection results.
58
+ """
59
+ # Check if image_path is a PIL.Image.Image object
60
+ if isinstance(image_path, Image.Image) or isinstance(image_path, np.ndarray):
61
+ image = np.array(image_path)
62
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
63
+ else:
64
+ image = cv2.imread(image_path)
65
+
66
+ overlay = image.copy()
67
+
68
+ cmap = colormap(N=len(id_to_names), normalized=False)
69
+
70
+ # Iterate over each bounding box
71
+ for i, bbox in enumerate(bboxes):
72
+ x_min, y_min, x_max, y_max = map(int, bbox)
73
+ class_id = int(classes[i])
74
+ class_name = id_to_names[class_id]
75
+
76
+ text = class_name + f":{scores[i]:.3f}"
77
+
78
+ color = tuple(int(c) for c in cmap[class_id])
79
+ cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
80
+ cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
81
+
82
+ # Add the class name with a background rectangle
83
+ (text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2)
84
+ cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)
85
+ cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)
86
+
87
+ # Blend the overlay with the original image
88
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
89
+
90
+ return image