onuralpszr commited on
Commit
ab8156f
1 Parent(s): 4fb6d61

feat: ✨ YOLO-World-Seg Image process added

Browse files

Signed-off-by: Onuralp SEZER <thunderbirdtr@gmail.com>

Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +184 -10
  3. requirements.txt +12 -12
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: YOLO World Seg
3
- emoji:
4
  colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
  - openai/clip-vit-base-patch32
13
  - wondervictor/YOLO-World
 
1
  ---
2
  title: YOLO World Seg
3
+ emoji: 🎨
4
  colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
+ license: gpl-3.0
11
  ---
12
  - openai/clip-vit-base-patch32
13
  - wondervictor/YOLO-World
app.py CHANGED
@@ -1,16 +1,190 @@
1
- # import os
2
- # os.system("mim install 'mmengine>=0.7.0'")
3
- # os.system("mim install mmcv")
4
- # os.system("mim install 'mmdet>=3.0.0'")
5
- # os.system("pip install -e .")
6
 
7
 
8
- # from yolo_world import version
 
 
 
 
 
 
 
 
 
 
9
 
10
  import gradio as gr
11
 
12
- def greet(name):
13
- return "text"
14
 
15
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
16
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("mim install 'mmengine>=0.7.0'")
3
+ os.system("mim install mmcv")
4
+ os.system("mim install 'mmdet>=3.0.0'")
5
+ os.system("pip install -e .")
6
 
7
 
8
+ import numpy as np
9
+ import torch
10
+ from mmengine.config import Config
11
+ from mmengine.dataset import Compose
12
+ from mmengine.runner import Runner
13
+ from mmengine.runner.amp import autocast
14
+ from mmyolo.registry import RUNNERS
15
+ from torchvision.ops import nms
16
+ import supervision as sv
17
+ import PIL.Image
18
+ import cv2
19
 
20
  import gradio as gr
21
 
 
 
22
 
