Spaces:
Runtime error
Runtime error
Add app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +169 -0
- README.md +1 -1
- app.py +108 -0
- assets/sample.png +0 -0
- exps/default/__init__.py +3 -0
- exps/default/yolov3.py +33 -0
- exps/default/yolox_l.py +15 -0
- exps/default/yolox_m.py +15 -0
- exps/default/yolox_nano.py +48 -0
- exps/default/yolox_s.py +15 -0
- exps/default/yolox_tiny.py +20 -0
- exps/default/yolox_x.py +15 -0
- exps/openlenda_nano.py +53 -0
- exps/openlenda_s.py +21 -0
- exps/openlenda_tiny.py +25 -0
- exps/openlenda_x.py +20 -0
- models/.gitkeep +0 -0
- predictor.py +87 -0
- requirements.txt +7 -0
- yolox/__init__.py +4 -0
- yolox/core/__init__.py +6 -0
- yolox/core/launch.py +147 -0
- yolox/core/trainer.py +390 -0
- yolox/data/__init__.py +9 -0
- yolox/data/data_augment.py +243 -0
- yolox/data/data_prefetcher.py +51 -0
- yolox/data/dataloading.py +113 -0
- yolox/data/datasets/__init__.py +9 -0
- yolox/data/datasets/coco.py +188 -0
- yolox/data/datasets/coco_classes.py +5 -0
- yolox/data/datasets/datasets_wrapper.py +300 -0
- yolox/data/datasets/mosaicdetection.py +234 -0
- yolox/data/datasets/voc.py +331 -0
- yolox/data/datasets/voc_classes.py +27 -0
- yolox/data/samplers.py +85 -0
- yolox/evaluators/__init__.py +6 -0
- yolox/evaluators/coco_evaluator.py +317 -0
- yolox/evaluators/voc_eval.py +183 -0
- yolox/evaluators/voc_evaluator.py +187 -0
- yolox/exp/__init__.py +6 -0
- yolox/exp/base_exp.py +90 -0
- yolox/exp/build.py +42 -0
- yolox/exp/default/__init__.py +28 -0
- yolox/exp/yolox_base.py +358 -0
- yolox/layers/__init__.py +13 -0
- yolox/layers/cocoeval/cocoeval.cpp +502 -0
- yolox/layers/cocoeval/cocoeval.h +98 -0
- yolox/layers/fast_coco_eval_api.py +151 -0
- yolox/layers/jit_ops.py +138 -0
- yolox/models/__init__.py +11 -0
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
*.png
|
163 |
+
*.jpg
|
164 |
+
*.mp4
|
165 |
+
|
166 |
+
YOLOX_outputs/
|
167 |
+
artifacts/
|
168 |
+
*.engine
|
169 |
+
*.pth
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: OpenLenda
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: OpenLenda
|
3 |
+
emoji: 🚥
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from yolox.exp import get_exp
|
3 |
+
from yolox.data.datasets import COCO_CLASSES
|
4 |
+
from predictor import Predictor
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import time
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
exp = get_exp("exps/openlenda_s.py", None)
|
16 |
+
model = exp.get_model()
|
17 |
+
model.eval()
|
18 |
+
ckpt_file = "models/openlenda_s.pth"
|
19 |
+
model.load_state_dict(torch.load(ckpt_file, map_location="cpu")["model"])
|
20 |
+
predictor = Predictor(
|
21 |
+
model, COCO_CLASSES, "cpu", False, False
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def image_inference(image, confthre, nmsthre):
|
26 |
+
cv2.cvtColor(image, cv2.COLOR_RGB2BGR, image)
|
27 |
+
outputs, img_info = predictor.inference(image, confthre, nmsthre)
|
28 |
+
result_image = predictor.visual(outputs[0], img_info)
|
29 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB, result_image)
|
30 |
+
return result_image
|
31 |
+
|
32 |
+
|
33 |
+
image_interface = gr.Interface(
|
34 |
+
fn=image_inference,
|
35 |
+
inputs=[
|
36 |
+
"image",
|
37 |
+
gr.Slider(0, 1, value=0.5, step=0.01, label="Confidence Threshold", ),
|
38 |
+
gr.Slider(0, 1, value=0.01, step=0.01, label="NMS Threshold")
|
39 |
+
],
|
40 |
+
examples=[["assets/sample.png", 0.5, 0.01]],
|
41 |
+
outputs=gr.Image(type="pil"),
|
42 |
+
title="OpenLenda image demo"
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def video_inference(video_file, confthre, nmsthre, start_sec, duration):
|
47 |
+
start_timestamp = time.strftime("%H:%M:%S", time.gmtime(start_sec))
|
48 |
+
end_timestamp = time.strftime("%H:%M:%S", time.gmtime(start_sec + duration))
|
49 |
+
|
50 |
+
suffix = Path(video_file).suffix
|
51 |
+
|
52 |
+
clip_temp_file = tempfile.NamedTemporaryFile(suffix=suffix)
|
53 |
+
subprocess.call(
|
54 |
+
f"ffmpeg -y -ss {start_timestamp} -i {video_file} -to {end_timestamp} -c copy {clip_temp_file.name}".split()
|
55 |
+
)
|
56 |
+
|
57 |
+
cap = cv2.VideoCapture(clip_temp_file.name)
|
58 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
59 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
60 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
61 |
+
|
62 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_file:
|
63 |
+
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*"MP4V"), fps, (width, height))
|
64 |
+
|
65 |
+
num_frames = 0
|
66 |
+
max_frames = duration * fps
|
67 |
+
while cap.isOpened():
|
68 |
+
try:
|
69 |
+
ret, frame = cap.read()
|
70 |
+
if not ret:
|
71 |
+
break
|
72 |
+
except Exception as e:
|
73 |
+
print(e)
|
74 |
+
continue
|
75 |
+
outputs, img_info = predictor.inference(frame, confthre, nmsthre)
|
76 |
+
result_frame = predictor.visual(outputs[0], img_info)
|
77 |
+
out.write(result_frame)
|
78 |
+
num_frames += 1
|
79 |
+
if num_frames == max_frames:
|
80 |
+
break
|
81 |
+
|
82 |
+
out.release()
|
83 |
+
|
84 |
+
out_file = tempfile.NamedTemporaryFile(suffix="out.mp4", delete=False)
|
85 |
+
subprocess.run(f"ffmpeg -y -loglevel quiet -stats -i {temp_file.name} -c:v libx264 {out_file.name}".split())
|
86 |
+
|
87 |
+
return out_file.name
|
88 |
+
|
89 |
+
|
90 |
+
video_interface = gr.Interface(
|
91 |
+
fn=video_inference,
|
92 |
+
inputs=[
|
93 |
+
gr.Video(),
|
94 |
+
gr.Slider(0, 1, value=0.5, step=0.01, label="Confidence Threshold", ),
|
95 |
+
gr.Slider(0, 1, value=0.01, step=0.01, label="NMS Threshold"),
|
96 |
+
gr.Slider(0, 60, value=0, step=1, label="Start Second"),
|
97 |
+
gr.Slider(0, 10, value=3, step=1, label="Duration"),
|
98 |
+
],
|
99 |
+
outputs=gr.Video(),
|
100 |
+
title="OpenLenda video demo"
|
101 |
+
)
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
gr.TabbedInterface(
|
105 |
+
[image_interface, video_interface],
|
106 |
+
["Image", "Video"],
|
107 |
+
title="OpenLenda demo!",
|
108 |
+
).launch()
|
assets/sample.png
ADDED
exps/default/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
exps/default/yolov3.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from yolox.exp import Exp as MyExp
|
10 |
+
|
11 |
+
|
12 |
+
class Exp(MyExp):
|
13 |
+
def __init__(self):
|
14 |
+
super(Exp, self).__init__()
|
15 |
+
self.depth = 1.0
|
16 |
+
self.width = 1.0
|
17 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
18 |
+
|
19 |
+
def get_model(self, sublinear=False):
|
20 |
+
def init_yolo(M):
|
21 |
+
for m in M.modules():
|
22 |
+
if isinstance(m, nn.BatchNorm2d):
|
23 |
+
m.eps = 1e-3
|
24 |
+
m.momentum = 0.03
|
25 |
+
if "model" not in self.__dict__:
|
26 |
+
from yolox.models import YOLOX, YOLOFPN, YOLOXHead
|
27 |
+
backbone = YOLOFPN()
|
28 |
+
head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
|
29 |
+
self.model = YOLOX(backbone, head)
|
30 |
+
self.model.apply(init_yolo)
|
31 |
+
self.model.head.initialize_biases(1e-2)
|
32 |
+
|
33 |
+
return self.model
|
exps/default/yolox_l.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 1.0
|
14 |
+
self.width = 1.0
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
exps/default/yolox_m.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 0.67
|
14 |
+
self.width = 0.75
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
exps/default/yolox_nano.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from yolox.exp import Exp as MyExp
|
10 |
+
|
11 |
+
|
12 |
+
class Exp(MyExp):
|
13 |
+
def __init__(self):
|
14 |
+
super(Exp, self).__init__()
|
15 |
+
self.depth = 0.33
|
16 |
+
self.width = 0.25
|
17 |
+
self.input_size = (416, 416)
|
18 |
+
self.random_size = (10, 20)
|
19 |
+
self.mosaic_scale = (0.5, 1.5)
|
20 |
+
self.test_size = (416, 416)
|
21 |
+
self.mosaic_prob = 0.5
|
22 |
+
self.enable_mixup = False
|
23 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
24 |
+
|
25 |
+
def get_model(self, sublinear=False):
|
26 |
+
|
27 |
+
def init_yolo(M):
|
28 |
+
for m in M.modules():
|
29 |
+
if isinstance(m, nn.BatchNorm2d):
|
30 |
+
m.eps = 1e-3
|
31 |
+
m.momentum = 0.03
|
32 |
+
if "model" not in self.__dict__:
|
33 |
+
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
|
34 |
+
in_channels = [256, 512, 1024]
|
35 |
+
# NANO model use depthwise = True, which is main difference.
|
36 |
+
backbone = YOLOPAFPN(
|
37 |
+
self.depth, self.width, in_channels=in_channels,
|
38 |
+
act=self.act, depthwise=True,
|
39 |
+
)
|
40 |
+
head = YOLOXHead(
|
41 |
+
self.num_classes, self.width, in_channels=in_channels,
|
42 |
+
act=self.act, depthwise=True
|
43 |
+
)
|
44 |
+
self.model = YOLOX(backbone, head)
|
45 |
+
|
46 |
+
self.model.apply(init_yolo)
|
47 |
+
self.model.head.initialize_biases(1e-2)
|
48 |
+
return self.model
|
exps/default/yolox_s.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 0.33
|
14 |
+
self.width = 0.50
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
exps/default/yolox_tiny.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 0.33
|
14 |
+
self.width = 0.375
|
15 |
+
self.input_size = (416, 416)
|
16 |
+
self.mosaic_scale = (0.5, 1.5)
|
17 |
+
self.random_size = (10, 20)
|
18 |
+
self.test_size = (416, 416)
|
19 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
20 |
+
self.enable_mixup = False
|
exps/default/yolox_x.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 1.33
|
14 |
+
self.width = 1.25
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
exps/openlenda_nano.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from yolox.exp import Exp as MyExp
|
10 |
+
|
11 |
+
|
12 |
+
class Exp(MyExp):
|
13 |
+
def __init__(self):
|
14 |
+
super(Exp, self).__init__()
|
15 |
+
self.depth = 0.33
|
16 |
+
self.width = 0.25
|
17 |
+
self.input_size = (416, 416)
|
18 |
+
self.random_size = (10, 20)
|
19 |
+
self.mosaic_scale = (0.5, 1.5)
|
20 |
+
self.test_size = (416, 416)
|
21 |
+
self.mosaic_prob = 0.5
|
22 |
+
self.enable_mixup = False
|
23 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
24 |
+
# max training epoch
|
25 |
+
self.max_epoch = 30
|
26 |
+
self.num_classes = 8
|
27 |
+
# --------------- transform config ----------------- #
|
28 |
+
self.flip_prob = 0
|
29 |
+
|
30 |
+
def get_model(self, sublinear=False):
|
31 |
+
|
32 |
+
def init_yolo(M):
|
33 |
+
for m in M.modules():
|
34 |
+
if isinstance(m, nn.BatchNorm2d):
|
35 |
+
m.eps = 1e-3
|
36 |
+
m.momentum = 0.03
|
37 |
+
if "model" not in self.__dict__:
|
38 |
+
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
|
39 |
+
in_channels = [256, 512, 1024]
|
40 |
+
# NANO model use depthwise = True, which is main difference.
|
41 |
+
backbone = YOLOPAFPN(
|
42 |
+
self.depth, self.width, in_channels=in_channels,
|
43 |
+
act=self.act, depthwise=True,
|
44 |
+
)
|
45 |
+
head = YOLOXHead(
|
46 |
+
self.num_classes, self.width, in_channels=in_channels,
|
47 |
+
act=self.act, depthwise=True
|
48 |
+
)
|
49 |
+
self.model = YOLOX(backbone, head)
|
50 |
+
|
51 |
+
self.model.apply(init_yolo)
|
52 |
+
self.model.head.initialize_biases(1e-2)
|
53 |
+
return self.model
|
exps/openlenda_s.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 0.33
|
14 |
+
self.width = 0.50
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
16 |
+
# max training epoch
|
17 |
+
self.max_epoch = 30
|
18 |
+
self.num_classes = 8
|
19 |
+
# --------------- transform config ----------------- #
|
20 |
+
self.flip_prob = 0
|
21 |
+
self.input_size = (1280, 1280) # (height, width)
|
exps/openlenda_tiny.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 0.33
|
14 |
+
self.width = 0.375
|
15 |
+
self.input_size = (416, 416)
|
16 |
+
self.mosaic_scale = (0.5, 1.5)
|
17 |
+
self.random_size = (10, 20)
|
18 |
+
self.test_size = (416, 416)
|
19 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
20 |
+
self.enable_mixup = False
|
21 |
+
# max training epoch
|
22 |
+
self.max_epoch = 30
|
23 |
+
self.num_classes = 8
|
24 |
+
# --------------- transform config ----------------- #
|
25 |
+
self.flip_prob = 0
|
exps/openlenda_x.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from yolox.exp import Exp as MyExp
|
8 |
+
|
9 |
+
|
10 |
+
class Exp(MyExp):
|
11 |
+
def __init__(self):
|
12 |
+
super(Exp, self).__init__()
|
13 |
+
self.depth = 1.33
|
14 |
+
self.width = 1.25
|
15 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
16 |
+
# max training epoch
|
17 |
+
self.max_epoch = 30
|
18 |
+
self.num_classes = 8
|
19 |
+
# --------------- transform config ----------------- #
|
20 |
+
self.input_size = (640, 800) # (height, width)
|
models/.gitkeep
ADDED
File without changes
|
predictor.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from yolox.data.data_augment import ValTransform
|
10 |
+
from yolox.data.datasets import COCO_CLASSES
|
11 |
+
from yolox.utils import postprocess, vis
|
12 |
+
|
13 |
+
|
14 |
+
class Predictor(object):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
model,
|
18 |
+
cls_names=COCO_CLASSES,
|
19 |
+
device="cpu",
|
20 |
+
fp16=False,
|
21 |
+
legacy=False,
|
22 |
+
):
|
23 |
+
self.model = model
|
24 |
+
self.cls_names = cls_names
|
25 |
+
self.num_classes = len(COCO_CLASSES)
|
26 |
+
self.confthre = 0.01
|
27 |
+
self.nmsthre = 0.01
|
28 |
+
self.test_size = (640, 640)
|
29 |
+
self.device = device
|
30 |
+
self.fp16 = fp16
|
31 |
+
self.preproc = ValTransform(legacy=legacy)
|
32 |
+
|
33 |
+
def inference(self, img, confthre=None, nmsthre=None, test_size=None):
|
34 |
+
if confthre is not None:
|
35 |
+
self.confthre = confthre
|
36 |
+
if nmsthre is not None:
|
37 |
+
self.nmsthre = nmsthre
|
38 |
+
if test_size is not None:
|
39 |
+
self.test_size = test_size
|
40 |
+
img_info = {"id": 0}
|
41 |
+
if isinstance(img, str):
|
42 |
+
img_info["file_name"] = os.path.basename(img)
|
43 |
+
img = cv2.imread(img)
|
44 |
+
else:
|
45 |
+
img_info["file_name"] = None
|
46 |
+
cv2.imwrite("test.png", img)
|
47 |
+
height, width = img.shape[:2]
|
48 |
+
img_info["height"] = height
|
49 |
+
img_info["width"] = width
|
50 |
+
img_info["raw_img"] = img
|
51 |
+
|
52 |
+
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
|
53 |
+
img_info["ratio"] = ratio
|
54 |
+
|
55 |
+
img, _ = self.preproc(img, None, self.test_size)
|
56 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
57 |
+
img = img.float()
|
58 |
+
if self.device == "gpu":
|
59 |
+
img = img.cuda()
|
60 |
+
if self.fp16:
|
61 |
+
img = img.half() # to FP16
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = self.model(img)
|
65 |
+
outputs = postprocess(
|
66 |
+
outputs, self.num_classes, self.confthre,
|
67 |
+
self.nmsthre
|
68 |
+
)
|
69 |
+
return outputs, img_info
|
70 |
+
|
71 |
+
def visual(self, output, img_info):
|
72 |
+
ratio = img_info["ratio"]
|
73 |
+
img = img_info["raw_img"]
|
74 |
+
if output is None:
|
75 |
+
return img
|
76 |
+
output = output.cpu()
|
77 |
+
|
78 |
+
bboxes = output[:, 0:4]
|
79 |
+
|
80 |
+
# preprocessing: resize
|
81 |
+
bboxes /= ratio
|
82 |
+
|
83 |
+
cls = output[:, 6]
|
84 |
+
scores = output[:, 4] * output[:, 5]
|
85 |
+
|
86 |
+
vis_res = vis(img, bboxes, scores, cls, self.confthre, self.cls_names)
|
87 |
+
return vis_res
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
loguru
|
2 |
+
tabulate
|
3 |
+
psutil
|
4 |
+
pycocotools
|
5 |
+
torch >= 2.0.1
|
6 |
+
torchvision >= 0.15.2
|
7 |
+
opencv-python
|
yolox/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
__version__ = "0.3.0"
|
yolox/core/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
from .launch import launch
|
6 |
+
from .trainer import Trainer
|
yolox/core/launch.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Code are based on
|
4 |
+
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
|
5 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
6 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
7 |
+
|
8 |
+
import sys
|
9 |
+
from datetime import timedelta
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
|
16 |
+
import yolox.utils.dist as comm
|
17 |
+
|
18 |
+
__all__ = ["launch"]
|
19 |
+
|
20 |
+
|
21 |
+
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
22 |
+
|
23 |
+
|
24 |
+
def _find_free_port():
|
25 |
+
"""
|
26 |
+
Find an available port of current machine / node.
|
27 |
+
"""
|
28 |
+
import socket
|
29 |
+
|
30 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
31 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
32 |
+
sock.bind(("", 0))
|
33 |
+
port = sock.getsockname()[1]
|
34 |
+
sock.close()
|
35 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
36 |
+
return port
|
37 |
+
|
38 |
+
|
39 |
+
def launch(
|
40 |
+
main_func,
|
41 |
+
num_gpus_per_machine,
|
42 |
+
num_machines=1,
|
43 |
+
machine_rank=0,
|
44 |
+
backend="nccl",
|
45 |
+
dist_url=None,
|
46 |
+
args=(),
|
47 |
+
timeout=DEFAULT_TIMEOUT,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
main_func: a function that will be called by `main_func(*args)`
|
52 |
+
num_machines (int): the total number of machines
|
53 |
+
machine_rank (int): the rank of this machine (one per machine)
|
54 |
+
dist_url (str): url to connect to for distributed training, including protocol
|
55 |
+
e.g. "tcp://127.0.0.1:8686".
|
56 |
+
Can be set to auto to automatically select a free port on localhost
|
57 |
+
args (tuple): arguments passed to main_func
|
58 |
+
"""
|
59 |
+
world_size = num_machines * num_gpus_per_machine
|
60 |
+
if world_size > 1:
|
61 |
+
# https://github.com/pytorch/pytorch/pull/14391
|
62 |
+
# TODO prctl in spawned processes
|
63 |
+
|
64 |
+
if dist_url == "auto":
|
65 |
+
assert (
|
66 |
+
num_machines == 1
|
67 |
+
), "dist_url=auto cannot work with distributed training."
|
68 |
+
port = _find_free_port()
|
69 |
+
dist_url = f"tcp://127.0.0.1:{port}"
|
70 |
+
|
71 |
+
start_method = "spawn"
|
72 |
+
cache = vars(args[1]).get("cache", False)
|
73 |
+
|
74 |
+
# To use numpy memmap for caching image into RAM, we have to use fork method
|
75 |
+
if cache:
|
76 |
+
assert sys.platform != "win32", (
|
77 |
+
"As Windows platform doesn't support fork method, "
|
78 |
+
"do not add --cache in your training command."
|
79 |
+
)
|
80 |
+
start_method = "fork"
|
81 |
+
|
82 |
+
mp.start_processes(
|
83 |
+
_distributed_worker,
|
84 |
+
nprocs=num_gpus_per_machine,
|
85 |
+
args=(
|
86 |
+
main_func,
|
87 |
+
world_size,
|
88 |
+
num_gpus_per_machine,
|
89 |
+
machine_rank,
|
90 |
+
backend,
|
91 |
+
dist_url,
|
92 |
+
args,
|
93 |
+
),
|
94 |
+
daemon=False,
|
95 |
+
start_method=start_method,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
main_func(*args)
|
99 |
+
|
100 |
+
|
101 |
+
def _distributed_worker(
|
102 |
+
local_rank,
|
103 |
+
main_func,
|
104 |
+
world_size,
|
105 |
+
num_gpus_per_machine,
|
106 |
+
machine_rank,
|
107 |
+
backend,
|
108 |
+
dist_url,
|
109 |
+
args,
|
110 |
+
timeout=DEFAULT_TIMEOUT,
|
111 |
+
):
|
112 |
+
assert (
|
113 |
+
torch.cuda.is_available()
|
114 |
+
), "cuda is not available. Please check your installation."
|
115 |
+
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
116 |
+
logger.info("Rank {} initialization finished.".format(global_rank))
|
117 |
+
try:
|
118 |
+
dist.init_process_group(
|
119 |
+
backend=backend,
|
120 |
+
init_method=dist_url,
|
121 |
+
world_size=world_size,
|
122 |
+
rank=global_rank,
|
123 |
+
timeout=timeout,
|
124 |
+
)
|
125 |
+
except Exception:
|
126 |
+
logger.error("Process group URL: {}".format(dist_url))
|
127 |
+
raise
|
128 |
+
|
129 |
+
# Setup the local process group (which contains ranks within the same machine)
|
130 |
+
assert comm._LOCAL_PROCESS_GROUP is None
|
131 |
+
num_machines = world_size // num_gpus_per_machine
|
132 |
+
for i in range(num_machines):
|
133 |
+
ranks_on_i = list(
|
134 |
+
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
|
135 |
+
)
|
136 |
+
pg = dist.new_group(ranks_on_i)
|
137 |
+
if i == machine_rank:
|
138 |
+
comm._LOCAL_PROCESS_GROUP = pg
|
139 |
+
|
140 |
+
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
141 |
+
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
142 |
+
comm.synchronize()
|
143 |
+
|
144 |
+
assert num_gpus_per_machine <= torch.cuda.device_count()
|
145 |
+
torch.cuda.set_device(local_rank)
|
146 |
+
|
147 |
+
main_func(*args)
|
yolox/core/trainer.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import datetime
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
11 |
+
from torch.utils.tensorboard import SummaryWriter
|
12 |
+
|
13 |
+
from yolox.data import DataPrefetcher
|
14 |
+
from yolox.exp import Exp
|
15 |
+
from yolox.utils import (
|
16 |
+
MeterBuffer,
|
17 |
+
ModelEMA,
|
18 |
+
WandbLogger,
|
19 |
+
adjust_status,
|
20 |
+
all_reduce_norm,
|
21 |
+
get_local_rank,
|
22 |
+
get_model_info,
|
23 |
+
get_rank,
|
24 |
+
get_world_size,
|
25 |
+
gpu_mem_usage,
|
26 |
+
is_parallel,
|
27 |
+
load_ckpt,
|
28 |
+
mem_usage,
|
29 |
+
occupy_mem,
|
30 |
+
save_checkpoint,
|
31 |
+
setup_logger,
|
32 |
+
synchronize
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class Trainer:
|
37 |
+
def __init__(self, exp: Exp, args):
|
38 |
+
# init function only defines some basic attr, other attrs like model, optimizer are built in
|
39 |
+
# before_train methods.
|
40 |
+
self.exp = exp
|
41 |
+
self.args = args
|
42 |
+
|
43 |
+
# training related attr
|
44 |
+
self.max_epoch = exp.max_epoch
|
45 |
+
self.amp_training = args.fp16
|
46 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
|
47 |
+
self.is_distributed = get_world_size() > 1
|
48 |
+
self.rank = get_rank()
|
49 |
+
self.local_rank = get_local_rank()
|
50 |
+
self.device = "cuda:{}".format(self.local_rank)
|
51 |
+
self.use_model_ema = exp.ema
|
52 |
+
self.save_history_ckpt = exp.save_history_ckpt
|
53 |
+
|
54 |
+
# data/dataloader related attr
|
55 |
+
self.data_type = torch.float16 if args.fp16 else torch.float32
|
56 |
+
self.input_size = exp.input_size
|
57 |
+
self.best_ap = 0
|
58 |
+
|
59 |
+
# metric record
|
60 |
+
self.meter = MeterBuffer(window_size=exp.print_interval)
|
61 |
+
self.file_name = os.path.join(exp.output_dir, args.experiment_name)
|
62 |
+
|
63 |
+
if self.rank == 0:
|
64 |
+
os.makedirs(self.file_name, exist_ok=True)
|
65 |
+
|
66 |
+
setup_logger(
|
67 |
+
self.file_name,
|
68 |
+
distributed_rank=self.rank,
|
69 |
+
filename="train_log.txt",
|
70 |
+
mode="a",
|
71 |
+
)
|
72 |
+
|
73 |
+
def train(self):
|
74 |
+
self.before_train()
|
75 |
+
try:
|
76 |
+
self.train_in_epoch()
|
77 |
+
except Exception:
|
78 |
+
raise
|
79 |
+
finally:
|
80 |
+
self.after_train()
|
81 |
+
|
82 |
+
def train_in_epoch(self):
|
83 |
+
for self.epoch in range(self.start_epoch, self.max_epoch):
|
84 |
+
self.before_epoch()
|
85 |
+
self.train_in_iter()
|
86 |
+
self.after_epoch()
|
87 |
+
|
88 |
+
def train_in_iter(self):
|
89 |
+
for self.iter in range(self.max_iter):
|
90 |
+
self.before_iter()
|
91 |
+
self.train_one_iter()
|
92 |
+
self.after_iter()
|
93 |
+
|
94 |
+
def train_one_iter(self):
|
95 |
+
iter_start_time = time.time()
|
96 |
+
|
97 |
+
inps, targets = self.prefetcher.next()
|
98 |
+
inps = inps.to(self.data_type)
|
99 |
+
targets = targets.to(self.data_type)
|
100 |
+
targets.requires_grad = False
|
101 |
+
inps, targets = self.exp.preprocess(inps, targets, self.input_size)
|
102 |
+
data_end_time = time.time()
|
103 |
+
|
104 |
+
with torch.cuda.amp.autocast(enabled=self.amp_training):
|
105 |
+
outputs = self.model(inps, targets)
|
106 |
+
|
107 |
+
loss = outputs["total_loss"]
|
108 |
+
|
109 |
+
self.optimizer.zero_grad()
|
110 |
+
self.scaler.scale(loss).backward()
|
111 |
+
self.scaler.step(self.optimizer)
|
112 |
+
self.scaler.update()
|
113 |
+
|
114 |
+
if self.use_model_ema:
|
115 |
+
self.ema_model.update(self.model)
|
116 |
+
|
117 |
+
lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
|
118 |
+
for param_group in self.optimizer.param_groups:
|
119 |
+
param_group["lr"] = lr
|
120 |
+
|
121 |
+
iter_end_time = time.time()
|
122 |
+
self.meter.update(
|
123 |
+
iter_time=iter_end_time - iter_start_time,
|
124 |
+
data_time=data_end_time - iter_start_time,
|
125 |
+
lr=lr,
|
126 |
+
**outputs,
|
127 |
+
)
|
128 |
+
|
129 |
+
def before_train(self):
|
130 |
+
logger.info("args: {}".format(self.args))
|
131 |
+
logger.info("exp value:\n{}".format(self.exp))
|
132 |
+
|
133 |
+
# model related init
|
134 |
+
torch.cuda.set_device(self.local_rank)
|
135 |
+
model = self.exp.get_model()
|
136 |
+
logger.info(
|
137 |
+
"Model Summary: {}".format(get_model_info(model, self.exp.test_size))
|
138 |
+
)
|
139 |
+
model.to(self.device)
|
140 |
+
|
141 |
+
# solver related init
|
142 |
+
self.optimizer = self.exp.get_optimizer(self.args.batch_size)
|
143 |
+
|
144 |
+
# value of epoch will be set in `resume_train`
|
145 |
+
model = self.resume_train(model)
|
146 |
+
|
147 |
+
# data related init
|
148 |
+
self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
|
149 |
+
self.train_loader = self.exp.get_data_loader(
|
150 |
+
batch_size=self.args.batch_size,
|
151 |
+
is_distributed=self.is_distributed,
|
152 |
+
no_aug=self.no_aug,
|
153 |
+
cache_img=self.args.cache,
|
154 |
+
)
|
155 |
+
logger.info("init prefetcher, this might take one minute or less...")
|
156 |
+
self.prefetcher = DataPrefetcher(self.train_loader)
|
157 |
+
# max_iter means iters per epoch
|
158 |
+
self.max_iter = len(self.train_loader)
|
159 |
+
|
160 |
+
self.lr_scheduler = self.exp.get_lr_scheduler(
|
161 |
+
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
|
162 |
+
)
|
163 |
+
if self.args.occupy:
|
164 |
+
occupy_mem(self.local_rank)
|
165 |
+
|
166 |
+
if self.is_distributed:
|
167 |
+
model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
|
168 |
+
|
169 |
+
if self.use_model_ema:
|
170 |
+
self.ema_model = ModelEMA(model, 0.9998)
|
171 |
+
self.ema_model.updates = self.max_iter * self.start_epoch
|
172 |
+
|
173 |
+
self.model = model
|
174 |
+
|
175 |
+
self.evaluator = self.exp.get_evaluator(
|
176 |
+
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
177 |
+
)
|
178 |
+
# Tensorboard and Wandb loggers
|
179 |
+
if self.rank == 0:
|
180 |
+
if self.args.logger == "tensorboard":
|
181 |
+
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
|
182 |
+
elif self.args.logger == "wandb":
|
183 |
+
self.wandb_logger = WandbLogger.initialize_wandb_logger(
|
184 |
+
self.args,
|
185 |
+
self.exp,
|
186 |
+
self.evaluator.dataloader.dataset
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
raise ValueError("logger must be either 'tensorboard' or 'wandb'")
|
190 |
+
|
191 |
+
logger.info("Training start...")
|
192 |
+
logger.info("\n{}".format(model))
|
193 |
+
|
194 |
+
def after_train(self):
|
195 |
+
logger.info(
|
196 |
+
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
|
197 |
+
)
|
198 |
+
if self.rank == 0:
|
199 |
+
if self.args.logger == "wandb":
|
200 |
+
self.wandb_logger.finish()
|
201 |
+
|
202 |
+
def before_epoch(self):
|
203 |
+
logger.info("---> start train epoch{}".format(self.epoch + 1))
|
204 |
+
|
205 |
+
if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
|
206 |
+
logger.info("--->No mosaic aug now!")
|
207 |
+
self.train_loader.close_mosaic()
|
208 |
+
logger.info("--->Add additional L1 loss now!")
|
209 |
+
if self.is_distributed:
|
210 |
+
self.model.module.head.use_l1 = True
|
211 |
+
else:
|
212 |
+
self.model.head.use_l1 = True
|
213 |
+
self.exp.eval_interval = 1
|
214 |
+
if not self.no_aug:
|
215 |
+
self.save_ckpt(ckpt_name="last_mosaic_epoch")
|
216 |
+
|
217 |
+
def after_epoch(self):
|
218 |
+
self.save_ckpt(ckpt_name="latest")
|
219 |
+
|
220 |
+
if (self.epoch + 1) % self.exp.eval_interval == 0:
|
221 |
+
all_reduce_norm(self.model)
|
222 |
+
self.evaluate_and_save_model()
|
223 |
+
|
224 |
+
def before_iter(self):
|
225 |
+
pass
|
226 |
+
|
227 |
+
def after_iter(self):
|
228 |
+
"""
|
229 |
+
`after_iter` contains two parts of logic:
|
230 |
+
* log information
|
231 |
+
* reset setting of resize
|
232 |
+
"""
|
233 |
+
# log needed information
|
234 |
+
if (self.iter + 1) % self.exp.print_interval == 0:
|
235 |
+
# TODO check ETA logic
|
236 |
+
left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
|
237 |
+
eta_seconds = self.meter["iter_time"].global_avg * left_iters
|
238 |
+
eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
|
239 |
+
|
240 |
+
progress_str = "epoch: {}/{}, iter: {}/{}".format(
|
241 |
+
self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
|
242 |
+
)
|
243 |
+
loss_meter = self.meter.get_filtered_meter("loss")
|
244 |
+
loss_str = ", ".join(
|
245 |
+
["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()]
|
246 |
+
)
|
247 |
+
|
248 |
+
time_meter = self.meter.get_filtered_meter("time")
|
249 |
+
time_str = ", ".join(
|
250 |
+
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
|
251 |
+
)
|
252 |
+
|
253 |
+
mem_str = "gpu mem: {:.0f}Mb, mem: {:.1f}Gb".format(gpu_mem_usage(), mem_usage())
|
254 |
+
|
255 |
+
logger.info(
|
256 |
+
"{}, {}, {}, {}, lr: {:.3e}".format(
|
257 |
+
progress_str,
|
258 |
+
mem_str,
|
259 |
+
time_str,
|
260 |
+
loss_str,
|
261 |
+
self.meter["lr"].latest,
|
262 |
+
)
|
263 |
+
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
|
264 |
+
)
|
265 |
+
|
266 |
+
if self.rank == 0:
|
267 |
+
if self.args.logger == "tensorboard":
|
268 |
+
self.tblogger.add_scalar(
|
269 |
+
"train/lr", self.meter["lr"].latest, self.progress_in_iter)
|
270 |
+
for k, v in loss_meter.items():
|
271 |
+
self.tblogger.add_scalar(
|
272 |
+
f"train/{k}", v.latest, self.progress_in_iter)
|
273 |
+
if self.args.logger == "wandb":
|
274 |
+
metrics = {"train/" + k: v.latest for k, v in loss_meter.items()}
|
275 |
+
metrics.update({
|
276 |
+
"train/lr": self.meter["lr"].latest
|
277 |
+
})
|
278 |
+
self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter)
|
279 |
+
|
280 |
+
self.meter.clear_meters()
|
281 |
+
|
282 |
+
# random resizing
|
283 |
+
if (self.progress_in_iter + 1) % 10 == 0:
|
284 |
+
self.input_size = self.exp.random_resize(
|
285 |
+
self.train_loader, self.epoch, self.rank, self.is_distributed
|
286 |
+
)
|
287 |
+
|
288 |
+
@property
|
289 |
+
def progress_in_iter(self):
|
290 |
+
return self.epoch * self.max_iter + self.iter
|
291 |
+
|
292 |
+
def resume_train(self, model):
|
293 |
+
if self.args.resume:
|
294 |
+
logger.info("resume training")
|
295 |
+
if self.args.ckpt is None:
|
296 |
+
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth")
|
297 |
+
else:
|
298 |
+
ckpt_file = self.args.ckpt
|
299 |
+
|
300 |
+
ckpt = torch.load(ckpt_file, map_location=self.device)
|
301 |
+
# resume the model/optimizer state dict
|
302 |
+
model.load_state_dict(ckpt["model"])
|
303 |
+
self.optimizer.load_state_dict(ckpt["optimizer"])
|
304 |
+
self.best_ap = ckpt.pop("best_ap", 0)
|
305 |
+
# resume the training states variables
|
306 |
+
start_epoch = (
|
307 |
+
self.args.start_epoch - 1
|
308 |
+
if self.args.start_epoch is not None
|
309 |
+
else ckpt["start_epoch"]
|
310 |
+
)
|
311 |
+
self.start_epoch = start_epoch
|
312 |
+
logger.info(
|
313 |
+
"loaded checkpoint '{}' (epoch {})".format(
|
314 |
+
self.args.resume, self.start_epoch
|
315 |
+
)
|
316 |
+
) # noqa
|
317 |
+
else:
|
318 |
+
if self.args.ckpt is not None:
|
319 |
+
logger.info("loading checkpoint for fine tuning")
|
320 |
+
ckpt_file = self.args.ckpt
|
321 |
+
ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
|
322 |
+
model = load_ckpt(model, ckpt)
|
323 |
+
self.start_epoch = 0
|
324 |
+
|
325 |
+
return model
|
326 |
+
|
327 |
+
def evaluate_and_save_model(self):
|
328 |
+
if self.use_model_ema:
|
329 |
+
evalmodel = self.ema_model.ema
|
330 |
+
else:
|
331 |
+
evalmodel = self.model
|
332 |
+
if is_parallel(evalmodel):
|
333 |
+
evalmodel = evalmodel.module
|
334 |
+
|
335 |
+
with adjust_status(evalmodel, training=False):
|
336 |
+
(ap50_95, ap50, summary), predictions = self.exp.eval(
|
337 |
+
evalmodel, self.evaluator, self.is_distributed, return_outputs=True
|
338 |
+
)
|
339 |
+
|
340 |
+
update_best_ckpt = ap50_95 > self.best_ap
|
341 |
+
self.best_ap = max(self.best_ap, ap50_95)
|
342 |
+
|
343 |
+
if self.rank == 0:
|
344 |
+
if self.args.logger == "tensorboard":
|
345 |
+
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
346 |
+
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
|
347 |
+
if self.args.logger == "wandb":
|
348 |
+
self.wandb_logger.log_metrics({
|
349 |
+
"val/COCOAP50": ap50,
|
350 |
+
"val/COCOAP50_95": ap50_95,
|
351 |
+
"train/epoch": self.epoch + 1,
|
352 |
+
})
|
353 |
+
self.wandb_logger.log_images(predictions)
|
354 |
+
logger.info("\n" + summary)
|
355 |
+
synchronize()
|
356 |
+
|
357 |
+
self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95)
|
358 |
+
if self.save_history_ckpt:
|
359 |
+
self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95)
|
360 |
+
|
361 |
+
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
|
362 |
+
if self.rank == 0:
|
363 |
+
save_model = self.ema_model.ema if self.use_model_ema else self.model
|
364 |
+
logger.info("Save weights to {}".format(self.file_name))
|
365 |
+
ckpt_state = {
|
366 |
+
"start_epoch": self.epoch + 1,
|
367 |
+
"model": save_model.state_dict(),
|
368 |
+
"optimizer": self.optimizer.state_dict(),
|
369 |
+
"best_ap": self.best_ap,
|
370 |
+
"curr_ap": ap,
|
371 |
+
}
|
372 |
+
save_checkpoint(
|
373 |
+
ckpt_state,
|
374 |
+
update_best_ckpt,
|
375 |
+
self.file_name,
|
376 |
+
ckpt_name,
|
377 |
+
)
|
378 |
+
|
379 |
+
if self.args.logger == "wandb":
|
380 |
+
self.wandb_logger.save_checkpoint(
|
381 |
+
self.file_name,
|
382 |
+
ckpt_name,
|
383 |
+
update_best_ckpt,
|
384 |
+
metadata={
|
385 |
+
"epoch": self.epoch + 1,
|
386 |
+
"optimizer": self.optimizer.state_dict(),
|
387 |
+
"best_ap": self.best_ap,
|
388 |
+
"curr_ap": ap
|
389 |
+
}
|
390 |
+
)
|
yolox/data/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
from .data_augment import TrainTransform, ValTransform
|
6 |
+
from .data_prefetcher import DataPrefetcher
|
7 |
+
from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed
|
8 |
+
from .datasets import *
|
9 |
+
from .samplers import InfiniteSampler, YoloBatchSampler
|
yolox/data/data_augment.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
"""
|
5 |
+
Data augmentation functionality. Passed as callable transformations to
|
6 |
+
Dataset classes.
|
7 |
+
|
8 |
+
The data augmentation procedures were interpreted from @weiliu89's SSD paper
|
9 |
+
http://arxiv.org/abs/1512.02325
|
10 |
+
"""
|
11 |
+
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from yolox.utils import xyxy2cxcywh
|
19 |
+
|
20 |
+
|
21 |
+
def augment_hsv(img, hgain=5, sgain=30, vgain=30):
|
22 |
+
hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] # random gains
|
23 |
+
hsv_augs *= np.random.randint(0, 2, 3) # random selection of h, s, v
|
24 |
+
hsv_augs = hsv_augs.astype(np.int16)
|
25 |
+
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
|
26 |
+
|
27 |
+
img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
|
28 |
+
img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
|
29 |
+
img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
|
30 |
+
|
31 |
+
cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
32 |
+
|
33 |
+
|
34 |
+
def get_aug_params(value, center=0):
|
35 |
+
if isinstance(value, float):
|
36 |
+
return random.uniform(center - value, center + value)
|
37 |
+
elif len(value) == 2:
|
38 |
+
return random.uniform(value[0], value[1])
|
39 |
+
else:
|
40 |
+
raise ValueError(
|
41 |
+
"Affine params should be either a sequence containing two values\
|
42 |
+
or single float values. Got {}".format(value)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def get_affine_matrix(
|
47 |
+
target_size,
|
48 |
+
degrees=10,
|
49 |
+
translate=0.1,
|
50 |
+
scales=0.1,
|
51 |
+
shear=10,
|
52 |
+
):
|
53 |
+
twidth, theight = target_size
|
54 |
+
|
55 |
+
# Rotation and Scale
|
56 |
+
angle = get_aug_params(degrees)
|
57 |
+
scale = get_aug_params(scales, center=1.0)
|
58 |
+
|
59 |
+
if scale <= 0.0:
|
60 |
+
raise ValueError("Argument scale should be positive")
|
61 |
+
|
62 |
+
R = cv2.getRotationMatrix2D(angle=angle, center=(0, 0), scale=scale)
|
63 |
+
|
64 |
+
M = np.ones([2, 3])
|
65 |
+
# Shear
|
66 |
+
shear_x = math.tan(get_aug_params(shear) * math.pi / 180)
|
67 |
+
shear_y = math.tan(get_aug_params(shear) * math.pi / 180)
|
68 |
+
|
69 |
+
M[0] = R[0] + shear_y * R[1]
|
70 |
+
M[1] = R[1] + shear_x * R[0]
|
71 |
+
|
72 |
+
# Translation
|
73 |
+
translation_x = get_aug_params(translate) * twidth # x translation (pixels)
|
74 |
+
translation_y = get_aug_params(translate) * theight # y translation (pixels)
|
75 |
+
|
76 |
+
M[0, 2] = translation_x
|
77 |
+
M[1, 2] = translation_y
|
78 |
+
|
79 |
+
return M, scale
|
80 |
+
|
81 |
+
|
82 |
+
def apply_affine_to_bboxes(targets, target_size, M, scale):
|
83 |
+
num_gts = len(targets)
|
84 |
+
|
85 |
+
# warp corner points
|
86 |
+
twidth, theight = target_size
|
87 |
+
corner_points = np.ones((4 * num_gts, 3))
|
88 |
+
corner_points[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
|
89 |
+
4 * num_gts, 2
|
90 |
+
) # x1y1, x2y2, x1y2, x2y1
|
91 |
+
corner_points = corner_points @ M.T # apply affine transform
|
92 |
+
corner_points = corner_points.reshape(num_gts, 8)
|
93 |
+
|
94 |
+
# create new boxes
|
95 |
+
corner_xs = corner_points[:, 0::2]
|
96 |
+
corner_ys = corner_points[:, 1::2]
|
97 |
+
new_bboxes = (
|
98 |
+
np.concatenate(
|
99 |
+
(corner_xs.min(1), corner_ys.min(1), corner_xs.max(1), corner_ys.max(1))
|
100 |
+
)
|
101 |
+
.reshape(4, num_gts)
|
102 |
+
.T
|
103 |
+
)
|
104 |
+
|
105 |
+
# clip boxes
|
106 |
+
new_bboxes[:, 0::2] = new_bboxes[:, 0::2].clip(0, twidth)
|
107 |
+
new_bboxes[:, 1::2] = new_bboxes[:, 1::2].clip(0, theight)
|
108 |
+
|
109 |
+
targets[:, :4] = new_bboxes
|
110 |
+
|
111 |
+
return targets
|
112 |
+
|
113 |
+
|
114 |
+
def random_affine(
|
115 |
+
img,
|
116 |
+
targets=(),
|
117 |
+
target_size=(640, 640),
|
118 |
+
degrees=10,
|
119 |
+
translate=0.1,
|
120 |
+
scales=0.1,
|
121 |
+
shear=10,
|
122 |
+
):
|
123 |
+
M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
|
124 |
+
|
125 |
+
img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
|
126 |
+
|
127 |
+
# Transform label coordinates
|
128 |
+
if len(targets) > 0:
|
129 |
+
targets = apply_affine_to_bboxes(targets, target_size, M, scale)
|
130 |
+
|
131 |
+
return img, targets
|
132 |
+
|
133 |
+
|
134 |
+
def _mirror(image, boxes, prob=0.5):
|
135 |
+
_, width, _ = image.shape
|
136 |
+
if random.random() < prob:
|
137 |
+
image = image[:, ::-1]
|
138 |
+
boxes[:, 0::2] = width - boxes[:, 2::-2]
|
139 |
+
return image, boxes
|
140 |
+
|
141 |
+
|
142 |
+
def preproc(img, input_size, swap=(2, 0, 1)):
|
143 |
+
if len(img.shape) == 3:
|
144 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
145 |
+
else:
|
146 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
147 |
+
|
148 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
149 |
+
resized_img = cv2.resize(
|
150 |
+
img,
|
151 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
152 |
+
interpolation=cv2.INTER_LINEAR,
|
153 |
+
).astype(np.uint8)
|
154 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
155 |
+
|
156 |
+
padded_img = padded_img.transpose(swap)
|
157 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
158 |
+
return padded_img, r
|
159 |
+
|
160 |
+
|
161 |
+
class TrainTransform:
|
162 |
+
def __init__(self, max_labels=50, flip_prob=0.5, hsv_prob=1.0):
|
163 |
+
self.max_labels = max_labels
|
164 |
+
self.flip_prob = flip_prob
|
165 |
+
self.hsv_prob = hsv_prob
|
166 |
+
|
167 |
+
def __call__(self, image, targets, input_dim):
|
168 |
+
boxes = targets[:, :4].copy()
|
169 |
+
labels = targets[:, 4].copy()
|
170 |
+
if len(boxes) == 0:
|
171 |
+
targets = np.zeros((self.max_labels, 5), dtype=np.float32)
|
172 |
+
image, r_o = preproc(image, input_dim)
|
173 |
+
return image, targets
|
174 |
+
|
175 |
+
image_o = image.copy()
|
176 |
+
targets_o = targets.copy()
|
177 |
+
height_o, width_o, _ = image_o.shape
|
178 |
+
boxes_o = targets_o[:, :4]
|
179 |
+
labels_o = targets_o[:, 4]
|
180 |
+
# bbox_o: [xyxy] to [c_x,c_y,w,h]
|
181 |
+
boxes_o = xyxy2cxcywh(boxes_o)
|
182 |
+
|
183 |
+
if random.random() < self.hsv_prob:
|
184 |
+
augment_hsv(image)
|
185 |
+
image_t, boxes = _mirror(image, boxes, self.flip_prob)
|
186 |
+
height, width, _ = image_t.shape
|
187 |
+
image_t, r_ = preproc(image_t, input_dim)
|
188 |
+
# boxes [xyxy] 2 [cx,cy,w,h]
|
189 |
+
boxes = xyxy2cxcywh(boxes)
|
190 |
+
boxes *= r_
|
191 |
+
|
192 |
+
mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1
|
193 |
+
boxes_t = boxes[mask_b]
|
194 |
+
labels_t = labels[mask_b]
|
195 |
+
|
196 |
+
if len(boxes_t) == 0:
|
197 |
+
image_t, r_o = preproc(image_o, input_dim)
|
198 |
+
boxes_o *= r_o
|
199 |
+
boxes_t = boxes_o
|
200 |
+
labels_t = labels_o
|
201 |
+
|
202 |
+
labels_t = np.expand_dims(labels_t, 1)
|
203 |
+
|
204 |
+
targets_t = np.hstack((labels_t, boxes_t))
|
205 |
+
padded_labels = np.zeros((self.max_labels, 5))
|
206 |
+
padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[
|
207 |
+
: self.max_labels
|
208 |
+
]
|
209 |
+
padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
|
210 |
+
return image_t, padded_labels
|
211 |
+
|
212 |
+
|
213 |
+
class ValTransform:
|
214 |
+
"""
|
215 |
+
Defines the transformations that should be applied to test PIL image
|
216 |
+
for input into the network
|
217 |
+
|
218 |
+
dimension -> tensorize -> color adj
|
219 |
+
|
220 |
+
Arguments:
|
221 |
+
resize (int): input dimension to SSD
|
222 |
+
rgb_means ((int,int,int)): average RGB of the dataset
|
223 |
+
(104,117,123)
|
224 |
+
swap ((int,int,int)): final order of channels
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
transform (transform) : callable transform to be applied to test/val
|
228 |
+
data
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(self, swap=(2, 0, 1), legacy=False):
|
232 |
+
self.swap = swap
|
233 |
+
self.legacy = legacy
|
234 |
+
|
235 |
+
# assume input is cv2 img for now
|
236 |
+
def __call__(self, img, res, input_size):
|
237 |
+
img, _ = preproc(img, input_size, self.swap)
|
238 |
+
if self.legacy:
|
239 |
+
img = img[::-1, :, :].copy()
|
240 |
+
img /= 255.0
|
241 |
+
img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
|
242 |
+
img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
|
243 |
+
return img, np.zeros((1, 5))
|
yolox/data/data_prefetcher.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class DataPrefetcher:
|
9 |
+
"""
|
10 |
+
DataPrefetcher is inspired by code of following file:
|
11 |
+
https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
|
12 |
+
It could speedup your pytorch dataloader. For more information, please check
|
13 |
+
https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, loader):
|
17 |
+
self.loader = iter(loader)
|
18 |
+
self.stream = torch.cuda.Stream()
|
19 |
+
self.input_cuda = self._input_cuda_for_image
|
20 |
+
self.record_stream = DataPrefetcher._record_stream_for_image
|
21 |
+
self.preload()
|
22 |
+
|
23 |
+
def preload(self):
|
24 |
+
try:
|
25 |
+
self.next_input, self.next_target, _, _ = next(self.loader)
|
26 |
+
except StopIteration:
|
27 |
+
self.next_input = None
|
28 |
+
self.next_target = None
|
29 |
+
return
|
30 |
+
|
31 |
+
with torch.cuda.stream(self.stream):
|
32 |
+
self.input_cuda()
|
33 |
+
self.next_target = self.next_target.cuda(non_blocking=True)
|
34 |
+
|
35 |
+
def next(self):
|
36 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
37 |
+
input = self.next_input
|
38 |
+
target = self.next_target
|
39 |
+
if input is not None:
|
40 |
+
self.record_stream(input)
|
41 |
+
if target is not None:
|
42 |
+
target.record_stream(torch.cuda.current_stream())
|
43 |
+
self.preload()
|
44 |
+
return input, target
|
45 |
+
|
46 |
+
def _input_cuda_for_image(self):
|
47 |
+
self.next_input = self.next_input.cuda(non_blocking=True)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _record_stream_for_image(input):
|
51 |
+
input.record_stream(torch.cuda.current_stream())
|
yolox/data/dataloading.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.dataloader import DataLoader as torchDataLoader
|
13 |
+
from torch.utils.data.dataloader import default_collate
|
14 |
+
|
15 |
+
from .samplers import YoloBatchSampler
|
16 |
+
|
17 |
+
|
18 |
+
def get_yolox_datadir():
|
19 |
+
"""
|
20 |
+
get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set,
|
21 |
+
this function will return value of the environment variable. Otherwise, use data
|
22 |
+
"""
|
23 |
+
yolox_datadir = os.getenv("YOLOX_DATADIR", None)
|
24 |
+
if yolox_datadir is None:
|
25 |
+
import yolox
|
26 |
+
|
27 |
+
yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
|
28 |
+
yolox_datadir = os.path.join(yolox_path, "datasets")
|
29 |
+
return yolox_datadir
|
30 |
+
|
31 |
+
|
32 |
+
class DataLoader(torchDataLoader):
|
33 |
+
"""
|
34 |
+
Lightnet dataloader that enables on the fly resizing of the images.
|
35 |
+
See :class:`torch.utils.data.DataLoader` for more information on the arguments.
|
36 |
+
Check more on the following website:
|
37 |
+
https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, *args, **kwargs):
|
41 |
+
super().__init__(*args, **kwargs)
|
42 |
+
self.__initialized = False
|
43 |
+
shuffle = False
|
44 |
+
batch_sampler = None
|
45 |
+
if len(args) > 5:
|
46 |
+
shuffle = args[2]
|
47 |
+
sampler = args[3]
|
48 |
+
batch_sampler = args[4]
|
49 |
+
elif len(args) > 4:
|
50 |
+
shuffle = args[2]
|
51 |
+
sampler = args[3]
|
52 |
+
if "batch_sampler" in kwargs:
|
53 |
+
batch_sampler = kwargs["batch_sampler"]
|
54 |
+
elif len(args) > 3:
|
55 |
+
shuffle = args[2]
|
56 |
+
if "sampler" in kwargs:
|
57 |
+
sampler = kwargs["sampler"]
|
58 |
+
if "batch_sampler" in kwargs:
|
59 |
+
batch_sampler = kwargs["batch_sampler"]
|
60 |
+
else:
|
61 |
+
if "shuffle" in kwargs:
|
62 |
+
shuffle = kwargs["shuffle"]
|
63 |
+
if "sampler" in kwargs:
|
64 |
+
sampler = kwargs["sampler"]
|
65 |
+
if "batch_sampler" in kwargs:
|
66 |
+
batch_sampler = kwargs["batch_sampler"]
|
67 |
+
|
68 |
+
# Use custom BatchSampler
|
69 |
+
if batch_sampler is None:
|
70 |
+
if sampler is None:
|
71 |
+
if shuffle:
|
72 |
+
sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
|
73 |
+
# sampler = torch.utils.data.DistributedSampler(self.dataset)
|
74 |
+
else:
|
75 |
+
sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
|
76 |
+
batch_sampler = YoloBatchSampler(
|
77 |
+
sampler,
|
78 |
+
self.batch_size,
|
79 |
+
self.drop_last,
|
80 |
+
input_dimension=self.dataset.input_dim,
|
81 |
+
)
|
82 |
+
# batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
|
83 |
+
|
84 |
+
self.batch_sampler = batch_sampler
|
85 |
+
|
86 |
+
self.__initialized = True
|
87 |
+
|
88 |
+
def close_mosaic(self):
|
89 |
+
self.batch_sampler.mosaic = False
|
90 |
+
|
91 |
+
|
92 |
+
def list_collate(batch):
|
93 |
+
"""
|
94 |
+
Function that collates lists or tuples together into one list (of lists/tuples).
|
95 |
+
Use this as the collate function in a Dataloader, if you want to have a list of
|
96 |
+
items as an output, as opposed to tensors (eg. Brambox.boxes).
|
97 |
+
"""
|
98 |
+
items = list(zip(*batch))
|
99 |
+
|
100 |
+
for i in range(len(items)):
|
101 |
+
if isinstance(items[i][0], (list, tuple)):
|
102 |
+
items[i] = list(items[i])
|
103 |
+
else:
|
104 |
+
items[i] = default_collate(items[i])
|
105 |
+
|
106 |
+
return items
|
107 |
+
|
108 |
+
|
109 |
+
def worker_init_reset_seed(worker_id):
|
110 |
+
seed = uuid.uuid4().int % 2**32
|
111 |
+
random.seed(seed)
|
112 |
+
torch.set_rng_state(torch.manual_seed(seed).get_state())
|
113 |
+
np.random.seed(seed)
|
yolox/data/datasets/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
from .coco import COCODataset
|
6 |
+
from .coco_classes import COCO_CLASSES
|
7 |
+
from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset
|
8 |
+
from .mosaicdetection import MosaicDetection
|
9 |
+
from .voc import VOCDetection
|
yolox/data/datasets/coco.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
import copy
|
5 |
+
import os
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
|
11 |
+
from ..dataloading import get_yolox_datadir
|
12 |
+
from .datasets_wrapper import CacheDataset, cache_read_img
|
13 |
+
|
14 |
+
|
15 |
+
def remove_useless_info(coco):
|
16 |
+
"""
|
17 |
+
Remove useless info in coco dataset. COCO object is modified inplace.
|
18 |
+
This function is mainly used for saving memory (save about 30% mem).
|
19 |
+
"""
|
20 |
+
if isinstance(coco, COCO):
|
21 |
+
dataset = coco.dataset
|
22 |
+
dataset.pop("info", None)
|
23 |
+
dataset.pop("licenses", None)
|
24 |
+
for img in dataset["images"]:
|
25 |
+
img.pop("license", None)
|
26 |
+
img.pop("coco_url", None)
|
27 |
+
img.pop("date_captured", None)
|
28 |
+
img.pop("flickr_url", None)
|
29 |
+
if "annotations" in coco.dataset:
|
30 |
+
for anno in coco.dataset["annotations"]:
|
31 |
+
anno.pop("segmentation", None)
|
32 |
+
|
33 |
+
|
34 |
+
class COCODataset(CacheDataset):
|
35 |
+
"""
|
36 |
+
COCO dataset class.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
data_dir=None,
|
42 |
+
json_file="instances_train2017.json",
|
43 |
+
name="train2017",
|
44 |
+
img_size=(416, 416),
|
45 |
+
preproc=None,
|
46 |
+
cache=False,
|
47 |
+
cache_type="ram",
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
COCO dataset initialization. Annotation data are read into memory by COCO API.
|
51 |
+
Args:
|
52 |
+
data_dir (str): dataset root directory
|
53 |
+
json_file (str): COCO json file name
|
54 |
+
name (str): COCO data name (e.g. 'train2017' or 'val2017')
|
55 |
+
img_size (int): target image size after pre-processing
|
56 |
+
preproc: data augmentation strategy
|
57 |
+
"""
|
58 |
+
if data_dir is None:
|
59 |
+
data_dir = os.path.join(get_yolox_datadir(), "COCO")
|
60 |
+
self.data_dir = data_dir
|
61 |
+
self.json_file = json_file
|
62 |
+
|
63 |
+
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
|
64 |
+
remove_useless_info(self.coco)
|
65 |
+
self.ids = self.coco.getImgIds()
|
66 |
+
self.num_imgs = len(self.ids)
|
67 |
+
self.class_ids = sorted(self.coco.getCatIds())
|
68 |
+
self.cats = self.coco.loadCats(self.coco.getCatIds())
|
69 |
+
self._classes = tuple([c["name"] for c in self.cats])
|
70 |
+
self.name = name
|
71 |
+
self.img_size = img_size
|
72 |
+
self.preproc = preproc
|
73 |
+
self.annotations = self._load_coco_annotations()
|
74 |
+
|
75 |
+
path_filename = [os.path.join(name, anno[3]) for anno in self.annotations]
|
76 |
+
super().__init__(
|
77 |
+
input_dimension=img_size,
|
78 |
+
num_imgs=self.num_imgs,
|
79 |
+
data_dir=data_dir,
|
80 |
+
cache_dir_name=f"cache_{name}",
|
81 |
+
path_filename=path_filename,
|
82 |
+
cache=cache,
|
83 |
+
cache_type=cache_type
|
84 |
+
)
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return self.num_imgs
|
88 |
+
|
89 |
+
def _load_coco_annotations(self):
|
90 |
+
return [self.load_anno_from_ids(_ids) for _ids in self.ids]
|
91 |
+
|
92 |
+
def load_anno_from_ids(self, id_):
|
93 |
+
im_ann = self.coco.loadImgs(id_)[0]
|
94 |
+
width = im_ann["width"]
|
95 |
+
height = im_ann["height"]
|
96 |
+
anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
|
97 |
+
annotations = self.coco.loadAnns(anno_ids)
|
98 |
+
objs = []
|
99 |
+
for obj in annotations:
|
100 |
+
x1 = np.max((0, obj["bbox"][0]))
|
101 |
+
y1 = np.max((0, obj["bbox"][1]))
|
102 |
+
x2 = np.min((width, x1 + np.max((0, obj["bbox"][2]))))
|
103 |
+
y2 = np.min((height, y1 + np.max((0, obj["bbox"][3]))))
|
104 |
+
if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
|
105 |
+
obj["clean_bbox"] = [x1, y1, x2, y2]
|
106 |
+
objs.append(obj)
|
107 |
+
|
108 |
+
num_objs = len(objs)
|
109 |
+
|
110 |
+
res = np.zeros((num_objs, 5))
|
111 |
+
for ix, obj in enumerate(objs):
|
112 |
+
cls = self.class_ids.index(obj["category_id"])
|
113 |
+
res[ix, 0:4] = obj["clean_bbox"]
|
114 |
+
res[ix, 4] = cls
|
115 |
+
|
116 |
+
r = min(self.img_size[0] / height, self.img_size[1] / width)
|
117 |
+
res[:, :4] *= r
|
118 |
+
|
119 |
+
img_info = (height, width)
|
120 |
+
resized_info = (int(height * r), int(width * r))
|
121 |
+
|
122 |
+
file_name = (
|
123 |
+
im_ann["file_name"]
|
124 |
+
if "file_name" in im_ann
|
125 |
+
else "{:012}".format(id_) + ".jpg"
|
126 |
+
)
|
127 |
+
|
128 |
+
return (res, img_info, resized_info, file_name)
|
129 |
+
|
130 |
+
def load_anno(self, index):
|
131 |
+
return self.annotations[index][0]
|
132 |
+
|
133 |
+
def load_resized_img(self, index):
|
134 |
+
img = self.load_image(index)
|
135 |
+
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
|
136 |
+
resized_img = cv2.resize(
|
137 |
+
img,
|
138 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
139 |
+
interpolation=cv2.INTER_LINEAR,
|
140 |
+
).astype(np.uint8)
|
141 |
+
return resized_img
|
142 |
+
|
143 |
+
def load_image(self, index):
|
144 |
+
file_name = self.annotations[index][3]
|
145 |
+
|
146 |
+
img_file = os.path.join(self.data_dir, self.name, file_name)
|
147 |
+
|
148 |
+
img = cv2.imread(img_file)
|
149 |
+
assert img is not None, f"file named {img_file} not found"
|
150 |
+
|
151 |
+
return img
|
152 |
+
|
153 |
+
@cache_read_img(use_cache=True)
|
154 |
+
def read_img(self, index):
|
155 |
+
return self.load_resized_img(index)
|
156 |
+
|
157 |
+
def pull_item(self, index):
|
158 |
+
id_ = self.ids[index]
|
159 |
+
label, origin_image_size, _, _ = self.annotations[index]
|
160 |
+
img = self.read_img(index)
|
161 |
+
|
162 |
+
return img, copy.deepcopy(label), origin_image_size, np.array([id_])
|
163 |
+
|
164 |
+
@CacheDataset.mosaic_getitem
|
165 |
+
def __getitem__(self, index):
|
166 |
+
"""
|
167 |
+
One image / label pair for the given index is picked up and pre-processed.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
index (int): data index
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
img (numpy.ndarray): pre-processed image
|
174 |
+
padded_labels (torch.Tensor): pre-processed label data.
|
175 |
+
The shape is :math:`[max_labels, 5]`.
|
176 |
+
each label consists of [class, xc, yc, w, h]:
|
177 |
+
class (float): class index.
|
178 |
+
xc, yc (float) : center of bbox whose values range from 0 to 1.
|
179 |
+
w, h (float) : size of bbox whose values range from 0 to 1.
|
180 |
+
info_img : tuple of h, w.
|
181 |
+
h, w (int): original shape of the image
|
182 |
+
img_id (int): same as the input index. Used for evaluation.
|
183 |
+
"""
|
184 |
+
img, target, img_info, img_id = self.pull_item(index)
|
185 |
+
|
186 |
+
if self.preproc is not None:
|
187 |
+
img, target = self.preproc(img, target, self.input_dim)
|
188 |
+
return img, target, img_info, img_id
|
yolox/data/datasets/coco_classes.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
COCO_CLASSES = ("red", "green", "yellow", "empty", "straight", "left", "right", "other")
|
yolox/data/datasets/datasets_wrapper.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import bisect
|
6 |
+
import copy
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
from abc import ABCMeta, abstractmethod
|
10 |
+
from functools import partial, wraps
|
11 |
+
from multiprocessing.pool import ThreadPool
|
12 |
+
import psutil
|
13 |
+
from loguru import logger
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
|
19 |
+
from torch.utils.data.dataset import Dataset as torchDataset
|
20 |
+
|
21 |
+
|
22 |
+
class ConcatDataset(torchConcatDataset):
|
23 |
+
def __init__(self, datasets):
|
24 |
+
super(ConcatDataset, self).__init__(datasets)
|
25 |
+
if hasattr(self.datasets[0], "input_dim"):
|
26 |
+
self._input_dim = self.datasets[0].input_dim
|
27 |
+
self.input_dim = self.datasets[0].input_dim
|
28 |
+
|
29 |
+
def pull_item(self, idx):
|
30 |
+
if idx < 0:
|
31 |
+
if -idx > len(self):
|
32 |
+
raise ValueError(
|
33 |
+
"absolute value of index should not exceed dataset length"
|
34 |
+
)
|
35 |
+
idx = len(self) + idx
|
36 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
37 |
+
if dataset_idx == 0:
|
38 |
+
sample_idx = idx
|
39 |
+
else:
|
40 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
41 |
+
return self.datasets[dataset_idx].pull_item(sample_idx)
|
42 |
+
|
43 |
+
|
44 |
+
class MixConcatDataset(torchConcatDataset):
|
45 |
+
def __init__(self, datasets):
|
46 |
+
super(MixConcatDataset, self).__init__(datasets)
|
47 |
+
if hasattr(self.datasets[0], "input_dim"):
|
48 |
+
self._input_dim = self.datasets[0].input_dim
|
49 |
+
self.input_dim = self.datasets[0].input_dim
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
|
53 |
+
if not isinstance(index, int):
|
54 |
+
idx = index[1]
|
55 |
+
if idx < 0:
|
56 |
+
if -idx > len(self):
|
57 |
+
raise ValueError(
|
58 |
+
"absolute value of index should not exceed dataset length"
|
59 |
+
)
|
60 |
+
idx = len(self) + idx
|
61 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
62 |
+
if dataset_idx == 0:
|
63 |
+
sample_idx = idx
|
64 |
+
else:
|
65 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
66 |
+
if not isinstance(index, int):
|
67 |
+
index = (index[0], sample_idx, index[2])
|
68 |
+
|
69 |
+
return self.datasets[dataset_idx][index]
|
70 |
+
|
71 |
+
|
72 |
+
class Dataset(torchDataset):
|
73 |
+
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
|
74 |
+
that enables on the fly resizing of the ``input_dim``.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
input_dimension (tuple): (width,height) tuple with default dimensions of the network
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, input_dimension, mosaic=True):
|
81 |
+
super().__init__()
|
82 |
+
self.__input_dim = input_dimension[:2]
|
83 |
+
self.enable_mosaic = mosaic
|
84 |
+
|
85 |
+
@property
|
86 |
+
def input_dim(self):
|
87 |
+
"""
|
88 |
+
Dimension that can be used by transforms to set the correct image size, etc.
|
89 |
+
This allows transforms to have a single source of truth
|
90 |
+
for the input dimension of the network.
|
91 |
+
|
92 |
+
Return:
|
93 |
+
list: Tuple containing the current width,height
|
94 |
+
"""
|
95 |
+
if hasattr(self, "_input_dim"):
|
96 |
+
return self._input_dim
|
97 |
+
return self.__input_dim
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def mosaic_getitem(getitem_fn):
|
101 |
+
"""
|
102 |
+
Decorator method that needs to be used around the ``__getitem__`` method. |br|
|
103 |
+
This decorator enables the closing mosaic
|
104 |
+
|
105 |
+
Example:
|
106 |
+
>>> class CustomSet(ln.data.Dataset):
|
107 |
+
... def __len__(self):
|
108 |
+
... return 10
|
109 |
+
... @ln.data.Dataset.mosaic_getitem
|
110 |
+
... def __getitem__(self, index):
|
111 |
+
... return self.enable_mosaic
|
112 |
+
"""
|
113 |
+
|
114 |
+
@wraps(getitem_fn)
|
115 |
+
def wrapper(self, index):
|
116 |
+
if not isinstance(index, int):
|
117 |
+
self.enable_mosaic = index[0]
|
118 |
+
index = index[1]
|
119 |
+
|
120 |
+
ret_val = getitem_fn(self, index)
|
121 |
+
|
122 |
+
return ret_val
|
123 |
+
|
124 |
+
return wrapper
|
125 |
+
|
126 |
+
|
127 |
+
class CacheDataset(Dataset, metaclass=ABCMeta):
|
128 |
+
""" This class is a subclass of the base :class:`yolox.data.datasets.Dataset`,
|
129 |
+
that enables cache images to ram or disk.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
input_dimension (tuple): (width,height) tuple with default dimensions of the network
|
133 |
+
num_imgs (int): datset size
|
134 |
+
data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`.
|
135 |
+
cache_dir_name (str): the name of the directory to cache to disk,
|
136 |
+
e.g. `"custom_cache"`. The files cached to disk will be saved
|
137 |
+
under `/path/to/COCO/custom_cache`.
|
138 |
+
path_filename (str): a list of paths to the data relative to the `data_dir`,
|
139 |
+
e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`,
|
140 |
+
then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
|
141 |
+
cache (bool): whether to cache the images to ram or disk.
|
142 |
+
cache_type (str): the type of cache,
|
143 |
+
"ram" : Caching imgs to ram for fast training.
|
144 |
+
"disk": Caching imgs to disk for fast training.
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
input_dimension,
|
150 |
+
num_imgs=None,
|
151 |
+
data_dir=None,
|
152 |
+
cache_dir_name=None,
|
153 |
+
path_filename=None,
|
154 |
+
cache=False,
|
155 |
+
cache_type="ram",
|
156 |
+
):
|
157 |
+
super().__init__(input_dimension)
|
158 |
+
self.cache = cache
|
159 |
+
self.cache_type = cache_type
|
160 |
+
|
161 |
+
if self.cache and self.cache_type == "disk":
|
162 |
+
self.cache_dir = os.path.join(data_dir, cache_dir_name)
|
163 |
+
self.path_filename = path_filename
|
164 |
+
|
165 |
+
if self.cache and self.cache_type == "ram":
|
166 |
+
self.imgs = None
|
167 |
+
|
168 |
+
if self.cache:
|
169 |
+
self.cache_images(
|
170 |
+
num_imgs=num_imgs,
|
171 |
+
data_dir=data_dir,
|
172 |
+
cache_dir_name=cache_dir_name,
|
173 |
+
path_filename=path_filename,
|
174 |
+
)
|
175 |
+
|
176 |
+
def __del__(self):
|
177 |
+
if self.cache and self.cache_type == "ram":
|
178 |
+
del self.imgs
|
179 |
+
|
180 |
+
@abstractmethod
|
181 |
+
def read_img(self, index):
|
182 |
+
"""
|
183 |
+
Given index, return the corresponding image
|
184 |
+
|
185 |
+
Args:
|
186 |
+
index (int): image index
|
187 |
+
"""
|
188 |
+
raise NotImplementedError
|
189 |
+
|
190 |
+
def cache_images(
|
191 |
+
self,
|
192 |
+
num_imgs=None,
|
193 |
+
data_dir=None,
|
194 |
+
cache_dir_name=None,
|
195 |
+
path_filename=None,
|
196 |
+
):
|
197 |
+
assert num_imgs is not None, "num_imgs must be specified as the size of the dataset"
|
198 |
+
if self.cache_type == "disk":
|
199 |
+
assert (data_dir and cache_dir_name and path_filename) is not None, \
|
200 |
+
"data_dir, cache_name and path_filename must be specified if cache_type is disk"
|
201 |
+
self.path_filename = path_filename
|
202 |
+
|
203 |
+
mem = psutil.virtual_memory()
|
204 |
+
mem_required = self.cal_cache_occupy(num_imgs)
|
205 |
+
gb = 1 << 30
|
206 |
+
|
207 |
+
if self.cache_type == "ram":
|
208 |
+
if mem_required > mem.available:
|
209 |
+
self.cache = False
|
210 |
+
else:
|
211 |
+
logger.info(
|
212 |
+
f"{mem_required / gb:.1f}GB RAM required, "
|
213 |
+
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
|
214 |
+
f"Since the first thing we do is cache, "
|
215 |
+
f"there is no guarantee that the remaining memory space is sufficient"
|
216 |
+
)
|
217 |
+
|
218 |
+
if self.cache and self.imgs is None:
|
219 |
+
if self.cache_type == 'ram':
|
220 |
+
self.imgs = [None] * num_imgs
|
221 |
+
logger.info("You are using cached images in RAM to accelerate training!")
|
222 |
+
else: # 'disk'
|
223 |
+
if not os.path.exists(self.cache_dir):
|
224 |
+
os.mkdir(self.cache_dir)
|
225 |
+
logger.warning(
|
226 |
+
f"\n*******************************************************************\n"
|
227 |
+
f"You are using cached images in DISK to accelerate training.\n"
|
228 |
+
f"This requires large DISK space.\n"
|
229 |
+
f"Make sure you have {mem_required / gb:.1f} "
|
230 |
+
f"available DISK space for training your dataset.\n"
|
231 |
+
f"*******************************************************************\\n"
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
logger.info(f"Found disk cache at {self.cache_dir}")
|
235 |
+
return
|
236 |
+
|
237 |
+
logger.info(
|
238 |
+
"Caching images...\n"
|
239 |
+
"This might take some time for your dataset"
|
240 |
+
)
|
241 |
+
|
242 |
+
num_threads = min(8, max(1, os.cpu_count() - 1))
|
243 |
+
b = 0
|
244 |
+
load_imgs = ThreadPool(num_threads).imap(
|
245 |
+
partial(self.read_img, use_cache=False),
|
246 |
+
range(num_imgs)
|
247 |
+
)
|
248 |
+
pbar = tqdm(enumerate(load_imgs), total=num_imgs)
|
249 |
+
for i, x in pbar: # x = self.read_img(self, i, use_cache=False)
|
250 |
+
if self.cache_type == 'ram':
|
251 |
+
self.imgs[i] = x
|
252 |
+
else: # 'disk'
|
253 |
+
cache_filename = f'{self.path_filename[i].split(".")[0]}.npy'
|
254 |
+
cache_path_filename = os.path.join(self.cache_dir, cache_filename)
|
255 |
+
os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True)
|
256 |
+
np.save(cache_path_filename, x)
|
257 |
+
b += x.nbytes
|
258 |
+
pbar.desc = \
|
259 |
+
f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})'
|
260 |
+
pbar.close()
|
261 |
+
|
262 |
+
def cal_cache_occupy(self, num_imgs):
|
263 |
+
cache_bytes = 0
|
264 |
+
num_samples = min(num_imgs, 32)
|
265 |
+
for _ in range(num_samples):
|
266 |
+
img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False)
|
267 |
+
cache_bytes += img.nbytes
|
268 |
+
mem_required = cache_bytes * num_imgs / num_samples
|
269 |
+
return mem_required
|
270 |
+
|
271 |
+
|
272 |
+
def cache_read_img(use_cache=True):
|
273 |
+
def decorator(read_img_fn):
|
274 |
+
"""
|
275 |
+
Decorate the read_img function to cache the image
|
276 |
+
|
277 |
+
Args:
|
278 |
+
read_img_fn: read_img function
|
279 |
+
use_cache (bool, optional): For the decorated read_img function,
|
280 |
+
whether to read the image from cache.
|
281 |
+
Defaults to True.
|
282 |
+
"""
|
283 |
+
@wraps(read_img_fn)
|
284 |
+
def wrapper(self, index, use_cache=use_cache):
|
285 |
+
cache = self.cache and use_cache
|
286 |
+
if cache:
|
287 |
+
if self.cache_type == "ram":
|
288 |
+
img = self.imgs[index]
|
289 |
+
img = copy.deepcopy(img)
|
290 |
+
elif self.cache_type == "disk":
|
291 |
+
img = np.load(
|
292 |
+
os.path.join(
|
293 |
+
self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy"))
|
294 |
+
else:
|
295 |
+
raise ValueError(f"Unknown cache type: {self.cache_type}")
|
296 |
+
else:
|
297 |
+
img = read_img_fn(self, index)
|
298 |
+
return img
|
299 |
+
return wrapper
|
300 |
+
return decorator
|
yolox/data/datasets/mosaicdetection.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import random
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from yolox.utils import adjust_box_anns, get_local_rank
|
11 |
+
|
12 |
+
from ..data_augment import random_affine
|
13 |
+
from .datasets_wrapper import Dataset
|
14 |
+
|
15 |
+
|
16 |
+
def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
|
17 |
+
# TODO update doc
|
18 |
+
# index0 to top left part of image
|
19 |
+
if mosaic_index == 0:
|
20 |
+
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
|
21 |
+
small_coord = w - (x2 - x1), h - (y2 - y1), w, h
|
22 |
+
# index1 to top right part of image
|
23 |
+
elif mosaic_index == 1:
|
24 |
+
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
|
25 |
+
small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
|
26 |
+
# index2 to bottom left part of image
|
27 |
+
elif mosaic_index == 2:
|
28 |
+
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
|
29 |
+
small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
|
30 |
+
# index2 to bottom right part of image
|
31 |
+
elif mosaic_index == 3:
|
32 |
+
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
|
33 |
+
small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
|
34 |
+
return (x1, y1, x2, y2), small_coord
|
35 |
+
|
36 |
+
|
37 |
+
class MosaicDetection(Dataset):
|
38 |
+
"""Detection dataset wrapper that performs mixup for normal dataset."""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self, dataset, img_size, mosaic=True, preproc=None,
|
42 |
+
degrees=10.0, translate=0.1, mosaic_scale=(0.5, 1.5),
|
43 |
+
mixup_scale=(0.5, 1.5), shear=2.0, enable_mixup=True,
|
44 |
+
mosaic_prob=1.0, mixup_prob=1.0, *args
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
|
48 |
+
Args:
|
49 |
+
dataset(Dataset) : Pytorch dataset object.
|
50 |
+
img_size (tuple):
|
51 |
+
mosaic (bool): enable mosaic augmentation or not.
|
52 |
+
preproc (func):
|
53 |
+
degrees (float):
|
54 |
+
translate (float):
|
55 |
+
mosaic_scale (tuple):
|
56 |
+
mixup_scale (tuple):
|
57 |
+
shear (float):
|
58 |
+
enable_mixup (bool):
|
59 |
+
*args(tuple) : Additional arguments for mixup random sampler.
|
60 |
+
"""
|
61 |
+
super().__init__(img_size, mosaic=mosaic)
|
62 |
+
self._dataset = dataset
|
63 |
+
self.preproc = preproc
|
64 |
+
self.degrees = degrees
|
65 |
+
self.translate = translate
|
66 |
+
self.scale = mosaic_scale
|
67 |
+
self.shear = shear
|
68 |
+
self.mixup_scale = mixup_scale
|
69 |
+
self.enable_mosaic = mosaic
|
70 |
+
self.enable_mixup = enable_mixup
|
71 |
+
self.mosaic_prob = mosaic_prob
|
72 |
+
self.mixup_prob = mixup_prob
|
73 |
+
self.local_rank = get_local_rank()
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self._dataset)
|
77 |
+
|
78 |
+
@Dataset.mosaic_getitem
|
79 |
+
def __getitem__(self, idx):
|
80 |
+
if self.enable_mosaic and random.random() < self.mosaic_prob:
|
81 |
+
mosaic_labels = []
|
82 |
+
input_dim = self._dataset.input_dim
|
83 |
+
input_h, input_w = input_dim[0], input_dim[1]
|
84 |
+
|
85 |
+
# yc, xc = s, s # mosaic center x, y
|
86 |
+
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
|
87 |
+
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
|
88 |
+
|
89 |
+
# 3 additional image indices
|
90 |
+
indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
|
91 |
+
|
92 |
+
for i_mosaic, index in enumerate(indices):
|
93 |
+
img, _labels, _, img_id = self._dataset.pull_item(index)
|
94 |
+
h0, w0 = img.shape[:2] # orig hw
|
95 |
+
scale = min(1. * input_h / h0, 1. * input_w / w0)
|
96 |
+
img = cv2.resize(
|
97 |
+
img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
|
98 |
+
)
|
99 |
+
# generate output mosaic image
|
100 |
+
(h, w, c) = img.shape[:3]
|
101 |
+
if i_mosaic == 0:
|
102 |
+
mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
|
103 |
+
|
104 |
+
# suffix l means large image, while s means small image in mosaic aug.
|
105 |
+
(l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
|
106 |
+
mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
|
107 |
+
)
|
108 |
+
|
109 |
+
mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
|
110 |
+
padw, padh = l_x1 - s_x1, l_y1 - s_y1
|
111 |
+
|
112 |
+
labels = _labels.copy()
|
113 |
+
# Normalized xywh to pixel xyxy format
|
114 |
+
if _labels.size > 0:
|
115 |
+
labels[:, 0] = scale * _labels[:, 0] + padw
|
116 |
+
labels[:, 1] = scale * _labels[:, 1] + padh
|
117 |
+
labels[:, 2] = scale * _labels[:, 2] + padw
|
118 |
+
labels[:, 3] = scale * _labels[:, 3] + padh
|
119 |
+
mosaic_labels.append(labels)
|
120 |
+
|
121 |
+
if len(mosaic_labels):
|
122 |
+
mosaic_labels = np.concatenate(mosaic_labels, 0)
|
123 |
+
np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
|
124 |
+
np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
|
125 |
+
np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
|
126 |
+
np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
|
127 |
+
|
128 |
+
mosaic_img, mosaic_labels = random_affine(
|
129 |
+
mosaic_img,
|
130 |
+
mosaic_labels,
|
131 |
+
target_size=(input_w, input_h),
|
132 |
+
degrees=self.degrees,
|
133 |
+
translate=self.translate,
|
134 |
+
scales=self.scale,
|
135 |
+
shear=self.shear,
|
136 |
+
)
|
137 |
+
|
138 |
+
# -----------------------------------------------------------------
|
139 |
+
# CopyPaste: https://arxiv.org/abs/2012.07177
|
140 |
+
# -----------------------------------------------------------------
|
141 |
+
if (
|
142 |
+
self.enable_mixup
|
143 |
+
and not len(mosaic_labels) == 0
|
144 |
+
and random.random() < self.mixup_prob
|
145 |
+
):
|
146 |
+
mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
|
147 |
+
mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
|
148 |
+
img_info = (mix_img.shape[1], mix_img.shape[0])
|
149 |
+
|
150 |
+
# -----------------------------------------------------------------
|
151 |
+
# img_info and img_id are not used for training.
|
152 |
+
# They are also hard to be specified on a mosaic image.
|
153 |
+
# -----------------------------------------------------------------
|
154 |
+
return mix_img, padded_labels, img_info, img_id
|
155 |
+
|
156 |
+
else:
|
157 |
+
self._dataset._input_dim = self.input_dim
|
158 |
+
img, label, img_info, img_id = self._dataset.pull_item(idx)
|
159 |
+
img, label = self.preproc(img, label, self.input_dim)
|
160 |
+
return img, label, img_info, img_id
|
161 |
+
|
162 |
+
def mixup(self, origin_img, origin_labels, input_dim):
|
163 |
+
jit_factor = random.uniform(*self.mixup_scale)
|
164 |
+
FLIP = random.uniform(0, 1) > 0.5
|
165 |
+
cp_labels = []
|
166 |
+
while len(cp_labels) == 0:
|
167 |
+
cp_index = random.randint(0, self.__len__() - 1)
|
168 |
+
cp_labels = self._dataset.load_anno(cp_index)
|
169 |
+
img, cp_labels, _, _ = self._dataset.pull_item(cp_index)
|
170 |
+
|
171 |
+
if len(img.shape) == 3:
|
172 |
+
cp_img = np.ones((input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114
|
173 |
+
else:
|
174 |
+
cp_img = np.ones(input_dim, dtype=np.uint8) * 114
|
175 |
+
|
176 |
+
cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1])
|
177 |
+
resized_img = cv2.resize(
|
178 |
+
img,
|
179 |
+
(int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
|
180 |
+
interpolation=cv2.INTER_LINEAR,
|
181 |
+
)
|
182 |
+
|
183 |
+
cp_img[
|
184 |
+
: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)
|
185 |
+
] = resized_img
|
186 |
+
|
187 |
+
cp_img = cv2.resize(
|
188 |
+
cp_img,
|
189 |
+
(int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
|
190 |
+
)
|
191 |
+
cp_scale_ratio *= jit_factor
|
192 |
+
|
193 |
+
if FLIP:
|
194 |
+
cp_img = cp_img[:, ::-1, :]
|
195 |
+
|
196 |
+
origin_h, origin_w = cp_img.shape[:2]
|
197 |
+
target_h, target_w = origin_img.shape[:2]
|
198 |
+
padded_img = np.zeros(
|
199 |
+
(max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8
|
200 |
+
)
|
201 |
+
padded_img[:origin_h, :origin_w] = cp_img
|
202 |
+
|
203 |
+
x_offset, y_offset = 0, 0
|
204 |
+
if padded_img.shape[0] > target_h:
|
205 |
+
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
|
206 |
+
if padded_img.shape[1] > target_w:
|
207 |
+
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
|
208 |
+
padded_cropped_img = padded_img[
|
209 |
+
y_offset: y_offset + target_h, x_offset: x_offset + target_w
|
210 |
+
]
|
211 |
+
|
212 |
+
cp_bboxes_origin_np = adjust_box_anns(
|
213 |
+
cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h
|
214 |
+
)
|
215 |
+
if FLIP:
|
216 |
+
cp_bboxes_origin_np[:, 0::2] = (
|
217 |
+
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
|
218 |
+
)
|
219 |
+
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
|
220 |
+
cp_bboxes_transformed_np[:, 0::2] = np.clip(
|
221 |
+
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
|
222 |
+
)
|
223 |
+
cp_bboxes_transformed_np[:, 1::2] = np.clip(
|
224 |
+
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
|
225 |
+
)
|
226 |
+
|
227 |
+
cls_labels = cp_labels[:, 4:5].copy()
|
228 |
+
box_labels = cp_bboxes_transformed_np
|
229 |
+
labels = np.hstack((box_labels, cls_labels))
|
230 |
+
origin_labels = np.vstack((origin_labels, labels))
|
231 |
+
origin_img = origin_img.astype(np.float32)
|
232 |
+
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
|
233 |
+
|
234 |
+
return origin_img.astype(np.uint8), origin_labels
|
yolox/data/datasets/voc.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Code are based on
|
4 |
+
# https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
|
5 |
+
# Copyright (c) Francisco Massa.
|
6 |
+
# Copyright (c) Ellis Brown, Max deGroot.
|
7 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import os.path
|
11 |
+
import pickle
|
12 |
+
import xml.etree.ElementTree as ET
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from yolox.evaluators.voc_eval import voc_eval
|
18 |
+
|
19 |
+
from .datasets_wrapper import CacheDataset, cache_read_img
|
20 |
+
from .voc_classes import VOC_CLASSES
|
21 |
+
|
22 |
+
|
23 |
+
class AnnotationTransform(object):
|
24 |
+
|
25 |
+
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
|
26 |
+
Initilized with a dictionary lookup of classnames to indexes
|
27 |
+
|
28 |
+
Arguments:
|
29 |
+
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
|
30 |
+
(default: alphabetic indexing of VOC's 20 classes)
|
31 |
+
keep_difficult (bool, optional): keep difficult instances or not
|
32 |
+
(default: False)
|
33 |
+
height (int): height
|
34 |
+
width (int): width
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, class_to_ind=None, keep_difficult=True):
|
38 |
+
self.class_to_ind = class_to_ind or dict(
|
39 |
+
zip(VOC_CLASSES, range(len(VOC_CLASSES)))
|
40 |
+
)
|
41 |
+
self.keep_difficult = keep_difficult
|
42 |
+
|
43 |
+
def __call__(self, target):
|
44 |
+
"""
|
45 |
+
Arguments:
|
46 |
+
target (annotation) : the target annotation to be made usable
|
47 |
+
will be an ET.Element
|
48 |
+
Returns:
|
49 |
+
a list containing lists of bounding boxes [bbox coords, class name]
|
50 |
+
"""
|
51 |
+
res = np.empty((0, 5))
|
52 |
+
for obj in target.iter("object"):
|
53 |
+
difficult = obj.find("difficult")
|
54 |
+
if difficult is not None:
|
55 |
+
difficult = int(difficult.text) == 1
|
56 |
+
else:
|
57 |
+
difficult = False
|
58 |
+
if not self.keep_difficult and difficult:
|
59 |
+
continue
|
60 |
+
name = obj.find("name").text.strip()
|
61 |
+
bbox = obj.find("bndbox")
|
62 |
+
|
63 |
+
pts = ["xmin", "ymin", "xmax", "ymax"]
|
64 |
+
bndbox = []
|
65 |
+
for i, pt in enumerate(pts):
|
66 |
+
cur_pt = int(float(bbox.find(pt).text)) - 1
|
67 |
+
# scale height or width
|
68 |
+
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
|
69 |
+
bndbox.append(cur_pt)
|
70 |
+
label_idx = self.class_to_ind[name]
|
71 |
+
bndbox.append(label_idx)
|
72 |
+
res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind]
|
73 |
+
# img_id = target.find('filename').text[:-4]
|
74 |
+
|
75 |
+
width = int(target.find("size").find("width").text)
|
76 |
+
height = int(target.find("size").find("height").text)
|
77 |
+
img_info = (height, width)
|
78 |
+
|
79 |
+
return res, img_info
|
80 |
+
|
81 |
+
|
82 |
+
class VOCDetection(CacheDataset):
|
83 |
+
|
84 |
+
"""
|
85 |
+
VOC Detection Dataset Object
|
86 |
+
|
87 |
+
input is image, target is annotation
|
88 |
+
|
89 |
+
Args:
|
90 |
+
root (string): filepath to VOCdevkit folder.
|
91 |
+
image_set (string): imageset to use (eg. 'train', 'val', 'test')
|
92 |
+
transform (callable, optional): transformation to perform on the
|
93 |
+
input image
|
94 |
+
target_transform (callable, optional): transformation to perform on the
|
95 |
+
target `annotation`
|
96 |
+
(eg: take in caption string, return tensor of word indices)
|
97 |
+
dataset_name (string, optional): which dataset to load
|
98 |
+
(default: 'VOC2007')
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
data_dir,
|
104 |
+
image_sets=[("2007", "trainval"), ("2012", "trainval")],
|
105 |
+
img_size=(416, 416),
|
106 |
+
preproc=None,
|
107 |
+
target_transform=AnnotationTransform(),
|
108 |
+
dataset_name="VOC0712",
|
109 |
+
cache=False,
|
110 |
+
cache_type="ram",
|
111 |
+
):
|
112 |
+
self.root = data_dir
|
113 |
+
self.image_set = image_sets
|
114 |
+
self.img_size = img_size
|
115 |
+
self.preproc = preproc
|
116 |
+
self.target_transform = target_transform
|
117 |
+
self.name = dataset_name
|
118 |
+
self._annopath = os.path.join("%s", "Annotations", "%s.xml")
|
119 |
+
self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
|
120 |
+
self._classes = VOC_CLASSES
|
121 |
+
self.cats = [
|
122 |
+
{"id": idx, "name": val} for idx, val in enumerate(VOC_CLASSES)
|
123 |
+
]
|
124 |
+
self.class_ids = list(range(len(VOC_CLASSES)))
|
125 |
+
self.ids = list()
|
126 |
+
for (year, name) in image_sets:
|
127 |
+
self._year = year
|
128 |
+
rootpath = os.path.join(self.root, "VOC" + year)
|
129 |
+
for line in open(
|
130 |
+
os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
|
131 |
+
):
|
132 |
+
self.ids.append((rootpath, line.strip()))
|
133 |
+
self.num_imgs = len(self.ids)
|
134 |
+
|
135 |
+
self.annotations = self._load_coco_annotations()
|
136 |
+
|
137 |
+
path_filename = [
|
138 |
+
(self._imgpath % self.ids[i]).split(self.root + "/")[1]
|
139 |
+
for i in range(self.num_imgs)
|
140 |
+
]
|
141 |
+
super().__init__(
|
142 |
+
input_dimension=img_size,
|
143 |
+
num_imgs=self.num_imgs,
|
144 |
+
data_dir=self.root,
|
145 |
+
cache_dir_name=f"cache_{self.name}",
|
146 |
+
path_filename=path_filename,
|
147 |
+
cache=cache,
|
148 |
+
cache_type=cache_type
|
149 |
+
)
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return self.num_imgs
|
153 |
+
|
154 |
+
def _load_coco_annotations(self):
|
155 |
+
return [self.load_anno_from_ids(_ids) for _ids in range(self.num_imgs)]
|
156 |
+
|
157 |
+
def load_anno_from_ids(self, index):
|
158 |
+
img_id = self.ids[index]
|
159 |
+
target = ET.parse(self._annopath % img_id).getroot()
|
160 |
+
|
161 |
+
assert self.target_transform is not None
|
162 |
+
res, img_info = self.target_transform(target)
|
163 |
+
height, width = img_info
|
164 |
+
|
165 |
+
r = min(self.img_size[0] / height, self.img_size[1] / width)
|
166 |
+
res[:, :4] *= r
|
167 |
+
resized_info = (int(height * r), int(width * r))
|
168 |
+
|
169 |
+
return (res, img_info, resized_info)
|
170 |
+
|
171 |
+
def load_anno(self, index):
|
172 |
+
return self.annotations[index][0]
|
173 |
+
|
174 |
+
def load_resized_img(self, index):
|
175 |
+
img = self.load_image(index)
|
176 |
+
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
|
177 |
+
resized_img = cv2.resize(
|
178 |
+
img,
|
179 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
180 |
+
interpolation=cv2.INTER_LINEAR,
|
181 |
+
).astype(np.uint8)
|
182 |
+
|
183 |
+
return resized_img
|
184 |
+
|
185 |
+
def load_image(self, index):
|
186 |
+
img_id = self.ids[index]
|
187 |
+
img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
|
188 |
+
assert img is not None, f"file named {self._imgpath % img_id} not found"
|
189 |
+
|
190 |
+
return img
|
191 |
+
|
192 |
+
@cache_read_img(use_cache=True)
|
193 |
+
def read_img(self, index):
|
194 |
+
return self.load_resized_img(index)
|
195 |
+
|
196 |
+
def pull_item(self, index):
|
197 |
+
"""Returns the original image and target at an index for mixup
|
198 |
+
|
199 |
+
Note: not using self.__getitem__(), as any transformations passed in
|
200 |
+
could mess up this functionality.
|
201 |
+
|
202 |
+
Argument:
|
203 |
+
index (int): index of img to show
|
204 |
+
Return:
|
205 |
+
img, target
|
206 |
+
"""
|
207 |
+
target, img_info, _ = self.annotations[index]
|
208 |
+
img = self.read_img(index)
|
209 |
+
|
210 |
+
return img, target, img_info, index
|
211 |
+
|
212 |
+
@CacheDataset.mosaic_getitem
|
213 |
+
def __getitem__(self, index):
|
214 |
+
img, target, img_info, img_id = self.pull_item(index)
|
215 |
+
|
216 |
+
if self.preproc is not None:
|
217 |
+
img, target = self.preproc(img, target, self.input_dim)
|
218 |
+
|
219 |
+
return img, target, img_info, img_id
|
220 |
+
|
221 |
+
def evaluate_detections(self, all_boxes, output_dir=None):
|
222 |
+
"""
|
223 |
+
all_boxes is a list of length number-of-classes.
|
224 |
+
Each list element is a list of length number-of-images.
|
225 |
+
Each of those list elements is either an empty list []
|
226 |
+
or a numpy array of detection.
|
227 |
+
|
228 |
+
all_boxes[class][image] = [] or np.array of shape #dets x 5
|
229 |
+
"""
|
230 |
+
self._write_voc_results_file(all_boxes)
|
231 |
+
IouTh = np.linspace(
|
232 |
+
0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
|
233 |
+
)
|
234 |
+
mAPs = []
|
235 |
+
for iou in IouTh:
|
236 |
+
mAP = self._do_python_eval(output_dir, iou)
|
237 |
+
mAPs.append(mAP)
|
238 |
+
|
239 |
+
print("--------------------------------------------------------------")
|
240 |
+
print("map_5095:", np.mean(mAPs))
|
241 |
+
print("map_50:", mAPs[0])
|
242 |
+
print("--------------------------------------------------------------")
|
243 |
+
return np.mean(mAPs), mAPs[0]
|
244 |
+
|
245 |
+
def _get_voc_results_file_template(self):
|
246 |
+
filename = "comp4_det_test" + "_{:s}.txt"
|
247 |
+
filedir = os.path.join(self.root, "results", "VOC" + self._year, "Main")
|
248 |
+
if not os.path.exists(filedir):
|
249 |
+
os.makedirs(filedir)
|
250 |
+
path = os.path.join(filedir, filename)
|
251 |
+
return path
|
252 |
+
|
253 |
+
def _write_voc_results_file(self, all_boxes):
|
254 |
+
for cls_ind, cls in enumerate(VOC_CLASSES):
|
255 |
+
cls_ind = cls_ind
|
256 |
+
if cls == "__background__":
|
257 |
+
continue
|
258 |
+
print("Writing {} VOC results file".format(cls))
|
259 |
+
filename = self._get_voc_results_file_template().format(cls)
|
260 |
+
with open(filename, "wt") as f:
|
261 |
+
for im_ind, index in enumerate(self.ids):
|
262 |
+
index = index[1]
|
263 |
+
dets = all_boxes[cls_ind][im_ind]
|
264 |
+
if dets == []:
|
265 |
+
continue
|
266 |
+
for k in range(dets.shape[0]):
|
267 |
+
f.write(
|
268 |
+
"{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
|
269 |
+
index,
|
270 |
+
dets[k, -1],
|
271 |
+
dets[k, 0] + 1,
|
272 |
+
dets[k, 1] + 1,
|
273 |
+
dets[k, 2] + 1,
|
274 |
+
dets[k, 3] + 1,
|
275 |
+
)
|
276 |
+
)
|
277 |
+
|
278 |
+
def _do_python_eval(self, output_dir="output", iou=0.5):
|
279 |
+
rootpath = os.path.join(self.root, "VOC" + self._year)
|
280 |
+
name = self.image_set[0][1]
|
281 |
+
annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")
|
282 |
+
imagesetfile = os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
|
283 |
+
cachedir = os.path.join(
|
284 |
+
self.root, "annotations_cache", "VOC" + self._year, name
|
285 |
+
)
|
286 |
+
if not os.path.exists(cachedir):
|
287 |
+
os.makedirs(cachedir)
|
288 |
+
aps = []
|
289 |
+
# The PASCAL VOC metric changed in 2010
|
290 |
+
use_07_metric = True if int(self._year) < 2010 else False
|
291 |
+
print("Eval IoU : {:.2f}".format(iou))
|
292 |
+
if output_dir is not None and not os.path.isdir(output_dir):
|
293 |
+
os.mkdir(output_dir)
|
294 |
+
for i, cls in enumerate(VOC_CLASSES):
|
295 |
+
|
296 |
+
if cls == "__background__":
|
297 |
+
continue
|
298 |
+
|
299 |
+
filename = self._get_voc_results_file_template().format(cls)
|
300 |
+
rec, prec, ap = voc_eval(
|
301 |
+
filename,
|
302 |
+
annopath,
|
303 |
+
imagesetfile,
|
304 |
+
cls,
|
305 |
+
cachedir,
|
306 |
+
ovthresh=iou,
|
307 |
+
use_07_metric=use_07_metric,
|
308 |
+
)
|
309 |
+
aps += [ap]
|
310 |
+
if iou == 0.5:
|
311 |
+
print("AP for {} = {:.4f}".format(cls, ap))
|
312 |
+
if output_dir is not None:
|
313 |
+
with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
|
314 |
+
pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
|
315 |
+
if iou == 0.5:
|
316 |
+
print("Mean AP = {:.4f}".format(np.mean(aps)))
|
317 |
+
print("~~~~~~~~")
|
318 |
+
print("Results:")
|
319 |
+
for ap in aps:
|
320 |
+
print("{:.3f}".format(ap))
|
321 |
+
print("{:.3f}".format(np.mean(aps)))
|
322 |
+
print("~~~~~~~~")
|
323 |
+
print("")
|
324 |
+
print("--------------------------------------------------------------")
|
325 |
+
print("Results computed with the **unofficial** Python eval code.")
|
326 |
+
print("Results should be very close to the official MATLAB eval code.")
|
327 |
+
print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
|
328 |
+
print("-- Thanks, The Management")
|
329 |
+
print("--------------------------------------------------------------")
|
330 |
+
|
331 |
+
return np.mean(aps)
|
yolox/data/datasets/voc_classes.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
# VOC_CLASSES = ( '__background__', # always index 0
|
6 |
+
VOC_CLASSES = (
|
7 |
+
"aeroplane",
|
8 |
+
"bicycle",
|
9 |
+
"bird",
|
10 |
+
"boat",
|
11 |
+
"bottle",
|
12 |
+
"bus",
|
13 |
+
"car",
|
14 |
+
"cat",
|
15 |
+
"chair",
|
16 |
+
"cow",
|
17 |
+
"diningtable",
|
18 |
+
"dog",
|
19 |
+
"horse",
|
20 |
+
"motorbike",
|
21 |
+
"person",
|
22 |
+
"pottedplant",
|
23 |
+
"sheep",
|
24 |
+
"sofa",
|
25 |
+
"train",
|
26 |
+
"tvmonitor",
|
27 |
+
)
|
yolox/data/samplers.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import itertools
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.utils.data.sampler import BatchSampler as torchBatchSampler
|
11 |
+
from torch.utils.data.sampler import Sampler
|
12 |
+
|
13 |
+
|
14 |
+
class YoloBatchSampler(torchBatchSampler):
|
15 |
+
"""
|
16 |
+
This batch sampler will generate mini-batches of (mosaic, index) tuples from another sampler.
|
17 |
+
It works just like the :class:`torch.utils.data.sampler.BatchSampler`,
|
18 |
+
but it will turn on/off the mosaic aug.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, *args, mosaic=True, **kwargs):
|
22 |
+
super().__init__(*args, **kwargs)
|
23 |
+
self.mosaic = mosaic
|
24 |
+
|
25 |
+
def __iter__(self):
|
26 |
+
for batch in super().__iter__():
|
27 |
+
yield [(self.mosaic, idx) for idx in batch]
|
28 |
+
|
29 |
+
|
30 |
+
class InfiniteSampler(Sampler):
|
31 |
+
"""
|
32 |
+
In training, we only care about the "infinite stream" of training data.
|
33 |
+
So this sampler produces an infinite stream of indices and
|
34 |
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
35 |
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
36 |
+
where `indices` is an infinite stream of indices consisting of
|
37 |
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
38 |
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
size: int,
|
44 |
+
shuffle: bool = True,
|
45 |
+
seed: Optional[int] = 0,
|
46 |
+
rank=0,
|
47 |
+
world_size=1,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
size (int): the total number of data of the underlying dataset to sample from
|
52 |
+
shuffle (bool): whether to shuffle the indices or not
|
53 |
+
seed (int): the initial seed of the shuffle. Must be the same
|
54 |
+
across all workers. If None, will use a random seed shared
|
55 |
+
among workers (require synchronization among all workers).
|
56 |
+
"""
|
57 |
+
self._size = size
|
58 |
+
assert size > 0
|
59 |
+
self._shuffle = shuffle
|
60 |
+
self._seed = int(seed)
|
61 |
+
|
62 |
+
if dist.is_available() and dist.is_initialized():
|
63 |
+
self._rank = dist.get_rank()
|
64 |
+
self._world_size = dist.get_world_size()
|
65 |
+
else:
|
66 |
+
self._rank = rank
|
67 |
+
self._world_size = world_size
|
68 |
+
|
69 |
+
def __iter__(self):
|
70 |
+
start = self._rank
|
71 |
+
yield from itertools.islice(
|
72 |
+
self._infinite_indices(), start, None, self._world_size
|
73 |
+
)
|
74 |
+
|
75 |
+
def _infinite_indices(self):
|
76 |
+
g = torch.Generator()
|
77 |
+
g.manual_seed(self._seed)
|
78 |
+
while True:
|
79 |
+
if self._shuffle:
|
80 |
+
yield from torch.randperm(self._size, generator=g)
|
81 |
+
else:
|
82 |
+
yield from torch.arange(self._size)
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return self._size // self._world_size
|
yolox/evaluators/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
from .coco_evaluator import COCOEvaluator
|
6 |
+
from .voc_evaluator import VOCEvaluator
|
yolox/evaluators/coco_evaluator.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import contextlib
|
6 |
+
import io
|
7 |
+
import itertools
|
8 |
+
import json
|
9 |
+
import tempfile
|
10 |
+
import time
|
11 |
+
from collections import ChainMap, defaultdict
|
12 |
+
from loguru import logger
|
13 |
+
from tabulate import tabulate
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from yolox.data.datasets import COCO_CLASSES
|
21 |
+
from yolox.utils import (
|
22 |
+
gather,
|
23 |
+
is_main_process,
|
24 |
+
postprocess,
|
25 |
+
synchronize,
|
26 |
+
time_synchronized,
|
27 |
+
xyxy2xywh
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def per_class_AR_table(coco_eval, class_names=COCO_CLASSES, headers=["class", "AR"], colums=6):
|
32 |
+
per_class_AR = {}
|
33 |
+
recalls = coco_eval.eval["recall"]
|
34 |
+
# dimension of recalls: [TxKxAxM]
|
35 |
+
# recall has dims (iou, cls, area range, max dets)
|
36 |
+
assert len(class_names) == recalls.shape[1]
|
37 |
+
|
38 |
+
for idx, name in enumerate(class_names):
|
39 |
+
recall = recalls[:, idx, 0, -1]
|
40 |
+
recall = recall[recall > -1]
|
41 |
+
ar = np.mean(recall) if recall.size else float("nan")
|
42 |
+
per_class_AR[name] = float(ar * 100)
|
43 |
+
|
44 |
+
num_cols = min(colums, len(per_class_AR) * len(headers))
|
45 |
+
result_pair = [x for pair in per_class_AR.items() for x in pair]
|
46 |
+
row_pair = itertools.zip_longest(*[result_pair[i::num_cols] for i in range(num_cols)])
|
47 |
+
table_headers = headers * (num_cols // len(headers))
|
48 |
+
table = tabulate(
|
49 |
+
row_pair, tablefmt="pipe", floatfmt=".3f", headers=table_headers, numalign="left",
|
50 |
+
)
|
51 |
+
return table
|
52 |
+
|
53 |
+
|
54 |
+
def per_class_AP_table(coco_eval, class_names=COCO_CLASSES, headers=["class", "AP"], colums=6):
|
55 |
+
per_class_AP = {}
|
56 |
+
precisions = coco_eval.eval["precision"]
|
57 |
+
# dimension of precisions: [TxRxKxAxM]
|
58 |
+
# precision has dims (iou, recall, cls, area range, max dets)
|
59 |
+
assert len(class_names) == precisions.shape[2]
|
60 |
+
|
61 |
+
for idx, name in enumerate(class_names):
|
62 |
+
# area range index 0: all area ranges
|
63 |
+
# max dets index -1: typically 100 per image
|
64 |
+
precision = precisions[:, :, idx, 0, -1]
|
65 |
+
precision = precision[precision > -1]
|
66 |
+
ap = np.mean(precision) if precision.size else float("nan")
|
67 |
+
per_class_AP[name] = float(ap * 100)
|
68 |
+
|
69 |
+
num_cols = min(colums, len(per_class_AP) * len(headers))
|
70 |
+
result_pair = [x for pair in per_class_AP.items() for x in pair]
|
71 |
+
row_pair = itertools.zip_longest(*[result_pair[i::num_cols] for i in range(num_cols)])
|
72 |
+
table_headers = headers * (num_cols // len(headers))
|
73 |
+
table = tabulate(
|
74 |
+
row_pair, tablefmt="pipe", floatfmt=".3f", headers=table_headers, numalign="left",
|
75 |
+
)
|
76 |
+
return table
|
77 |
+
|
78 |
+
|
79 |
+
class COCOEvaluator:
|
80 |
+
"""
|
81 |
+
COCO AP Evaluation class. All the data in the val2017 dataset are processed
|
82 |
+
and evaluated by COCO API.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
dataloader,
|
88 |
+
img_size: int,
|
89 |
+
confthre: float,
|
90 |
+
nmsthre: float,
|
91 |
+
num_classes: int,
|
92 |
+
testdev: bool = False,
|
93 |
+
per_class_AP: bool = True,
|
94 |
+
per_class_AR: bool = True,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
dataloader (Dataloader): evaluate dataloader.
|
99 |
+
img_size: image size after preprocess. images are resized
|
100 |
+
to squares whose shape is (img_size, img_size).
|
101 |
+
confthre: confidence threshold ranging from 0 to 1, which
|
102 |
+
is defined in the config file.
|
103 |
+
nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
|
104 |
+
per_class_AP: Show per class AP during evalution or not. Default to True.
|
105 |
+
per_class_AR: Show per class AR during evalution or not. Default to True.
|
106 |
+
"""
|
107 |
+
self.dataloader = dataloader
|
108 |
+
self.img_size = img_size
|
109 |
+
self.confthre = confthre
|
110 |
+
self.nmsthre = nmsthre
|
111 |
+
self.num_classes = num_classes
|
112 |
+
self.testdev = testdev
|
113 |
+
self.per_class_AP = per_class_AP
|
114 |
+
self.per_class_AR = per_class_AR
|
115 |
+
|
116 |
+
def evaluate(
|
117 |
+
self, model, distributed=False, half=False, trt_file=None,
|
118 |
+
decoder=None, test_size=None, return_outputs=False
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
|
122 |
+
and the results are evaluated by COCO API.
|
123 |
+
|
124 |
+
NOTE: This function will change training mode to False, please save states if needed.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
model : model to evaluate.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
ap50_95 (float) : COCO AP of IoU=50:95
|
131 |
+
ap50 (float) : COCO AP of IoU=50
|
132 |
+
summary (sr): summary info of evaluation.
|
133 |
+
"""
|
134 |
+
# TODO half to amp_test
|
135 |
+
tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
|
136 |
+
model = model.eval()
|
137 |
+
if half:
|
138 |
+
model = model.half()
|
139 |
+
ids = []
|
140 |
+
data_list = []
|
141 |
+
output_data = defaultdict()
|
142 |
+
progress_bar = tqdm if is_main_process() else iter
|
143 |
+
|
144 |
+
inference_time = 0
|
145 |
+
nms_time = 0
|
146 |
+
n_samples = max(len(self.dataloader) - 1, 1)
|
147 |
+
|
148 |
+
if trt_file is not None:
|
149 |
+
from torch2trt import TRTModule
|
150 |
+
|
151 |
+
model_trt = TRTModule()
|
152 |
+
model_trt.load_state_dict(torch.load(trt_file))
|
153 |
+
|
154 |
+
x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
|
155 |
+
model(x)
|
156 |
+
model = model_trt
|
157 |
+
|
158 |
+
for cur_iter, (imgs, _, info_imgs, ids) in enumerate(
|
159 |
+
progress_bar(self.dataloader)
|
160 |
+
):
|
161 |
+
with torch.no_grad():
|
162 |
+
imgs = imgs.type(tensor_type)
|
163 |
+
|
164 |
+
# skip the last iters since batchsize might be not enough for batch inference
|
165 |
+
is_time_record = cur_iter < len(self.dataloader) - 1
|
166 |
+
if is_time_record:
|
167 |
+
start = time.time()
|
168 |
+
|
169 |
+
outputs = model(imgs)
|
170 |
+
if decoder is not None:
|
171 |
+
outputs = decoder(outputs, dtype=outputs.type())
|
172 |
+
|
173 |
+
if is_time_record:
|
174 |
+
infer_end = time_synchronized()
|
175 |
+
inference_time += infer_end - start
|
176 |
+
|
177 |
+
outputs = postprocess(
|
178 |
+
outputs, self.num_classes, self.confthre, self.nmsthre
|
179 |
+
)
|
180 |
+
if is_time_record:
|
181 |
+
nms_end = time_synchronized()
|
182 |
+
nms_time += nms_end - infer_end
|
183 |
+
|
184 |
+
data_list_elem, image_wise_data = self.convert_to_coco_format(
|
185 |
+
outputs, info_imgs, ids, return_outputs=True)
|
186 |
+
data_list.extend(data_list_elem)
|
187 |
+
output_data.update(image_wise_data)
|
188 |
+
|
189 |
+
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
|
190 |
+
if distributed:
|
191 |
+
# different process/device might have different speed,
|
192 |
+
# to make sure the process will not be stucked, sync func is used here.
|
193 |
+
synchronize()
|
194 |
+
data_list = gather(data_list, dst=0)
|
195 |
+
output_data = gather(output_data, dst=0)
|
196 |
+
data_list = list(itertools.chain(*data_list))
|
197 |
+
output_data = dict(ChainMap(*output_data))
|
198 |
+
torch.distributed.reduce(statistics, dst=0)
|
199 |
+
|
200 |
+
eval_results = self.evaluate_prediction(data_list, statistics)
|
201 |
+
synchronize()
|
202 |
+
|
203 |
+
if return_outputs:
|
204 |
+
return eval_results, output_data
|
205 |
+
return eval_results
|
206 |
+
|
207 |
+
def convert_to_coco_format(self, outputs, info_imgs, ids, return_outputs=False):
|
208 |
+
data_list = []
|
209 |
+
image_wise_data = defaultdict(dict)
|
210 |
+
for (output, img_h, img_w, img_id) in zip(
|
211 |
+
outputs, info_imgs[0], info_imgs[1], ids
|
212 |
+
):
|
213 |
+
if output is None:
|
214 |
+
continue
|
215 |
+
output = output.cpu()
|
216 |
+
|
217 |
+
bboxes = output[:, 0:4]
|
218 |
+
|
219 |
+
# preprocessing: resize
|
220 |
+
scale = min(
|
221 |
+
self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
|
222 |
+
)
|
223 |
+
bboxes /= scale
|
224 |
+
cls = output[:, 6]
|
225 |
+
scores = output[:, 4] * output[:, 5]
|
226 |
+
|
227 |
+
image_wise_data.update({
|
228 |
+
int(img_id): {
|
229 |
+
"bboxes": [box.numpy().tolist() for box in bboxes],
|
230 |
+
"scores": [score.numpy().item() for score in scores],
|
231 |
+
"categories": [
|
232 |
+
self.dataloader.dataset.class_ids[int(cls[ind])]
|
233 |
+
for ind in range(bboxes.shape[0])
|
234 |
+
],
|
235 |
+
}
|
236 |
+
})
|
237 |
+
|
238 |
+
bboxes = xyxy2xywh(bboxes)
|
239 |
+
|
240 |
+
for ind in range(bboxes.shape[0]):
|
241 |
+
label = self.dataloader.dataset.class_ids[int(cls[ind])]
|
242 |
+
pred_data = {
|
243 |
+
"image_id": int(img_id),
|
244 |
+
"category_id": label,
|
245 |
+
"bbox": bboxes[ind].numpy().tolist(),
|
246 |
+
"score": scores[ind].numpy().item(),
|
247 |
+
"segmentation": [],
|
248 |
+
} # COCO json format
|
249 |
+
data_list.append(pred_data)
|
250 |
+
|
251 |
+
if return_outputs:
|
252 |
+
return data_list, image_wise_data
|
253 |
+
return data_list
|
254 |
+
|
255 |
+
def evaluate_prediction(self, data_dict, statistics):
|
256 |
+
if not is_main_process():
|
257 |
+
return 0, 0, None
|
258 |
+
|
259 |
+
logger.info("Evaluate in main process...")
|
260 |
+
|
261 |
+
annType = ["segm", "bbox", "keypoints"]
|
262 |
+
|
263 |
+
inference_time = statistics[0].item()
|
264 |
+
nms_time = statistics[1].item()
|
265 |
+
n_samples = statistics[2].item()
|
266 |
+
|
267 |
+
a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
|
268 |
+
a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
|
269 |
+
|
270 |
+
time_info = ", ".join(
|
271 |
+
[
|
272 |
+
"Average {} time: {:.2f} ms".format(k, v)
|
273 |
+
for k, v in zip(
|
274 |
+
["forward", "NMS", "inference"],
|
275 |
+
[a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
|
276 |
+
)
|
277 |
+
]
|
278 |
+
)
|
279 |
+
|
280 |
+
info = time_info + "\n"
|
281 |
+
|
282 |
+
# Evaluate the Dt (detection) json comparing with the ground truth
|
283 |
+
if len(data_dict) > 0:
|
284 |
+
cocoGt = self.dataloader.dataset.coco
|
285 |
+
# TODO: since pycocotools can't process dict in py36, write data to json file.
|
286 |
+
if self.testdev:
|
287 |
+
json.dump(data_dict, open("./yolox_testdev_2017.json", "w"))
|
288 |
+
cocoDt = cocoGt.loadRes("./yolox_testdev_2017.json")
|
289 |
+
else:
|
290 |
+
_, tmp = tempfile.mkstemp()
|
291 |
+
json.dump(data_dict, open(tmp, "w"))
|
292 |
+
cocoDt = cocoGt.loadRes(tmp)
|
293 |
+
try:
|
294 |
+
from yolox.layers import COCOeval_opt as COCOeval
|
295 |
+
except ImportError:
|
296 |
+
from pycocotools.cocoeval import COCOeval
|
297 |
+
|
298 |
+
logger.warning("Use standard COCOeval.")
|
299 |
+
|
300 |
+
cocoEval = COCOeval(cocoGt, cocoDt, annType[1])
|
301 |
+
cocoEval.evaluate()
|
302 |
+
cocoEval.accumulate()
|
303 |
+
redirect_string = io.StringIO()
|
304 |
+
with contextlib.redirect_stdout(redirect_string):
|
305 |
+
cocoEval.summarize()
|
306 |
+
info += redirect_string.getvalue()
|
307 |
+
cat_ids = list(cocoGt.cats.keys())
|
308 |
+
cat_names = [cocoGt.cats[catId]['name'] for catId in sorted(cat_ids)]
|
309 |
+
if self.per_class_AP:
|
310 |
+
AP_table = per_class_AP_table(cocoEval, class_names=cat_names)
|
311 |
+
info += "per class AP:\n" + AP_table + "\n"
|
312 |
+
if self.per_class_AR:
|
313 |
+
AR_table = per_class_AR_table(cocoEval, class_names=cat_names)
|
314 |
+
info += "per class AR:\n" + AR_table + "\n"
|
315 |
+
return cocoEval.stats[0], cocoEval.stats[1], info
|
316 |
+
else:
|
317 |
+
return 0, 0, info
|
yolox/evaluators/voc_eval.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Code are based on
|
3 |
+
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
|
4 |
+
# Copyright (c) Bharath Hariharan.
|
5 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import xml.etree.ElementTree as ET
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def parse_rec(filename):
|
15 |
+
"""Parse a PASCAL VOC xml file"""
|
16 |
+
tree = ET.parse(filename)
|
17 |
+
objects = []
|
18 |
+
for obj in tree.findall("object"):
|
19 |
+
obj_struct = {}
|
20 |
+
obj_struct["name"] = obj.find("name").text
|
21 |
+
obj_struct["pose"] = obj.find("pose").text
|
22 |
+
obj_struct["truncated"] = int(obj.find("truncated").text)
|
23 |
+
obj_struct["difficult"] = int(obj.find("difficult").text)
|
24 |
+
bbox = obj.find("bndbox")
|
25 |
+
obj_struct["bbox"] = [
|
26 |
+
int(bbox.find("xmin").text),
|
27 |
+
int(bbox.find("ymin").text),
|
28 |
+
int(bbox.find("xmax").text),
|
29 |
+
int(bbox.find("ymax").text),
|
30 |
+
]
|
31 |
+
objects.append(obj_struct)
|
32 |
+
|
33 |
+
return objects
|
34 |
+
|
35 |
+
|
36 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
37 |
+
"""
|
38 |
+
Compute VOC AP given precision and recall.
|
39 |
+
If use_07_metric is true, uses the
|
40 |
+
VOC 07 11 point method (default:False).
|
41 |
+
"""
|
42 |
+
if use_07_metric:
|
43 |
+
# 11 point metric
|
44 |
+
ap = 0.0
|
45 |
+
for t in np.arange(0.0, 1.1, 0.1):
|
46 |
+
if np.sum(rec >= t) == 0:
|
47 |
+
p = 0
|
48 |
+
else:
|
49 |
+
p = np.max(prec[rec >= t])
|
50 |
+
ap = ap + p / 11.0
|
51 |
+
else:
|
52 |
+
# correct AP calculation
|
53 |
+
# first append sentinel values at the end
|
54 |
+
mrec = np.concatenate(([0.0], rec, [1.0]))
|
55 |
+
mpre = np.concatenate(([0.0], prec, [0.0]))
|
56 |
+
|
57 |
+
# compute the precision envelope
|
58 |
+
for i in range(mpre.size - 1, 0, -1):
|
59 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
60 |
+
|
61 |
+
# to calculate area under PR curve, look for points
|
62 |
+
# where X axis (recall) changes value
|
63 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
64 |
+
|
65 |
+
# and sum (\Delta recall) * prec
|
66 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
67 |
+
return ap
|
68 |
+
|
69 |
+
|
70 |
+
def voc_eval(
|
71 |
+
detpath,
|
72 |
+
annopath,
|
73 |
+
imagesetfile,
|
74 |
+
classname,
|
75 |
+
cachedir,
|
76 |
+
ovthresh=0.5,
|
77 |
+
use_07_metric=False,
|
78 |
+
):
|
79 |
+
# first load gt
|
80 |
+
if not os.path.isdir(cachedir):
|
81 |
+
os.mkdir(cachedir)
|
82 |
+
cachefile = os.path.join(cachedir, "annots.pkl")
|
83 |
+
# read list of images
|
84 |
+
with open(imagesetfile, "r") as f:
|
85 |
+
lines = f.readlines()
|
86 |
+
imagenames = [x.strip() for x in lines]
|
87 |
+
|
88 |
+
if not os.path.isfile(cachefile):
|
89 |
+
# load annots
|
90 |
+
recs = {}
|
91 |
+
for i, imagename in enumerate(imagenames):
|
92 |
+
recs[imagename] = parse_rec(annopath.format(imagename))
|
93 |
+
if i % 100 == 0:
|
94 |
+
print(f"Reading annotation for {i + 1}/{len(imagenames)}")
|
95 |
+
# save
|
96 |
+
print(f"Saving cached annotations to {cachefile}")
|
97 |
+
with open(cachefile, "wb") as f:
|
98 |
+
pickle.dump(recs, f)
|
99 |
+
else:
|
100 |
+
# load
|
101 |
+
with open(cachefile, "rb") as f:
|
102 |
+
recs = pickle.load(f)
|
103 |
+
|
104 |
+
# extract gt objects for this class
|
105 |
+
class_recs = {}
|
106 |
+
npos = 0
|
107 |
+
for imagename in imagenames:
|
108 |
+
R = [obj for obj in recs[imagename] if obj["name"] == classname]
|
109 |
+
bbox = np.array([x["bbox"] for x in R])
|
110 |
+
difficult = np.array([x["difficult"] for x in R]).astype(bool)
|
111 |
+
det = [False] * len(R)
|
112 |
+
npos = npos + sum(~difficult)
|
113 |
+
class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
|
114 |
+
|
115 |
+
# read dets
|
116 |
+
detfile = detpath.format(classname)
|
117 |
+
with open(detfile, "r") as f:
|
118 |
+
lines = f.readlines()
|
119 |
+
|
120 |
+
if len(lines) == 0:
|
121 |
+
return 0, 0, 0
|
122 |
+
|
123 |
+
splitlines = [x.strip().split(" ") for x in lines]
|
124 |
+
image_ids = [x[0] for x in splitlines]
|
125 |
+
confidence = np.array([float(x[1]) for x in splitlines])
|
126 |
+
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
|
127 |
+
|
128 |
+
# sort by confidence
|
129 |
+
sorted_ind = np.argsort(-confidence)
|
130 |
+
BB = BB[sorted_ind, :]
|
131 |
+
image_ids = [image_ids[x] for x in sorted_ind]
|
132 |
+
|
133 |
+
# go down dets and mark TPs and FPs
|
134 |
+
nd = len(image_ids)
|
135 |
+
tp = np.zeros(nd)
|
136 |
+
fp = np.zeros(nd)
|
137 |
+
for d in range(nd):
|
138 |
+
R = class_recs[image_ids[d]]
|
139 |
+
bb = BB[d, :].astype(float)
|
140 |
+
ovmax = -np.inf
|
141 |
+
BBGT = R["bbox"].astype(float)
|
142 |
+
|
143 |
+
if BBGT.size > 0:
|
144 |
+
# compute overlaps
|
145 |
+
# intersection
|
146 |
+
ixmin = np.maximum(BBGT[:, 0], bb[0])
|
147 |
+
iymin = np.maximum(BBGT[:, 1], bb[1])
|
148 |
+
ixmax = np.minimum(BBGT[:, 2], bb[2])
|
149 |
+
iymax = np.minimum(BBGT[:, 3], bb[3])
|
150 |
+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
|
151 |
+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
|
152 |
+
inters = iw * ih
|
153 |
+
|
154 |
+
# union
|
155 |
+
uni = (
|
156 |
+
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
|
157 |
+
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) - inters
|
158 |
+
)
|
159 |
+
|
160 |
+
overlaps = inters / uni
|
161 |
+
ovmax = np.max(overlaps)
|
162 |
+
jmax = np.argmax(overlaps)
|
163 |
+
|
164 |
+
if ovmax > ovthresh:
|
165 |
+
if not R["difficult"][jmax]:
|
166 |
+
if not R["det"][jmax]:
|
167 |
+
tp[d] = 1.0
|
168 |
+
R["det"][jmax] = 1
|
169 |
+
else:
|
170 |
+
fp[d] = 1.0
|
171 |
+
else:
|
172 |
+
fp[d] = 1.0
|
173 |
+
|
174 |
+
# compute precision recall
|
175 |
+
fp = np.cumsum(fp)
|
176 |
+
tp = np.cumsum(tp)
|
177 |
+
rec = tp / float(npos)
|
178 |
+
# avoid divide by zero in case the first detection matches a difficult
|
179 |
+
# ground truth
|
180 |
+
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
181 |
+
ap = voc_ap(rec, prec, use_07_metric)
|
182 |
+
|
183 |
+
return rec, prec, ap
|
yolox/evaluators/voc_evaluator.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import sys
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
from collections import ChainMap
|
9 |
+
from loguru import logger
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from yolox.utils import gather, is_main_process, postprocess, synchronize, time_synchronized
|
17 |
+
|
18 |
+
|
19 |
+
class VOCEvaluator:
|
20 |
+
"""
|
21 |
+
VOC AP Evaluation class.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, dataloader, img_size, confthre, nmsthre, num_classes):
|
25 |
+
"""
|
26 |
+
Args:
|
27 |
+
dataloader (Dataloader): evaluate dataloader.
|
28 |
+
img_size (int): image size after preprocess. images are resized
|
29 |
+
to squares whose shape is (img_size, img_size).
|
30 |
+
confthre (float): confidence threshold ranging from 0 to 1, which
|
31 |
+
is defined in the config file.
|
32 |
+
nmsthre (float): IoU threshold of non-max supression ranging from 0 to 1.
|
33 |
+
"""
|
34 |
+
self.dataloader = dataloader
|
35 |
+
self.img_size = img_size
|
36 |
+
self.confthre = confthre
|
37 |
+
self.nmsthre = nmsthre
|
38 |
+
self.num_classes = num_classes
|
39 |
+
self.num_images = len(dataloader.dataset)
|
40 |
+
|
41 |
+
def evaluate(
|
42 |
+
self, model, distributed=False, half=False, trt_file=None,
|
43 |
+
decoder=None, test_size=None, return_outputs=False,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
VOC average precision (AP) Evaluation. Iterate inference on the test dataset
|
47 |
+
and the results are evaluated by COCO API.
|
48 |
+
|
49 |
+
NOTE: This function will change training mode to False, please save states if needed.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
model : model to evaluate.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
ap50_95 (float) : COCO style AP of IoU=50:95
|
56 |
+
ap50 (float) : VOC 2007 metric AP of IoU=50
|
57 |
+
summary (sr): summary info of evaluation.
|
58 |
+
"""
|
59 |
+
# TODO half to amp_test
|
60 |
+
tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
|
61 |
+
model = model.eval()
|
62 |
+
if half:
|
63 |
+
model = model.half()
|
64 |
+
ids = []
|
65 |
+
data_list = {}
|
66 |
+
progress_bar = tqdm if is_main_process() else iter
|
67 |
+
|
68 |
+
inference_time = 0
|
69 |
+
nms_time = 0
|
70 |
+
n_samples = max(len(self.dataloader) - 1, 1)
|
71 |
+
|
72 |
+
if trt_file is not None:
|
73 |
+
from torch2trt import TRTModule
|
74 |
+
|
75 |
+
model_trt = TRTModule()
|
76 |
+
model_trt.load_state_dict(torch.load(trt_file))
|
77 |
+
|
78 |
+
x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
|
79 |
+
model(x)
|
80 |
+
model = model_trt
|
81 |
+
|
82 |
+
for cur_iter, (imgs, _, info_imgs, ids) in enumerate(progress_bar(self.dataloader)):
|
83 |
+
with torch.no_grad():
|
84 |
+
imgs = imgs.type(tensor_type)
|
85 |
+
|
86 |
+
# skip the last iters since batchsize might be not enough for batch inference
|
87 |
+
is_time_record = cur_iter < len(self.dataloader) - 1
|
88 |
+
if is_time_record:
|
89 |
+
start = time.time()
|
90 |
+
|
91 |
+
outputs = model(imgs)
|
92 |
+
if decoder is not None:
|
93 |
+
outputs = decoder(outputs, dtype=outputs.type())
|
94 |
+
|
95 |
+
if is_time_record:
|
96 |
+
infer_end = time_synchronized()
|
97 |
+
inference_time += infer_end - start
|
98 |
+
|
99 |
+
outputs = postprocess(
|
100 |
+
outputs, self.num_classes, self.confthre, self.nmsthre
|
101 |
+
)
|
102 |
+
if is_time_record:
|
103 |
+
nms_end = time_synchronized()
|
104 |
+
nms_time += nms_end - infer_end
|
105 |
+
|
106 |
+
data_list.update(self.convert_to_voc_format(outputs, info_imgs, ids))
|
107 |
+
|
108 |
+
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
|
109 |
+
if distributed:
|
110 |
+
data_list = gather(data_list, dst=0)
|
111 |
+
data_list = ChainMap(*data_list)
|
112 |
+
torch.distributed.reduce(statistics, dst=0)
|
113 |
+
|
114 |
+
eval_results = self.evaluate_prediction(data_list, statistics)
|
115 |
+
synchronize()
|
116 |
+
if return_outputs:
|
117 |
+
return eval_results, data_list
|
118 |
+
return eval_results
|
119 |
+
|
120 |
+
def convert_to_voc_format(self, outputs, info_imgs, ids):
|
121 |
+
predictions = {}
|
122 |
+
for output, img_h, img_w, img_id in zip(outputs, info_imgs[0], info_imgs[1], ids):
|
123 |
+
if output is None:
|
124 |
+
predictions[int(img_id)] = (None, None, None)
|
125 |
+
continue
|
126 |
+
output = output.cpu()
|
127 |
+
|
128 |
+
bboxes = output[:, 0:4]
|
129 |
+
|
130 |
+
# preprocessing: resize
|
131 |
+
scale = min(self.img_size[0] / float(img_h), self.img_size[1] / float(img_w))
|
132 |
+
bboxes /= scale
|
133 |
+
|
134 |
+
cls = output[:, 6]
|
135 |
+
scores = output[:, 4] * output[:, 5]
|
136 |
+
|
137 |
+
predictions[int(img_id)] = (bboxes, cls, scores)
|
138 |
+
return predictions
|
139 |
+
|
140 |
+
def evaluate_prediction(self, data_dict, statistics):
|
141 |
+
if not is_main_process():
|
142 |
+
return 0, 0, None
|
143 |
+
|
144 |
+
logger.info("Evaluate in main process...")
|
145 |
+
|
146 |
+
inference_time = statistics[0].item()
|
147 |
+
nms_time = statistics[1].item()
|
148 |
+
n_samples = statistics[2].item()
|
149 |
+
|
150 |
+
a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
|
151 |
+
a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
|
152 |
+
|
153 |
+
time_info = ", ".join(
|
154 |
+
[
|
155 |
+
"Average {} time: {:.2f} ms".format(k, v)
|
156 |
+
for k, v in zip(
|
157 |
+
["forward", "NMS", "inference"],
|
158 |
+
[a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
|
159 |
+
)
|
160 |
+
]
|
161 |
+
)
|
162 |
+
info = time_info + "\n"
|
163 |
+
|
164 |
+
all_boxes = [
|
165 |
+
[[] for _ in range(self.num_images)] for _ in range(self.num_classes)
|
166 |
+
]
|
167 |
+
for img_num in range(self.num_images):
|
168 |
+
bboxes, cls, scores = data_dict[img_num]
|
169 |
+
if bboxes is None:
|
170 |
+
for j in range(self.num_classes):
|
171 |
+
all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
|
172 |
+
continue
|
173 |
+
for j in range(self.num_classes):
|
174 |
+
mask_c = cls == j
|
175 |
+
if sum(mask_c) == 0:
|
176 |
+
all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
|
177 |
+
continue
|
178 |
+
|
179 |
+
c_dets = torch.cat((bboxes, scores.unsqueeze(1)), dim=1)
|
180 |
+
all_boxes[j][img_num] = c_dets[mask_c].numpy()
|
181 |
+
|
182 |
+
sys.stdout.write(f"im_eval: {img_num + 1}/{self.num_images} \r")
|
183 |
+
sys.stdout.flush()
|
184 |
+
|
185 |
+
with tempfile.TemporaryDirectory() as tempdir:
|
186 |
+
mAP50, mAP70 = self.dataloader.dataset.evaluate_detections(all_boxes, tempdir)
|
187 |
+
return mAP50, mAP70, info
|
yolox/exp/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
+
|
4 |
+
from .base_exp import BaseExp
|
5 |
+
from .build import get_exp
|
6 |
+
from .yolox_base import Exp, check_exp_value
|
yolox/exp/base_exp.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
+
|
4 |
+
import ast
|
5 |
+
import pprint
|
6 |
+
from abc import ABCMeta, abstractmethod
|
7 |
+
from typing import Dict, List, Tuple
|
8 |
+
from tabulate import tabulate
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn import Module
|
12 |
+
|
13 |
+
from yolox.utils import LRScheduler
|
14 |
+
|
15 |
+
|
16 |
+
class BaseExp(metaclass=ABCMeta):
|
17 |
+
"""Basic class for any experiment."""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.seed = None
|
21 |
+
self.output_dir = "./YOLOX_outputs"
|
22 |
+
self.print_interval = 100
|
23 |
+
self.eval_interval = 10
|
24 |
+
self.dataset = None
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def get_model(self) -> Module:
|
28 |
+
pass
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
|
32 |
+
pass
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def get_data_loader(
|
36 |
+
self, batch_size: int, is_distributed: bool
|
37 |
+
) -> Dict[str, torch.utils.data.DataLoader]:
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
|
42 |
+
pass
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def get_lr_scheduler(
|
46 |
+
self, lr: float, iters_per_epoch: int, **kwargs
|
47 |
+
) -> LRScheduler:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def get_evaluator(self):
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def eval(self, model, evaluator, weights):
|
56 |
+
pass
|
57 |
+
|
58 |
+
def __repr__(self):
|
59 |
+
table_header = ["keys", "values"]
|
60 |
+
exp_table = [
|
61 |
+
(str(k), pprint.pformat(v))
|
62 |
+
for k, v in vars(self).items()
|
63 |
+
if not k.startswith("_")
|
64 |
+
]
|
65 |
+
return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
|
66 |
+
|
67 |
+
def merge(self, cfg_list):
|
68 |
+
assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}"
|
69 |
+
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
70 |
+
# only update value with same key
|
71 |
+
if hasattr(self, k):
|
72 |
+
src_value = getattr(self, k)
|
73 |
+
src_type = type(src_value)
|
74 |
+
|
75 |
+
# pre-process input if source type is list or tuple
|
76 |
+
if isinstance(src_value, (List, Tuple)):
|
77 |
+
v = v.strip("[]()")
|
78 |
+
v = [t.strip() for t in v.split(",")]
|
79 |
+
|
80 |
+
# find type of tuple
|
81 |
+
if len(src_value) > 0:
|
82 |
+
src_item_type = type(src_value[0])
|
83 |
+
v = [src_item_type(t) for t in v]
|
84 |
+
|
85 |
+
if src_value is not None and src_type != type(v):
|
86 |
+
try:
|
87 |
+
v = src_type(v)
|
88 |
+
except Exception:
|
89 |
+
v = ast.literal_eval(v)
|
90 |
+
setattr(self, k, v)
|
yolox/exp/build.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
+
|
5 |
+
import importlib
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
|
10 |
+
def get_exp_by_file(exp_file):
|
11 |
+
try:
|
12 |
+
sys.path.append(os.path.dirname(exp_file))
|
13 |
+
current_exp = importlib.import_module(os.path.basename(exp_file).split(".")[0])
|
14 |
+
exp = current_exp.Exp()
|
15 |
+
except Exception:
|
16 |
+
raise ImportError("{} doesn't contains class named 'Exp'".format(exp_file))
|
17 |
+
return exp
|
18 |
+
|
19 |
+
|
20 |
+
def get_exp_by_name(exp_name):
|
21 |
+
exp = exp_name.replace("-", "_") # convert string like "yolox-s" to "yolox_s"
|
22 |
+
module_name = ".".join(["yolox", "exp", "default", exp])
|
23 |
+
exp_object = importlib.import_module(module_name).Exp()
|
24 |
+
return exp_object
|
25 |
+
|
26 |
+
|
27 |
+
def get_exp(exp_file=None, exp_name=None):
|
28 |
+
"""
|
29 |
+
get Exp object by file or name. If exp_file and exp_name
|
30 |
+
are both provided, get Exp by exp_file.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
exp_file (str): file path of experiment.
|
34 |
+
exp_name (str): name of experiment. "yolo-s",
|
35 |
+
"""
|
36 |
+
assert (
|
37 |
+
exp_file is not None or exp_name is not None
|
38 |
+
), "plz provide exp file or exp name."
|
39 |
+
if exp_file is not None:
|
40 |
+
return get_exp_by_file(exp_file)
|
41 |
+
else:
|
42 |
+
return get_exp_by_name(exp_name)
|
yolox/exp/default/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
+
|
5 |
+
# This file is used for package installation and find default exp file
|
6 |
+
|
7 |
+
import sys
|
8 |
+
from importlib import abc, util
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
_EXP_PATH = Path(__file__).resolve().parent.parent.parent.parent / "exps" / "default"
|
12 |
+
|
13 |
+
if _EXP_PATH.is_dir():
|
14 |
+
# This is true only for in-place installation (pip install -e, setup.py develop),
|
15 |
+
# where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
|
16 |
+
|
17 |
+
class _ExpFinder(abc.MetaPathFinder):
|
18 |
+
|
19 |
+
def find_spec(self, name, path, target=None):
|
20 |
+
if not name.startswith("yolox.exp.default"):
|
21 |
+
return
|
22 |
+
project_name = name.split(".")[-1] + ".py"
|
23 |
+
target_file = _EXP_PATH / project_name
|
24 |
+
if not target_file.is_file():
|
25 |
+
return
|
26 |
+
return util.spec_from_file_location(name, target_file)
|
27 |
+
|
28 |
+
sys.meta_path.append(_ExpFinder())
|
yolox/exp/yolox_base.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from .base_exp import BaseExp
|
12 |
+
|
13 |
+
__all__ = ["Exp", "check_exp_value"]
|
14 |
+
|
15 |
+
|
16 |
+
class Exp(BaseExp):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# ---------------- model config ---------------- #
|
21 |
+
# detect classes number of model
|
22 |
+
self.num_classes = 80
|
23 |
+
# factor of model depth
|
24 |
+
self.depth = 1.00
|
25 |
+
# factor of model width
|
26 |
+
self.width = 1.00
|
27 |
+
# activation name. For example, if using "relu", then "silu" will be replaced to "relu".
|
28 |
+
self.act = "silu"
|
29 |
+
|
30 |
+
# ---------------- dataloader config ---------------- #
|
31 |
+
# set worker to 4 for shorter dataloader init time
|
32 |
+
# If your training process cost many memory, reduce this value.
|
33 |
+
self.data_num_workers = 4
|
34 |
+
self.input_size = (640, 640) # (height, width)
|
35 |
+
# Actual multiscale ranges: [640 - 5 * 32, 640 + 5 * 32].
|
36 |
+
# To disable multiscale training, set the value to 0.
|
37 |
+
self.multiscale_range = 5
|
38 |
+
# You can uncomment this line to specify a multiscale range
|
39 |
+
# self.random_size = (14, 26)
|
40 |
+
# dir of dataset images, if data_dir is None, this project will use `datasets` dir
|
41 |
+
self.data_dir = None
|
42 |
+
# name of annotation file for training
|
43 |
+
self.train_ann = "instances_train2017.json"
|
44 |
+
# name of annotation file for evaluation
|
45 |
+
self.val_ann = "instances_val2017.json"
|
46 |
+
# name of annotation file for testing
|
47 |
+
self.test_ann = "instances_test2017.json"
|
48 |
+
|
49 |
+
# --------------- transform config ----------------- #
|
50 |
+
# prob of applying mosaic aug
|
51 |
+
self.mosaic_prob = 1.0
|
52 |
+
# prob of applying mixup aug
|
53 |
+
self.mixup_prob = 1.0
|
54 |
+
# prob of applying hsv aug
|
55 |
+
self.hsv_prob = 1.0
|
56 |
+
# prob of applying flip aug
|
57 |
+
self.flip_prob = 0.5
|
58 |
+
# rotation angle range, for example, if set to 2, the true range is (-2, 2)
|
59 |
+
self.degrees = 10.0
|
60 |
+
# translate range, for example, if set to 0.1, the true range is (-0.1, 0.1)
|
61 |
+
self.translate = 0.1
|
62 |
+
self.mosaic_scale = (0.1, 2)
|
63 |
+
# apply mixup aug or not
|
64 |
+
self.enable_mixup = True
|
65 |
+
self.mixup_scale = (0.5, 1.5)
|
66 |
+
# shear angle range, for example, if set to 2, the true range is (-2, 2)
|
67 |
+
self.shear = 2.0
|
68 |
+
|
69 |
+
# -------------- training config --------------------- #
|
70 |
+
# epoch number used for warmup
|
71 |
+
self.warmup_epochs = 5
|
72 |
+
# max training epoch
|
73 |
+
self.max_epoch = 300
|
74 |
+
# minimum learning rate during warmup
|
75 |
+
self.warmup_lr = 0
|
76 |
+
self.min_lr_ratio = 0.05
|
77 |
+
# learning rate for one image. During training, lr will multiply batchsize.
|
78 |
+
self.basic_lr_per_img = 0.01 / 64.0
|
79 |
+
# name of LRScheduler
|
80 |
+
self.scheduler = "yoloxwarmcos"
|
81 |
+
# last #epoch to close augmention like mosaic
|
82 |
+
self.no_aug_epochs = 15
|
83 |
+
# apply EMA during training
|
84 |
+
self.ema = True
|
85 |
+
|
86 |
+
# weight decay of optimizer
|
87 |
+
self.weight_decay = 5e-4
|
88 |
+
# momentum of optimizer
|
89 |
+
self.momentum = 0.9
|
90 |
+
# log period in iter, for example,
|
91 |
+
# if set to 1, user could see log every iteration.
|
92 |
+
self.print_interval = 10
|
93 |
+
# eval period in epoch, for example,
|
94 |
+
# if set to 1, model will be evaluate after every epoch.
|
95 |
+
self.eval_interval = 10
|
96 |
+
# save history checkpoint or not.
|
97 |
+
# If set to False, yolox will only save latest and best ckpt.
|
98 |
+
self.save_history_ckpt = True
|
99 |
+
# name of experiment
|
100 |
+
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
|
101 |
+
|
102 |
+
# ----------------- testing config ------------------ #
|
103 |
+
# output image size during evaluation/test
|
104 |
+
self.test_size = (640, 640)
|
105 |
+
# confidence threshold during evaluation/test,
|
106 |
+
# boxes whose scores are less than test_conf will be filtered
|
107 |
+
self.test_conf = 0.01
|
108 |
+
# nms threshold
|
109 |
+
self.nmsthre = 0.65
|
110 |
+
|
111 |
+
def get_model(self):
|
112 |
+
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
|
113 |
+
|
114 |
+
def init_yolo(M):
|
115 |
+
for m in M.modules():
|
116 |
+
if isinstance(m, nn.BatchNorm2d):
|
117 |
+
m.eps = 1e-3
|
118 |
+
m.momentum = 0.03
|
119 |
+
|
120 |
+
if getattr(self, "model", None) is None:
|
121 |
+
in_channels = [256, 512, 1024]
|
122 |
+
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
|
123 |
+
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
|
124 |
+
self.model = YOLOX(backbone, head)
|
125 |
+
|
126 |
+
self.model.apply(init_yolo)
|
127 |
+
self.model.head.initialize_biases(1e-2)
|
128 |
+
self.model.train()
|
129 |
+
return self.model
|
130 |
+
|
131 |
+
def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
|
132 |
+
"""
|
133 |
+
Get dataset according to cache and cache_type parameters.
|
134 |
+
Args:
|
135 |
+
cache (bool): Whether to cache imgs to ram or disk.
|
136 |
+
cache_type (str, optional): Defaults to "ram".
|
137 |
+
"ram" : Caching imgs to ram for fast training.
|
138 |
+
"disk": Caching imgs to disk for fast training.
|
139 |
+
"""
|
140 |
+
from yolox.data import COCODataset, TrainTransform
|
141 |
+
|
142 |
+
return COCODataset(
|
143 |
+
data_dir=self.data_dir,
|
144 |
+
json_file=self.train_ann,
|
145 |
+
img_size=self.input_size,
|
146 |
+
preproc=TrainTransform(
|
147 |
+
max_labels=50,
|
148 |
+
flip_prob=self.flip_prob,
|
149 |
+
hsv_prob=self.hsv_prob
|
150 |
+
),
|
151 |
+
cache=cache,
|
152 |
+
cache_type=cache_type,
|
153 |
+
)
|
154 |
+
|
155 |
+
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
|
156 |
+
"""
|
157 |
+
Get dataloader according to cache_img parameter.
|
158 |
+
Args:
|
159 |
+
no_aug (bool, optional): Whether to turn off mosaic data enhancement. Defaults to False.
|
160 |
+
cache_img (str, optional): cache_img is equivalent to cache_type. Defaults to None.
|
161 |
+
"ram" : Caching imgs to ram for fast training.
|
162 |
+
"disk": Caching imgs to disk for fast training.
|
163 |
+
None: Do not use cache, in this case cache_data is also None.
|
164 |
+
"""
|
165 |
+
from yolox.data import (
|
166 |
+
TrainTransform,
|
167 |
+
YoloBatchSampler,
|
168 |
+
DataLoader,
|
169 |
+
InfiniteSampler,
|
170 |
+
MosaicDetection,
|
171 |
+
worker_init_reset_seed,
|
172 |
+
)
|
173 |
+
from yolox.utils import wait_for_the_master
|
174 |
+
|
175 |
+
# if cache is True, we will create self.dataset before launch
|
176 |
+
# else we will create self.dataset after launch
|
177 |
+
if self.dataset is None:
|
178 |
+
with wait_for_the_master():
|
179 |
+
assert cache_img is None, \
|
180 |
+
"cache_img must be None if you didn't create self.dataset before launch"
|
181 |
+
self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
|
182 |
+
|
183 |
+
self.dataset = MosaicDetection(
|
184 |
+
dataset=self.dataset,
|
185 |
+
mosaic=not no_aug,
|
186 |
+
img_size=self.input_size,
|
187 |
+
preproc=TrainTransform(
|
188 |
+
max_labels=120,
|
189 |
+
flip_prob=self.flip_prob,
|
190 |
+
hsv_prob=self.hsv_prob),
|
191 |
+
degrees=self.degrees,
|
192 |
+
translate=self.translate,
|
193 |
+
mosaic_scale=self.mosaic_scale,
|
194 |
+
mixup_scale=self.mixup_scale,
|
195 |
+
shear=self.shear,
|
196 |
+
enable_mixup=self.enable_mixup,
|
197 |
+
mosaic_prob=self.mosaic_prob,
|
198 |
+
mixup_prob=self.mixup_prob,
|
199 |
+
)
|
200 |
+
|
201 |
+
if is_distributed:
|
202 |
+
batch_size = batch_size // dist.get_world_size()
|
203 |
+
|
204 |
+
sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
|
205 |
+
|
206 |
+
batch_sampler = YoloBatchSampler(
|
207 |
+
sampler=sampler,
|
208 |
+
batch_size=batch_size,
|
209 |
+
drop_last=False,
|
210 |
+
mosaic=not no_aug,
|
211 |
+
)
|
212 |
+
|
213 |
+
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
|
214 |
+
dataloader_kwargs["batch_sampler"] = batch_sampler
|
215 |
+
|
216 |
+
# Make sure each process has different random seed, especially for 'fork' method.
|
217 |
+
# Check https://github.com/pytorch/pytorch/issues/63311 for more details.
|
218 |
+
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
|
219 |
+
|
220 |
+
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
|
221 |
+
|
222 |
+
return train_loader
|
223 |
+
|
224 |
+
def random_resize(self, data_loader, epoch, rank, is_distributed):
|
225 |
+
tensor = torch.LongTensor(2).cuda()
|
226 |
+
|
227 |
+
if rank == 0:
|
228 |
+
size_factor = self.input_size[1] * 1.0 / self.input_size[0]
|
229 |
+
if not hasattr(self, 'random_size'):
|
230 |
+
min_size = int(self.input_size[0] / 32) - self.multiscale_range
|
231 |
+
max_size = int(self.input_size[0] / 32) + self.multiscale_range
|
232 |
+
self.random_size = (min_size, max_size)
|
233 |
+
size = random.randint(*self.random_size)
|
234 |
+
size = (int(32 * size), 32 * int(size * size_factor))
|
235 |
+
tensor[0] = size[0]
|
236 |
+
tensor[1] = size[1]
|
237 |
+
|
238 |
+
if is_distributed:
|
239 |
+
dist.barrier()
|
240 |
+
dist.broadcast(tensor, 0)
|
241 |
+
|
242 |
+
input_size = (tensor[0].item(), tensor[1].item())
|
243 |
+
return input_size
|
244 |
+
|
245 |
+
def preprocess(self, inputs, targets, tsize):
|
246 |
+
scale_y = tsize[0] / self.input_size[0]
|
247 |
+
scale_x = tsize[1] / self.input_size[1]
|
248 |
+
if scale_x != 1 or scale_y != 1:
|
249 |
+
inputs = nn.functional.interpolate(
|
250 |
+
inputs, size=tsize, mode="bilinear", align_corners=False
|
251 |
+
)
|
252 |
+
targets[..., 1::2] = targets[..., 1::2] * scale_x
|
253 |
+
targets[..., 2::2] = targets[..., 2::2] * scale_y
|
254 |
+
return inputs, targets
|
255 |
+
|
256 |
+
def get_optimizer(self, batch_size):
|
257 |
+
if "optimizer" not in self.__dict__:
|
258 |
+
if self.warmup_epochs > 0:
|
259 |
+
lr = self.warmup_lr
|
260 |
+
else:
|
261 |
+
lr = self.basic_lr_per_img * batch_size
|
262 |
+
|
263 |
+
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
264 |
+
|
265 |
+
for k, v in self.model.named_modules():
|
266 |
+
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
267 |
+
pg2.append(v.bias) # biases
|
268 |
+
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
|
269 |
+
pg0.append(v.weight) # no decay
|
270 |
+
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
271 |
+
pg1.append(v.weight) # apply decay
|
272 |
+
|
273 |
+
optimizer = torch.optim.SGD(
|
274 |
+
pg0, lr=lr, momentum=self.momentum, nesterov=True
|
275 |
+
)
|
276 |
+
optimizer.add_param_group(
|
277 |
+
{"params": pg1, "weight_decay": self.weight_decay}
|
278 |
+
) # add pg1 with weight_decay
|
279 |
+
optimizer.add_param_group({"params": pg2})
|
280 |
+
self.optimizer = optimizer
|
281 |
+
|
282 |
+
return self.optimizer
|
283 |
+
|
284 |
+
def get_lr_scheduler(self, lr, iters_per_epoch):
|
285 |
+
from yolox.utils import LRScheduler
|
286 |
+
|
287 |
+
scheduler = LRScheduler(
|
288 |
+
self.scheduler,
|
289 |
+
lr,
|
290 |
+
iters_per_epoch,
|
291 |
+
self.max_epoch,
|
292 |
+
warmup_epochs=self.warmup_epochs,
|
293 |
+
warmup_lr_start=self.warmup_lr,
|
294 |
+
no_aug_epochs=self.no_aug_epochs,
|
295 |
+
min_lr_ratio=self.min_lr_ratio,
|
296 |
+
)
|
297 |
+
return scheduler
|
298 |
+
|
299 |
+
def get_eval_dataset(self, **kwargs):
|
300 |
+
from yolox.data import COCODataset, ValTransform
|
301 |
+
testdev = kwargs.get("testdev", False)
|
302 |
+
legacy = kwargs.get("legacy", False)
|
303 |
+
|
304 |
+
return COCODataset(
|
305 |
+
data_dir=self.data_dir,
|
306 |
+
json_file=self.val_ann if not testdev else self.test_ann,
|
307 |
+
name="val2017" if not testdev else "test2017",
|
308 |
+
img_size=self.test_size,
|
309 |
+
preproc=ValTransform(legacy=legacy),
|
310 |
+
)
|
311 |
+
|
312 |
+
def get_eval_loader(self, batch_size, is_distributed, **kwargs):
|
313 |
+
valdataset = self.get_eval_dataset(**kwargs)
|
314 |
+
|
315 |
+
if is_distributed:
|
316 |
+
batch_size = batch_size // dist.get_world_size()
|
317 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
318 |
+
valdataset, shuffle=False
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
sampler = torch.utils.data.SequentialSampler(valdataset)
|
322 |
+
|
323 |
+
dataloader_kwargs = {
|
324 |
+
"num_workers": self.data_num_workers,
|
325 |
+
"pin_memory": True,
|
326 |
+
"sampler": sampler,
|
327 |
+
}
|
328 |
+
dataloader_kwargs["batch_size"] = batch_size
|
329 |
+
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
|
330 |
+
|
331 |
+
return val_loader
|
332 |
+
|
333 |
+
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
|
334 |
+
from yolox.evaluators import COCOEvaluator
|
335 |
+
|
336 |
+
return COCOEvaluator(
|
337 |
+
dataloader=self.get_eval_loader(batch_size, is_distributed,
|
338 |
+
testdev=testdev, legacy=legacy),
|
339 |
+
img_size=self.test_size,
|
340 |
+
confthre=self.test_conf,
|
341 |
+
nmsthre=self.nmsthre,
|
342 |
+
num_classes=self.num_classes,
|
343 |
+
testdev=testdev,
|
344 |
+
)
|
345 |
+
|
346 |
+
def get_trainer(self, args):
|
347 |
+
from yolox.core import Trainer
|
348 |
+
trainer = Trainer(self, args)
|
349 |
+
# NOTE: trainer shouldn't be an attribute of exp object
|
350 |
+
return trainer
|
351 |
+
|
352 |
+
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
|
353 |
+
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
|
354 |
+
|
355 |
+
|
356 |
+
def check_exp_value(exp: Exp):
|
357 |
+
h, w = exp.input_size
|
358 |
+
assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32"
|
yolox/layers/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
+
|
5 |
+
# import torch first to make jit op work without `ImportError of libc10.so`
|
6 |
+
import torch # noqa
|
7 |
+
|
8 |
+
from .jit_ops import FastCOCOEvalOp, JitOp
|
9 |
+
|
10 |
+
try:
|
11 |
+
from .fast_coco_eval_api import COCOeval_opt
|
12 |
+
except ImportError: # exception will be raised when users build yolox from source
|
13 |
+
pass
|
yolox/layers/cocoeval/cocoeval.cpp
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
#include "cocoeval.h"
|
3 |
+
#include <time.h>
|
4 |
+
#include <algorithm>
|
5 |
+
#include <cstdint>
|
6 |
+
#include <numeric>
|
7 |
+
|
8 |
+
using namespace pybind11::literals;
|
9 |
+
|
10 |
+
namespace COCOeval {
|
11 |
+
|
12 |
+
// Sort detections from highest score to lowest, such that
|
13 |
+
// detection_instances[detection_sorted_indices[t]] >=
|
14 |
+
// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match
|
15 |
+
// original COCO API
|
16 |
+
void SortInstancesByDetectionScore(
|
17 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
18 |
+
std::vector<uint64_t>* detection_sorted_indices) {
|
19 |
+
detection_sorted_indices->resize(detection_instances.size());
|
20 |
+
std::iota(
|
21 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
22 |
+
std::stable_sort(
|
23 |
+
detection_sorted_indices->begin(),
|
24 |
+
detection_sorted_indices->end(),
|
25 |
+
[&detection_instances](size_t j1, size_t j2) {
|
26 |
+
return detection_instances[j1].score > detection_instances[j2].score;
|
27 |
+
});
|
28 |
+
}
|
29 |
+
|
30 |
+
// Partition the ground truth objects based on whether or not to ignore them
|
31 |
+
// based on area
|
32 |
+
void SortInstancesByIgnore(
|
33 |
+
const std::array<double, 2>& area_range,
|
34 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
35 |
+
std::vector<uint64_t>* ground_truth_sorted_indices,
|
36 |
+
std::vector<bool>* ignores) {
|
37 |
+
ignores->clear();
|
38 |
+
ignores->reserve(ground_truth_instances.size());
|
39 |
+
for (auto o : ground_truth_instances) {
|
40 |
+
ignores->push_back(
|
41 |
+
o.ignore || o.area < area_range[0] || o.area > area_range[1]);
|
42 |
+
}
|
43 |
+
|
44 |
+
ground_truth_sorted_indices->resize(ground_truth_instances.size());
|
45 |
+
std::iota(
|
46 |
+
ground_truth_sorted_indices->begin(),
|
47 |
+
ground_truth_sorted_indices->end(),
|
48 |
+
0);
|
49 |
+
std::stable_sort(
|
50 |
+
ground_truth_sorted_indices->begin(),
|
51 |
+
ground_truth_sorted_indices->end(),
|
52 |
+
[&ignores](size_t j1, size_t j2) {
|
53 |
+
return (int)(*ignores)[j1] < (int)(*ignores)[j2];
|
54 |
+
});
|
55 |
+
}
|
56 |
+
|
57 |
+
// For each IOU threshold, greedily match each detected instance to a ground
|
58 |
+
// truth instance (if possible) and store the results
|
59 |
+
void MatchDetectionsToGroundTruth(
|
60 |
+
const std::vector<InstanceAnnotation>& detection_instances,
|
61 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
62 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances,
|
63 |
+
const std::vector<uint64_t>& ground_truth_sorted_indices,
|
64 |
+
const std::vector<bool>& ignores,
|
65 |
+
const std::vector<std::vector<double>>& ious,
|
66 |
+
const std::vector<double>& iou_thresholds,
|
67 |
+
const std::array<double, 2>& area_range,
|
68 |
+
ImageEvaluation* results) {
|
69 |
+
// Initialize memory to store return data matches and ignore
|
70 |
+
const int num_iou_thresholds = iou_thresholds.size();
|
71 |
+
const int num_ground_truth = ground_truth_sorted_indices.size();
|
72 |
+
const int num_detections = detection_sorted_indices.size();
|
73 |
+
std::vector<uint64_t> ground_truth_matches(
|
74 |
+
num_iou_thresholds * num_ground_truth, 0);
|
75 |
+
std::vector<uint64_t>& detection_matches = results->detection_matches;
|
76 |
+
std::vector<bool>& detection_ignores = results->detection_ignores;
|
77 |
+
std::vector<bool>& ground_truth_ignores = results->ground_truth_ignores;
|
78 |
+
detection_matches.resize(num_iou_thresholds * num_detections, 0);
|
79 |
+
detection_ignores.resize(num_iou_thresholds * num_detections, false);
|
80 |
+
ground_truth_ignores.resize(num_ground_truth);
|
81 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
82 |
+
ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]];
|
83 |
+
}
|
84 |
+
|
85 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
86 |
+
for (auto d = 0; d < num_detections; ++d) {
|
87 |
+
// information about best match so far (match=-1 -> unmatched)
|
88 |
+
double best_iou = std::min(iou_thresholds[t], 1 - 1e-10);
|
89 |
+
int match = -1;
|
90 |
+
for (auto g = 0; g < num_ground_truth; ++g) {
|
91 |
+
// if this ground truth instance is already matched and not a
|
92 |
+
// crowd, it cannot be matched to another detection
|
93 |
+
if (ground_truth_matches[t * num_ground_truth + g] > 0 &&
|
94 |
+
!ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) {
|
95 |
+
continue;
|
96 |
+
}
|
97 |
+
|
98 |
+
// if detected instance matched to a regular ground truth
|
99 |
+
// instance, we can break on the first ground truth instance
|
100 |
+
// tagged as ignore (because they are sorted by the ignore tag)
|
101 |
+
if (match >= 0 && !ground_truth_ignores[match] &&
|
102 |
+
ground_truth_ignores[g]) {
|
103 |
+
break;
|
104 |
+
}
|
105 |
+
|
106 |
+
// if IOU overlap is the best so far, store the match appropriately
|
107 |
+
if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) {
|
108 |
+
best_iou = ious[d][ground_truth_sorted_indices[g]];
|
109 |
+
match = g;
|
110 |
+
}
|
111 |
+
}
|
112 |
+
// if match was made, store id of match for both detection and
|
113 |
+
// ground truth
|
114 |
+
if (match >= 0) {
|
115 |
+
detection_ignores[t * num_detections + d] = ground_truth_ignores[match];
|
116 |
+
detection_matches[t * num_detections + d] =
|
117 |
+
ground_truth_instances[ground_truth_sorted_indices[match]].id;
|
118 |
+
ground_truth_matches[t * num_ground_truth + match] =
|
119 |
+
detection_instances[detection_sorted_indices[d]].id;
|
120 |
+
}
|
121 |
+
|
122 |
+
// set unmatched detections outside of area range to ignore
|
123 |
+
const InstanceAnnotation& detection =
|
124 |
+
detection_instances[detection_sorted_indices[d]];
|
125 |
+
detection_ignores[t * num_detections + d] =
|
126 |
+
detection_ignores[t * num_detections + d] ||
|
127 |
+
(detection_matches[t * num_detections + d] == 0 &&
|
128 |
+
(detection.area < area_range[0] || detection.area > area_range[1]));
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// store detection score results
|
133 |
+
results->detection_scores.resize(detection_sorted_indices.size());
|
134 |
+
for (size_t d = 0; d < detection_sorted_indices.size(); ++d) {
|
135 |
+
results->detection_scores[d] =
|
136 |
+
detection_instances[detection_sorted_indices[d]].score;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
141 |
+
const std::vector<std::array<double, 2>>& area_ranges,
|
142 |
+
int max_detections,
|
143 |
+
const std::vector<double>& iou_thresholds,
|
144 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
145 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
146 |
+
image_category_ground_truth_instances,
|
147 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
148 |
+
image_category_detection_instances) {
|
149 |
+
const int num_area_ranges = area_ranges.size();
|
150 |
+
const int num_images = image_category_ground_truth_instances.size();
|
151 |
+
const int num_categories =
|
152 |
+
image_category_ious.size() > 0 ? image_category_ious[0].size() : 0;
|
153 |
+
std::vector<uint64_t> detection_sorted_indices;
|
154 |
+
std::vector<uint64_t> ground_truth_sorted_indices;
|
155 |
+
std::vector<bool> ignores;
|
156 |
+
std::vector<ImageEvaluation> results_all(
|
157 |
+
num_images * num_area_ranges * num_categories);
|
158 |
+
|
159 |
+
// Store results for each image, category, and area range combination. Results
|
160 |
+
// for each IOU threshold are packed into the same ImageEvaluation object
|
161 |
+
for (auto i = 0; i < num_images; ++i) {
|
162 |
+
for (auto c = 0; c < num_categories; ++c) {
|
163 |
+
const std::vector<InstanceAnnotation>& ground_truth_instances =
|
164 |
+
image_category_ground_truth_instances[i][c];
|
165 |
+
const std::vector<InstanceAnnotation>& detection_instances =
|
166 |
+
image_category_detection_instances[i][c];
|
167 |
+
|
168 |
+
SortInstancesByDetectionScore(
|
169 |
+
detection_instances, &detection_sorted_indices);
|
170 |
+
if ((int)detection_sorted_indices.size() > max_detections) {
|
171 |
+
detection_sorted_indices.resize(max_detections);
|
172 |
+
}
|
173 |
+
|
174 |
+
for (size_t a = 0; a < area_ranges.size(); ++a) {
|
175 |
+
SortInstancesByIgnore(
|
176 |
+
area_ranges[a],
|
177 |
+
ground_truth_instances,
|
178 |
+
&ground_truth_sorted_indices,
|
179 |
+
&ignores);
|
180 |
+
|
181 |
+
MatchDetectionsToGroundTruth(
|
182 |
+
detection_instances,
|
183 |
+
detection_sorted_indices,
|
184 |
+
ground_truth_instances,
|
185 |
+
ground_truth_sorted_indices,
|
186 |
+
ignores,
|
187 |
+
image_category_ious[i][c],
|
188 |
+
iou_thresholds,
|
189 |
+
area_ranges[a],
|
190 |
+
&results_all
|
191 |
+
[c * num_area_ranges * num_images + a * num_images + i]);
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
|
196 |
+
return results_all;
|
197 |
+
}
|
198 |
+
|
199 |
+
// Convert a python list to a vector
|
200 |
+
template <typename T>
|
201 |
+
std::vector<T> list_to_vec(const py::list& l) {
|
202 |
+
std::vector<T> v(py::len(l));
|
203 |
+
for (int i = 0; i < (int)py::len(l); ++i) {
|
204 |
+
v[i] = l[i].cast<T>();
|
205 |
+
}
|
206 |
+
return v;
|
207 |
+
}
|
208 |
+
|
209 |
+
// Helper function to Accumulate()
|
210 |
+
// Considers the evaluation results applicable to a particular category, area
|
211 |
+
// range, and max_detections parameter setting, which begin at
|
212 |
+
// evaluations[evaluation_index]. Extracts a sorted list of length n of all
|
213 |
+
// applicable detection instances concatenated across all images in the dataset,
|
214 |
+
// which are represented by the outputs evaluation_indices, detection_scores,
|
215 |
+
// image_detection_indices, and detection_sorted_indices--all of which are
|
216 |
+
// length n. evaluation_indices[i] stores the applicable index into
|
217 |
+
// evaluations[] for instance i, which has detection score detection_score[i],
|
218 |
+
// and is the image_detection_indices[i]'th of the list of detections
|
219 |
+
// for the image containing i. detection_sorted_indices[] defines a sorted
|
220 |
+
// permutation of the 3 other outputs
|
221 |
+
int BuildSortedDetectionList(
|
222 |
+
const std::vector<ImageEvaluation>& evaluations,
|
223 |
+
const int64_t evaluation_index,
|
224 |
+
const int64_t num_images,
|
225 |
+
const int max_detections,
|
226 |
+
std::vector<uint64_t>* evaluation_indices,
|
227 |
+
std::vector<double>* detection_scores,
|
228 |
+
std::vector<uint64_t>* detection_sorted_indices,
|
229 |
+
std::vector<uint64_t>* image_detection_indices) {
|
230 |
+
assert(evaluations.size() >= evaluation_index + num_images);
|
231 |
+
|
232 |
+
// Extract a list of object instances of the applicable category, area
|
233 |
+
// range, and max detections requirements such that they can be sorted
|
234 |
+
image_detection_indices->clear();
|
235 |
+
evaluation_indices->clear();
|
236 |
+
detection_scores->clear();
|
237 |
+
image_detection_indices->reserve(num_images * max_detections);
|
238 |
+
evaluation_indices->reserve(num_images * max_detections);
|
239 |
+
detection_scores->reserve(num_images * max_detections);
|
240 |
+
int num_valid_ground_truth = 0;
|
241 |
+
for (auto i = 0; i < num_images; ++i) {
|
242 |
+
const ImageEvaluation& evaluation = evaluations[evaluation_index + i];
|
243 |
+
|
244 |
+
for (int d = 0;
|
245 |
+
d < (int)evaluation.detection_scores.size() && d < max_detections;
|
246 |
+
++d) { // detected instances
|
247 |
+
evaluation_indices->push_back(evaluation_index + i);
|
248 |
+
image_detection_indices->push_back(d);
|
249 |
+
detection_scores->push_back(evaluation.detection_scores[d]);
|
250 |
+
}
|
251 |
+
for (auto ground_truth_ignore : evaluation.ground_truth_ignores) {
|
252 |
+
if (!ground_truth_ignore) {
|
253 |
+
++num_valid_ground_truth;
|
254 |
+
}
|
255 |
+
}
|
256 |
+
}
|
257 |
+
|
258 |
+
// Sort detections by decreasing score, using stable sort to match
|
259 |
+
// python implementation
|
260 |
+
detection_sorted_indices->resize(detection_scores->size());
|
261 |
+
std::iota(
|
262 |
+
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
|
263 |
+
std::stable_sort(
|
264 |
+
detection_sorted_indices->begin(),
|
265 |
+
detection_sorted_indices->end(),
|
266 |
+
[&detection_scores](size_t j1, size_t j2) {
|
267 |
+
return (*detection_scores)[j1] > (*detection_scores)[j2];
|
268 |
+
});
|
269 |
+
|
270 |
+
return num_valid_ground_truth;
|
271 |
+
}
|
272 |
+
|
273 |
+
// Helper function to Accumulate()
|
274 |
+
// Compute a precision recall curve given a sorted list of detected instances
|
275 |
+
// encoded in evaluations, evaluation_indices, detection_scores,
|
276 |
+
// detection_sorted_indices, image_detection_indices (see
|
277 |
+
// BuildSortedDetectionList()). Using vectors precisions and recalls
|
278 |
+
// and temporary storage, output the results into precisions_out, recalls_out,
|
279 |
+
// and scores_out, which are large buffers containing many precion/recall curves
|
280 |
+
// for all possible parameter settings, with precisions_out_index and
|
281 |
+
// recalls_out_index defining the applicable indices to store results.
|
282 |
+
void ComputePrecisionRecallCurve(
|
283 |
+
const int64_t precisions_out_index,
|
284 |
+
const int64_t precisions_out_stride,
|
285 |
+
const int64_t recalls_out_index,
|
286 |
+
const std::vector<double>& recall_thresholds,
|
287 |
+
const int iou_threshold_index,
|
288 |
+
const int num_iou_thresholds,
|
289 |
+
const int num_valid_ground_truth,
|
290 |
+
const std::vector<ImageEvaluation>& evaluations,
|
291 |
+
const std::vector<uint64_t>& evaluation_indices,
|
292 |
+
const std::vector<double>& detection_scores,
|
293 |
+
const std::vector<uint64_t>& detection_sorted_indices,
|
294 |
+
const std::vector<uint64_t>& image_detection_indices,
|
295 |
+
std::vector<double>* precisions,
|
296 |
+
std::vector<double>* recalls,
|
297 |
+
std::vector<double>* precisions_out,
|
298 |
+
std::vector<double>* scores_out,
|
299 |
+
std::vector<double>* recalls_out) {
|
300 |
+
assert(recalls_out->size() > recalls_out_index);
|
301 |
+
|
302 |
+
// Compute precision/recall for each instance in the sorted list of detections
|
303 |
+
int64_t true_positives_sum = 0, false_positives_sum = 0;
|
304 |
+
precisions->clear();
|
305 |
+
recalls->clear();
|
306 |
+
precisions->reserve(detection_sorted_indices.size());
|
307 |
+
recalls->reserve(detection_sorted_indices.size());
|
308 |
+
assert(!evaluations.empty() || detection_sorted_indices.empty());
|
309 |
+
for (auto detection_sorted_index : detection_sorted_indices) {
|
310 |
+
const ImageEvaluation& evaluation =
|
311 |
+
evaluations[evaluation_indices[detection_sorted_index]];
|
312 |
+
const auto num_detections =
|
313 |
+
evaluation.detection_matches.size() / num_iou_thresholds;
|
314 |
+
const auto detection_index = iou_threshold_index * num_detections +
|
315 |
+
image_detection_indices[detection_sorted_index];
|
316 |
+
assert(evaluation.detection_matches.size() > detection_index);
|
317 |
+
assert(evaluation.detection_ignores.size() > detection_index);
|
318 |
+
const int64_t detection_match =
|
319 |
+
evaluation.detection_matches[detection_index];
|
320 |
+
const bool detection_ignores =
|
321 |
+
evaluation.detection_ignores[detection_index];
|
322 |
+
const auto true_positive = detection_match > 0 && !detection_ignores;
|
323 |
+
const auto false_positive = detection_match == 0 && !detection_ignores;
|
324 |
+
if (true_positive) {
|
325 |
+
++true_positives_sum;
|
326 |
+
}
|
327 |
+
if (false_positive) {
|
328 |
+
++false_positives_sum;
|
329 |
+
}
|
330 |
+
|
331 |
+
const double recall =
|
332 |
+
static_cast<double>(true_positives_sum) / num_valid_ground_truth;
|
333 |
+
recalls->push_back(recall);
|
334 |
+
const int64_t num_valid_detections =
|
335 |
+
true_positives_sum + false_positives_sum;
|
336 |
+
const double precision = num_valid_detections > 0
|
337 |
+
? static_cast<double>(true_positives_sum) / num_valid_detections
|
338 |
+
: 0.0;
|
339 |
+
precisions->push_back(precision);
|
340 |
+
}
|
341 |
+
|
342 |
+
(*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0;
|
343 |
+
|
344 |
+
for (int64_t i = static_cast<int64_t>(precisions->size()) - 1; i > 0; --i) {
|
345 |
+
if ((*precisions)[i] > (*precisions)[i - 1]) {
|
346 |
+
(*precisions)[i - 1] = (*precisions)[i];
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
// Sample the per instance precision/recall list at each recall threshold
|
351 |
+
for (size_t r = 0; r < recall_thresholds.size(); ++r) {
|
352 |
+
// first index in recalls >= recall_thresholds[r]
|
353 |
+
std::vector<double>::iterator low = std::lower_bound(
|
354 |
+
recalls->begin(), recalls->end(), recall_thresholds[r]);
|
355 |
+
size_t precisions_index = low - recalls->begin();
|
356 |
+
|
357 |
+
const auto results_ind = precisions_out_index + r * precisions_out_stride;
|
358 |
+
assert(results_ind < precisions_out->size());
|
359 |
+
assert(results_ind < scores_out->size());
|
360 |
+
if (precisions_index < precisions->size()) {
|
361 |
+
(*precisions_out)[results_ind] = (*precisions)[precisions_index];
|
362 |
+
(*scores_out)[results_ind] =
|
363 |
+
detection_scores[detection_sorted_indices[precisions_index]];
|
364 |
+
} else {
|
365 |
+
(*precisions_out)[results_ind] = 0;
|
366 |
+
(*scores_out)[results_ind] = 0;
|
367 |
+
}
|
368 |
+
}
|
369 |
+
}
|
370 |
+
py::dict Accumulate(
|
371 |
+
const py::object& params,
|
372 |
+
const std::vector<ImageEvaluation>& evaluations) {
|
373 |
+
const std::vector<double> recall_thresholds =
|
374 |
+
list_to_vec<double>(params.attr("recThrs"));
|
375 |
+
const std::vector<int> max_detections =
|
376 |
+
list_to_vec<int>(params.attr("maxDets"));
|
377 |
+
const int num_iou_thresholds = py::len(params.attr("iouThrs"));
|
378 |
+
const int num_recall_thresholds = py::len(params.attr("recThrs"));
|
379 |
+
const int num_categories = params.attr("useCats").cast<int>() == 1
|
380 |
+
? py::len(params.attr("catIds"))
|
381 |
+
: 1;
|
382 |
+
const int num_area_ranges = py::len(params.attr("areaRng"));
|
383 |
+
const int num_max_detections = py::len(params.attr("maxDets"));
|
384 |
+
const int num_images = py::len(params.attr("imgIds"));
|
385 |
+
|
386 |
+
std::vector<double> precisions_out(
|
387 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
388 |
+
num_area_ranges * num_max_detections,
|
389 |
+
-1);
|
390 |
+
std::vector<double> recalls_out(
|
391 |
+
num_iou_thresholds * num_categories * num_area_ranges *
|
392 |
+
num_max_detections,
|
393 |
+
-1);
|
394 |
+
std::vector<double> scores_out(
|
395 |
+
num_iou_thresholds * num_recall_thresholds * num_categories *
|
396 |
+
num_area_ranges * num_max_detections,
|
397 |
+
-1);
|
398 |
+
|
399 |
+
// Consider the list of all detected instances in the entire dataset in one
|
400 |
+
// large list. evaluation_indices, detection_scores,
|
401 |
+
// image_detection_indices, and detection_sorted_indices all have the same
|
402 |
+
// length as this list, such that each entry corresponds to one detected
|
403 |
+
// instance
|
404 |
+
std::vector<uint64_t> evaluation_indices; // indices into evaluations[]
|
405 |
+
std::vector<double> detection_scores; // detection scores of each instance
|
406 |
+
std::vector<uint64_t> detection_sorted_indices; // sorted indices of all
|
407 |
+
// instances in the dataset
|
408 |
+
std::vector<uint64_t>
|
409 |
+
image_detection_indices; // indices into the list of detected instances in
|
410 |
+
// the same image as each instance
|
411 |
+
std::vector<double> precisions, recalls;
|
412 |
+
|
413 |
+
for (auto c = 0; c < num_categories; ++c) {
|
414 |
+
for (auto a = 0; a < num_area_ranges; ++a) {
|
415 |
+
for (auto m = 0; m < num_max_detections; ++m) {
|
416 |
+
// The COCO PythonAPI assumes evaluations[] (the return value of
|
417 |
+
// COCOeval::EvaluateImages() is one long list storing results for each
|
418 |
+
// combination of category, area range, and image id, with categories in
|
419 |
+
// the outermost loop and images in the innermost loop.
|
420 |
+
const int64_t evaluations_index =
|
421 |
+
c * num_area_ranges * num_images + a * num_images;
|
422 |
+
int num_valid_ground_truth = BuildSortedDetectionList(
|
423 |
+
evaluations,
|
424 |
+
evaluations_index,
|
425 |
+
num_images,
|
426 |
+
max_detections[m],
|
427 |
+
&evaluation_indices,
|
428 |
+
&detection_scores,
|
429 |
+
&detection_sorted_indices,
|
430 |
+
&image_detection_indices);
|
431 |
+
|
432 |
+
if (num_valid_ground_truth == 0) {
|
433 |
+
continue;
|
434 |
+
}
|
435 |
+
|
436 |
+
for (auto t = 0; t < num_iou_thresholds; ++t) {
|
437 |
+
// recalls_out is a flattened vectors representing a
|
438 |
+
// num_iou_thresholds X num_categories X num_area_ranges X
|
439 |
+
// num_max_detections matrix
|
440 |
+
const int64_t recalls_out_index =
|
441 |
+
t * num_categories * num_area_ranges * num_max_detections +
|
442 |
+
c * num_area_ranges * num_max_detections +
|
443 |
+
a * num_max_detections + m;
|
444 |
+
|
445 |
+
// precisions_out and scores_out are flattened vectors
|
446 |
+
// representing a num_iou_thresholds X num_recall_thresholds X
|
447 |
+
// num_categories X num_area_ranges X num_max_detections matrix
|
448 |
+
const int64_t precisions_out_stride =
|
449 |
+
num_categories * num_area_ranges * num_max_detections;
|
450 |
+
const int64_t precisions_out_index = t * num_recall_thresholds *
|
451 |
+
num_categories * num_area_ranges * num_max_detections +
|
452 |
+
c * num_area_ranges * num_max_detections +
|
453 |
+
a * num_max_detections + m;
|
454 |
+
|
455 |
+
ComputePrecisionRecallCurve(
|
456 |
+
precisions_out_index,
|
457 |
+
precisions_out_stride,
|
458 |
+
recalls_out_index,
|
459 |
+
recall_thresholds,
|
460 |
+
t,
|
461 |
+
num_iou_thresholds,
|
462 |
+
num_valid_ground_truth,
|
463 |
+
evaluations,
|
464 |
+
evaluation_indices,
|
465 |
+
detection_scores,
|
466 |
+
detection_sorted_indices,
|
467 |
+
image_detection_indices,
|
468 |
+
&precisions,
|
469 |
+
&recalls,
|
470 |
+
&precisions_out,
|
471 |
+
&scores_out,
|
472 |
+
&recalls_out);
|
473 |
+
}
|
474 |
+
}
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
time_t rawtime;
|
479 |
+
struct tm local_time;
|
480 |
+
std::array<char, 200> buffer;
|
481 |
+
time(&rawtime);
|
482 |
+
#ifdef _WIN32
|
483 |
+
localtime_s(&local_time, &rawtime);
|
484 |
+
#else
|
485 |
+
localtime_r(&rawtime, &local_time);
|
486 |
+
#endif
|
487 |
+
strftime(
|
488 |
+
buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time);
|
489 |
+
return py::dict(
|
490 |
+
"params"_a = params,
|
491 |
+
"counts"_a = std::vector<int64_t>({num_iou_thresholds,
|
492 |
+
num_recall_thresholds,
|
493 |
+
num_categories,
|
494 |
+
num_area_ranges,
|
495 |
+
num_max_detections}),
|
496 |
+
"date"_a = buffer,
|
497 |
+
"precision"_a = precisions_out,
|
498 |
+
"recall"_a = recalls_out,
|
499 |
+
"scores"_a = scores_out);
|
500 |
+
}
|
501 |
+
|
502 |
+
} // namespace COCOeval
|
yolox/layers/cocoeval/cocoeval.h
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
#pragma once
|
3 |
+
|
4 |
+
#include <pybind11/numpy.h>
|
5 |
+
#include <pybind11/pybind11.h>
|
6 |
+
#include <pybind11/stl.h>
|
7 |
+
#include <pybind11/stl_bind.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
namespace py = pybind11;
|
11 |
+
|
12 |
+
namespace COCOeval {
|
13 |
+
|
14 |
+
// Annotation data for a single object instance in an image
|
15 |
+
struct InstanceAnnotation {
|
16 |
+
InstanceAnnotation(
|
17 |
+
uint64_t id,
|
18 |
+
double score,
|
19 |
+
double area,
|
20 |
+
bool is_crowd,
|
21 |
+
bool ignore)
|
22 |
+
: id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {}
|
23 |
+
uint64_t id;
|
24 |
+
double score = 0.;
|
25 |
+
double area = 0.;
|
26 |
+
bool is_crowd = false;
|
27 |
+
bool ignore = false;
|
28 |
+
};
|
29 |
+
|
30 |
+
// Stores intermediate results for evaluating detection results for a single
|
31 |
+
// image that has D detected instances and G ground truth instances. This stores
|
32 |
+
// matches between detected and ground truth instances
|
33 |
+
struct ImageEvaluation {
|
34 |
+
// For each of the D detected instances, the id of the matched ground truth
|
35 |
+
// instance, or 0 if unmatched
|
36 |
+
std::vector<uint64_t> detection_matches;
|
37 |
+
|
38 |
+
// The detection score of each of the D detected instances
|
39 |
+
std::vector<double> detection_scores;
|
40 |
+
|
41 |
+
// Marks whether or not each of G instances was ignored from evaluation (e.g.,
|
42 |
+
// because it's outside area_range)
|
43 |
+
std::vector<bool> ground_truth_ignores;
|
44 |
+
|
45 |
+
// Marks whether or not each of D instances was ignored from evaluation (e.g.,
|
46 |
+
// because it's outside aRng)
|
47 |
+
std::vector<bool> detection_ignores;
|
48 |
+
};
|
49 |
+
|
50 |
+
template <class T>
|
51 |
+
using ImageCategoryInstances = std::vector<std::vector<std::vector<T>>>;
|
52 |
+
|
53 |
+
// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each
|
54 |
+
// combination of image, category, area range settings, and IOU thresholds to
|
55 |
+
// evaluate, it matches detected instances to ground truth instances and stores
|
56 |
+
// the results into a vector of ImageEvaluation results, which will be
|
57 |
+
// interpreted by the COCOeval::Accumulate() function to produce precion-recall
|
58 |
+
// curves. The parameters of nested vectors have the following semantics:
|
59 |
+
// image_category_ious[i][c][d][g] is the intersection over union of the d'th
|
60 |
+
// detected instance and g'th ground truth instance of
|
61 |
+
// category category_ids[c] in image image_ids[i]
|
62 |
+
// image_category_ground_truth_instances[i][c] is a vector of ground truth
|
63 |
+
// instances in image image_ids[i] of category category_ids[c]
|
64 |
+
// image_category_detection_instances[i][c] is a vector of detected
|
65 |
+
// instances in image image_ids[i] of category category_ids[c]
|
66 |
+
std::vector<ImageEvaluation> EvaluateImages(
|
67 |
+
const std::vector<std::array<double, 2>>& area_ranges, // vector of 2-tuples
|
68 |
+
int max_detections,
|
69 |
+
const std::vector<double>& iou_thresholds,
|
70 |
+
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
|
71 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
72 |
+
image_category_ground_truth_instances,
|
73 |
+
const ImageCategoryInstances<InstanceAnnotation>&
|
74 |
+
image_category_detection_instances);
|
75 |
+
|
76 |
+
// C++ implementation of COCOeval.accumulate(), which generates precision
|
77 |
+
// recall curves for each set of category, IOU threshold, detection area range,
|
78 |
+
// and max number of detections parameters. It is assumed that the parameter
|
79 |
+
// evaluations is the return value of the functon COCOeval::EvaluateImages(),
|
80 |
+
// which was called with the same parameter settings params
|
81 |
+
py::dict Accumulate(
|
82 |
+
const py::object& params,
|
83 |
+
const std::vector<ImageEvaluation>& evalutations);
|
84 |
+
|
85 |
+
} // namespace COCOeval
|
86 |
+
|
87 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
88 |
+
{
|
89 |
+
m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate");
|
90 |
+
m.def(
|
91 |
+
"COCOevalEvaluateImages",
|
92 |
+
&COCOeval::EvaluateImages,
|
93 |
+
"COCOeval::EvaluateImages");
|
94 |
+
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
|
95 |
+
.def(pybind11::init<uint64_t, double, double, bool, bool>());
|
96 |
+
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
|
97 |
+
.def(pybind11::init<>());
|
98 |
+
}
|
yolox/layers/fast_coco_eval_api.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# This file comes from
|
4 |
+
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/evaluation/fast_eval_api.py
|
5 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
6 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
7 |
+
|
8 |
+
import copy
|
9 |
+
import time
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from pycocotools.cocoeval import COCOeval
|
13 |
+
|
14 |
+
from .jit_ops import FastCOCOEvalOp
|
15 |
+
|
16 |
+
|
17 |
+
class COCOeval_opt(COCOeval):
|
18 |
+
"""
|
19 |
+
This is a slightly modified version of the original COCO API, where the functions evaluateImg()
|
20 |
+
and accumulate() are implemented in C++ to speedup evaluation
|
21 |
+
"""
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self.module = FastCOCOEvalOp().load()
|
25 |
+
|
26 |
+
def evaluate(self):
|
27 |
+
"""
|
28 |
+
Run per image evaluation on given images and store results in self.evalImgs_cpp, a
|
29 |
+
datastructure that isn't readable from Python but is used by a c++ implementation of
|
30 |
+
accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure
|
31 |
+
self.evalImgs because this datastructure is a computational bottleneck.
|
32 |
+
:return: None
|
33 |
+
"""
|
34 |
+
tic = time.time()
|
35 |
+
|
36 |
+
print("Running per image evaluation...")
|
37 |
+
p = self.params
|
38 |
+
# add backward compatibility if useSegm is specified in params
|
39 |
+
if p.useSegm is not None:
|
40 |
+
p.iouType = "segm" if p.useSegm == 1 else "bbox"
|
41 |
+
print(
|
42 |
+
"useSegm (deprecated) is not None. Running {} evaluation".format(
|
43 |
+
p.iouType
|
44 |
+
)
|
45 |
+
)
|
46 |
+
print("Evaluate annotation type *{}*".format(p.iouType))
|
47 |
+
p.imgIds = list(np.unique(p.imgIds))
|
48 |
+
if p.useCats:
|
49 |
+
p.catIds = list(np.unique(p.catIds))
|
50 |
+
p.maxDets = sorted(p.maxDets)
|
51 |
+
self.params = p
|
52 |
+
|
53 |
+
self._prepare()
|
54 |
+
|
55 |
+
# loop through images, area range, max detection number
|
56 |
+
catIds = p.catIds if p.useCats else [-1]
|
57 |
+
|
58 |
+
if p.iouType == "segm" or p.iouType == "bbox":
|
59 |
+
computeIoU = self.computeIoU
|
60 |
+
elif p.iouType == "keypoints":
|
61 |
+
computeIoU = self.computeOks
|
62 |
+
self.ious = {
|
63 |
+
(imgId, catId): computeIoU(imgId, catId)
|
64 |
+
for imgId in p.imgIds
|
65 |
+
for catId in catIds
|
66 |
+
}
|
67 |
+
|
68 |
+
maxDet = p.maxDets[-1]
|
69 |
+
|
70 |
+
# <<<< Beginning of code differences with original COCO API
|
71 |
+
def convert_instances_to_cpp(instances, is_det=False):
|
72 |
+
# Convert annotations for a list of instances in an image to a format that's fast
|
73 |
+
# to access in C++
|
74 |
+
instances_cpp = []
|
75 |
+
for instance in instances:
|
76 |
+
instance_cpp = self.module.InstanceAnnotation(
|
77 |
+
int(instance["id"]),
|
78 |
+
instance["score"] if is_det else instance.get("score", 0.0),
|
79 |
+
instance["area"],
|
80 |
+
bool(instance.get("iscrowd", 0)),
|
81 |
+
bool(instance.get("ignore", 0)),
|
82 |
+
)
|
83 |
+
instances_cpp.append(instance_cpp)
|
84 |
+
return instances_cpp
|
85 |
+
|
86 |
+
# Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
|
87 |
+
ground_truth_instances = [
|
88 |
+
[convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
|
89 |
+
for imgId in p.imgIds
|
90 |
+
]
|
91 |
+
detected_instances = [
|
92 |
+
[
|
93 |
+
convert_instances_to_cpp(self._dts[imgId, catId], is_det=True)
|
94 |
+
for catId in p.catIds
|
95 |
+
]
|
96 |
+
for imgId in p.imgIds
|
97 |
+
]
|
98 |
+
ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
|
99 |
+
|
100 |
+
if not p.useCats:
|
101 |
+
# For each image, flatten per-category lists into a single list
|
102 |
+
ground_truth_instances = [
|
103 |
+
[[o for c in i for o in c]] for i in ground_truth_instances
|
104 |
+
]
|
105 |
+
detected_instances = [
|
106 |
+
[[o for c in i for o in c]] for i in detected_instances
|
107 |
+
]
|
108 |
+
|
109 |
+
# Call C++ implementation of self.evaluateImgs()
|
110 |
+
self._evalImgs_cpp = self.module.COCOevalEvaluateImages(
|
111 |
+
p.areaRng,
|
112 |
+
maxDet,
|
113 |
+
p.iouThrs,
|
114 |
+
ious,
|
115 |
+
ground_truth_instances,
|
116 |
+
detected_instances,
|
117 |
+
)
|
118 |
+
self._evalImgs = None
|
119 |
+
|
120 |
+
self._paramsEval = copy.deepcopy(self.params)
|
121 |
+
toc = time.time()
|
122 |
+
print("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic))
|
123 |
+
# >>>> End of code differences with original COCO API
|
124 |
+
|
125 |
+
def accumulate(self):
|
126 |
+
"""
|
127 |
+
Accumulate per image evaluation results and store the result in self.eval. Does not
|
128 |
+
support changing parameter settings from those used by self.evaluate()
|
129 |
+
"""
|
130 |
+
print("Accumulating evaluation results...")
|
131 |
+
tic = time.time()
|
132 |
+
if not hasattr(self, "_evalImgs_cpp"):
|
133 |
+
print("Please run evaluate() first")
|
134 |
+
|
135 |
+
self.eval = self.module.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
|
136 |
+
|
137 |
+
# recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
|
138 |
+
self.eval["recall"] = np.array(self.eval["recall"]).reshape(
|
139 |
+
self.eval["counts"][:1] + self.eval["counts"][2:]
|
140 |
+
)
|
141 |
+
|
142 |
+
# precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
|
143 |
+
# num_area_ranges X num_max_detections
|
144 |
+
self.eval["precision"] = np.array(self.eval["precision"]).reshape(
|
145 |
+
self.eval["counts"]
|
146 |
+
)
|
147 |
+
self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
|
148 |
+
toc = time.time()
|
149 |
+
print(
|
150 |
+
"COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic)
|
151 |
+
)
|
yolox/layers/jit_ops.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Megvii, Inc. and its affiliates. All Rights Reserved
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import importlib
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
__all__ = ["JitOp", "FastCOCOEvalOp"]
|
12 |
+
|
13 |
+
|
14 |
+
class JitOp:
|
15 |
+
"""
|
16 |
+
Just-in-time compilation of ops.
|
17 |
+
|
18 |
+
Some code of `JitOp` is inspired by `deepspeed.op_builder`,
|
19 |
+
check the following link for more details:
|
20 |
+
https://github.com/microsoft/DeepSpeed/blob/master/op_builder/builder.py
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, name):
|
24 |
+
self.name = name
|
25 |
+
|
26 |
+
def absolute_name(self) -> str:
|
27 |
+
"""Get absolute build path for cases where the op is pre-installed."""
|
28 |
+
pass
|
29 |
+
|
30 |
+
def sources(self) -> List:
|
31 |
+
"""Get path list of source files of op.
|
32 |
+
|
33 |
+
NOTE: the path should be elative to root of package during building,
|
34 |
+
Otherwise, exception will be raised when building package.
|
35 |
+
However, for runtime building, path will be absolute.
|
36 |
+
"""
|
37 |
+
pass
|
38 |
+
|
39 |
+
def include_dirs(self) -> List:
|
40 |
+
"""
|
41 |
+
Get list of include paths, relative to root of package.
|
42 |
+
|
43 |
+
NOTE: the path should be elative to root of package.
|
44 |
+
Otherwise, exception will be raised when building package.
|
45 |
+
"""
|
46 |
+
return []
|
47 |
+
|
48 |
+
def define_macros(self) -> List:
|
49 |
+
"""Get list of macros to define for op"""
|
50 |
+
return []
|
51 |
+
|
52 |
+
def cxx_args(self) -> List:
|
53 |
+
"""Get optional list of compiler flags to forward"""
|
54 |
+
args = ["-O2"] if sys.platform == "win32" else ["-O3", "-std=c++14", "-g", "-Wno-reorder"]
|
55 |
+
return args
|
56 |
+
|
57 |
+
def nvcc_args(self) -> List:
|
58 |
+
"""Get optional list of compiler flags to forward to nvcc when building CUDA sources"""
|
59 |
+
args = [
|
60 |
+
"-O3", "--use_fast_math",
|
61 |
+
"-std=c++17" if sys.platform == "win32" else "-std=c++14",
|
62 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
63 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
64 |
+
"-U__CUDA_NO_HALF2_OPERATORS__",
|
65 |
+
]
|
66 |
+
return args
|
67 |
+
|
68 |
+
def build_op(self):
|
69 |
+
from torch.utils.cpp_extension import CppExtension
|
70 |
+
return CppExtension(
|
71 |
+
name=self.absolute_name(),
|
72 |
+
sources=self.sources(),
|
73 |
+
include_dirs=self.include_dirs(),
|
74 |
+
define_macros=self.define_macros(),
|
75 |
+
extra_compile_args={
|
76 |
+
"cxx": self.cxx_args(),
|
77 |
+
},
|
78 |
+
)
|
79 |
+
|
80 |
+
def load(self, verbose=True):
|
81 |
+
try:
|
82 |
+
# try to import op from pre-installed package
|
83 |
+
return importlib.import_module(self.absolute_name())
|
84 |
+
except Exception: # op not compiled, jit load
|
85 |
+
from yolox.utils import wait_for_the_master
|
86 |
+
with wait_for_the_master(): # to avoid race condition
|
87 |
+
return self.jit_load(verbose)
|
88 |
+
|
89 |
+
def jit_load(self, verbose=True):
|
90 |
+
from torch.utils.cpp_extension import load
|
91 |
+
from loguru import logger
|
92 |
+
try:
|
93 |
+
import ninja # noqa
|
94 |
+
except ImportError:
|
95 |
+
if verbose:
|
96 |
+
logger.warning(
|
97 |
+
f"Ninja is not installed, fall back to normal installation for {self.name}."
|
98 |
+
)
|
99 |
+
|
100 |
+
build_tik = time.time()
|
101 |
+
# build op and load
|
102 |
+
op_module = load(
|
103 |
+
name=self.name,
|
104 |
+
sources=self.sources(),
|
105 |
+
extra_cflags=self.cxx_args(),
|
106 |
+
extra_cuda_cflags=self.nvcc_args(),
|
107 |
+
verbose=verbose,
|
108 |
+
)
|
109 |
+
build_duration = time.time() - build_tik
|
110 |
+
if verbose:
|
111 |
+
logger.info(f"Load {self.name} op in {build_duration:.3f}s.")
|
112 |
+
return op_module
|
113 |
+
|
114 |
+
def clear_dynamic_library(self):
|
115 |
+
"""Remove dynamic libraray files generated by JIT compilation."""
|
116 |
+
module = self.load()
|
117 |
+
os.remove(module.__file__)
|
118 |
+
|
119 |
+
|
120 |
+
class FastCOCOEvalOp(JitOp):
|
121 |
+
|
122 |
+
def __init__(self, name="fast_cocoeval"):
|
123 |
+
super().__init__(name=name)
|
124 |
+
|
125 |
+
def absolute_name(self):
|
126 |
+
return f'yolox.layers.{self.name}'
|
127 |
+
|
128 |
+
def sources(self):
|
129 |
+
sources = glob.glob(os.path.join("yolox", "layers", "cocoeval", "*.cpp"))
|
130 |
+
if not sources: # source will be empty list if the so file is removed after install
|
131 |
+
# use abosolute path to compile
|
132 |
+
import yolox
|
133 |
+
code_path = os.path.join(yolox.__path__[0], "layers", "cocoeval", "*.cpp")
|
134 |
+
sources = glob.glob(code_path)
|
135 |
+
return sources
|
136 |
+
|
137 |
+
def include_dirs(self):
|
138 |
+
return [os.path.join("yolox", "layers", "cocoeval")]
|
yolox/models/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
+
|
5 |
+
from .build import *
|
6 |
+
from .darknet import CSPDarknet, Darknet
|
7 |
+
from .losses import IOUloss
|
8 |
+
from .yolo_fpn import YOLOFPN
|
9 |
+
from .yolo_head import YOLOXHead
|
10 |
+
from .yolo_pafpn import YOLOPAFPN
|
11 |
+
from .yolox import YOLOX
|