23
+ TITLE = """
24
+ # YOLO-World-Seg
25
+
26
+ This is a demo of zero-shot object detection and instance segmentation using
27
+ [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
28
+
29
+ Powered by [Supervision](https://github.com/roboflow/supervision).
30
+ """
31
+
32
+ EXAMPLES = [
33
+ ["https://media.roboflow.com/efficient-sam/corgi.jpg", "dog",0.5,0.5,0.5,100],
34
+ ["https://media.roboflow.com/efficient-sam/horses.jpg", "horse",0.5,0.5,0.5,100],
35
+ ["https://media.roboflow.com/efficient-sam/bears.jpg", "bear",0.5,0.5,0.5,100],
36
+ ]
37
+
38
+ box_annotator = sv.BoxAnnotator()
39
+ label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
40
+ mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
41
+
42
+ def load_runner():
43
+ cfg = Config.fromfile(
44
+ "./configs/segmentation/yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis.py"
45
+ )
46
+ cfg.work_dir = "."
47
+ cfg.load_from = "yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis-5a642d30.pth"
48
+ runner = Runner.from_cfg(cfg)
49
+ runner.call_hook("before_run")
50
+ runner.load_or_resume()
51
+ pipeline = cfg.test_dataloader.dataset.pipeline
52
+ runner.pipeline = Compose(pipeline)
53
+ runner.model.eval()
54
+
55
+ def run_image(
56
+ input_image,
57
+ class_names="person,car,bus,truck",
58
+ score_thr=0.05,
59
+ iou_thr=0.5,
60
+ nms_thr = 0.5,
61
+ max_num_boxes=100,
62
+ ):
63
+ runner = load_runner()
64
+ with open("input.jpeg", "wb") as f:
65
+ f.write(input_image)
66
+
67
+ class_names = [class_name.strip() for class_name in class_names.split(',')]
68
+
69
+ texts = [[t.strip()] for t in class_names.split(",")] + [[" "]]
70
+ data_info = runner.pipeline(dict(img_id=0, img_path="input.jpeg",
71
+ texts=texts))
72
+
73
+ data_batch = dict(
74
+ inputs=data_info["inputs"].unsqueeze(0),
75
+ data_samples=[data_info["data_samples"]],
76
+ )
77
+
78
+ with autocast(enabled=False), torch.no_grad():
79
+ output = runner.model.test_step(data_batch)[0]
80
+ runner.model.class_names = texts
81
+ pred_instances = output.pred_instances
82
+
83
+ keep_idxs = nms(pred_instances.bboxes, pred_instances.scores, iou_threshold=iou_thr)
84
+ pred_instances = pred_instances[keep_idxs]
85
+ pred_instances = pred_instances[pred_instances.scores.float() > score_thr]
86
+
87
+ if len(pred_instances.scores) > max_num_boxes:
88
+ indices = pred_instances.scores.float().topk(max_num_boxes)[1]
89
+ pred_instances = pred_instances[indices]
90
+ output.pred_instances = pred_instances
91
+ result = pred_instances.cpu().numpy()
92
+ detections = sv.Detections(
93
+ xyxy=result['bboxes'],
94
+ class_id=result['labels'],
95
+ confidence=result['scores'],
96
+ mask = result['masks']
97
+ )
98
+ detections = detections.with_nms(threshold=nms_thr)
99
+
100
+ labels = [
101
+ f"{class_id} {confidence:.3f}"
102
+ for class_id, confidence
103
+ in zip(detections.class_id, detections.confidence)
104
+ ]
105
+
106
+ svimage = box_annotator.annotate(input_image, detections)
107
+ svimage = label_annotator.annotate(svimage, detections, labels)
108
+ svimage = mask_annotator.annotate(svimage,detections)
109
+ return svimage
110
+
111
+ confidence_threshold_component = gr.Slider(
112
+ minimum=0,
113
+ maximum=1.0,
114
+ value=0.3,
115
+ step=0.01,
116
+ label="Confidence Threshold",
117
+ info=(
118
+ "The confidence threshold for the YOLO-World model. Lower the threshold to "
119
+ "reduce false negatives, enhancing the model's sensitivity to detect "
120
+ "sought-after objects. Conversely, increase the threshold to minimize false "
121
+ "positives, preventing the model from identifying objects it shouldn't."
122
+ ))
123
+
124
+ iou_threshold_component = gr.Slider(
125
+ minimum=0,
126
+ maximum=1.0,
127
+ value=0.5,
128
+ step=0.01,
129
+ label="IoU Threshold",
130
+ info=(
131
+ "The Intersection over Union (IoU) threshold for non-maximum suppression. "
132
+ "Decrease the value to lessen the occurrence of overlapping bounding boxes, "
133
+ "making the detection process stricter. On the other hand, increase the value "
134
+ "to allow more overlapping bounding boxes, accommodating a broader range of "
135
+ "detections."
136
+ ))
137
+
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown(TITLE)
140
+ with gr.Accordion("Configuration", open=False):
141
+ confidence_threshold_component.render()
142
+ iou_threshold_component.render()
143
+ with gr.Tab(label="Image"):
144
+ with gr.Row():
145
+ input_image_component = gr.Image(
146
+ type='numpy',
147
+ label='Input Image'
148
+ )
149
+ output_image_component = gr.Image(
150
+ type='numpy',
151
+ label='Output Image'
152
+ )
153
+ with gr.Row():
154
+ image_categories_text_component = gr.Textbox(
155
+ label='Categories',
156
+ placeholder='comma separated list of categories',
157
+ scale=7
158
+ )
159
+ image_submit_button_component = gr.Button(
160
+ value='Submit',
161
+ scale=1,
162
+ variant='primary'
163
+ )
164
+ gr.Examples(
165
+ fn=run_image,
166
+ examples=EXAMPLES,
167
+ inputs=[
168
+ input_image_component,
169
+ image_categories_text_component,
170
+ confidence_threshold_component,
171
+ iou_threshold_component,
172
+ ],
173
+ outputs=output_image_component
174
+ )
175
+
176
+
177
+ image_submit_button_component.click(
178
+ fn=run_image,
179
+ inputs=[
180
+ input_image_component,
181
+ image_categories_text_component,
182
+ confidence_threshold_component,
183
+ iou_threshold_component,
184
+ ],
185
+ outputs=output_image_component
186
+ )
187
+
188
+
189
+
190
+ demo.launch(debug=False, show_error=True)
requirements.txt CHANGED
@@ -1,14 +1,14 @@
1
- openmim
2
  gradio
3
- transformers
4
- # numpy
5
- # opencv-python
6
- # supervision
7
- # wheel
8
 
9
- # --extra-index-url https://download.pytorch.org/whl/cu121
10
- # torch==2.1.0+cu121
11
- # torchdata==0.7.0
12
- # torchsummary==1.5.1
13
- # torchtext==0.16.0
14
- # torchvision==0.16.0+cu121
 
1
+ openmim
2
  gradio
3
+ transformers
4
+ numpy
5
+ opencv-python
6
+ supervision
7
+ wheel
8
 
9
+ --extra-index-url https://download.pytorch.org/whl/cu121
10
+ torch==2.1.0+cu121
11
+ torchdata==0.7.0
12
+ torchsummary==1.5.1
13
+ torchtext==0.16.0
14
+ torchvision==0.16.0+cu121