Spaces:
Build error
Build error
Matteo Sirri
commited on
Commit
•
169e11c
1
Parent(s):
b12f810
feat: initial commit
Browse files- .gitignore +13 -0
- LICENSE +21 -0
- README.md +75 -13
- __init__.py +0 -0
- app.py +63 -0
- configs/__init__.py +0 -0
- configs/path_cfg.py +19 -0
- deps/win/conda_environment.yml +42 -0
- deps/win/conda_requirements.txt +39 -0
- deps/win/pip_requirements.txt +30 -0
- notebook/colab/detector_show.ipynb +0 -0
- notebook/colab/train_detector.ipynb +0 -0
- scripts/evaluate_detector.sh +3 -0
- scripts/inference_detector.sh +3 -0
- scripts/train_detector.sh +3 -0
- src/__init__.py +0 -0
- src/detection/__init__.py +0 -0
- src/detection/graph_utils.py +87 -0
- src/detection/model_factory.py +54 -0
- src/detection/mot_dataset.py +48 -0
- src/detection/vision/README.md +88 -0
- src/detection/vision/__init__.py +0 -0
- src/detection/vision/coco_eval.py +194 -0
- src/detection/vision/coco_utils.py +263 -0
- src/detection/vision/engine.py +137 -0
- src/detection/vision/group_by_aspect_ratio.py +196 -0
- src/detection/vision/mot_data.py +370 -0
- src/detection/vision/presets.py +48 -0
- src/detection/vision/transforms.py +284 -0
- src/detection/vision/utils.py +282 -0
- tools/__init__.py +0 -0
- tools/anns/combine_anns.py +87 -0
- tools/anns/generate_mot_format_files.py +73 -0
- tools/anns/generate_mots_format_files.py +102 -0
- tools/anns/motcha_to_coco.py +145 -0
- tools/anns/splits/motsynth_split1.txt +16 -0
- tools/anns/splits/motsynth_split2.txt +31 -0
- tools/anns/splits/motsynth_split3.txt +62 -0
- tools/anns/splits/motsynth_split4.txt +123 -0
- tools/anns/store_reid_imgs.py +84 -0
- tools/anns/to_frames.py +56 -0
- tools/inference_detector.py +46 -0
- tools/train_detector.py +408 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#storage folders
|
2 |
+
storage/**
|
3 |
+
|
4 |
+
error_causing_batch.pth
|
5 |
+
|
6 |
+
configs/__pycache__
|
7 |
+
src/__pycache__
|
8 |
+
src/detection/__pycache__
|
9 |
+
tools/__pycache__
|
10 |
+
src/detection/vision/__pycache__
|
11 |
+
|
12 |
+
custom_out.png
|
13 |
+
baseline_out.png
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 sir3mat
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,75 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# School in AI Project Work
|
2 |
+
|
3 |
+
This repository contains the code to train and evaluate a pedestrian detector for
|
4 |
+
the "School in Ai 2° edition"@[@UNIMORE](https://www.unimore.it/)
|
5 |
+
|
6 |
+
## Installation
|
7 |
+
|
8 |
+
N.B.: Installation only avaiable in win64 environments
|
9 |
+
|
10 |
+
Create and activate an environment with all required packages:
|
11 |
+
|
12 |
+
```
|
13 |
+
conda create --name ped_detector --file deps/wins/conda_environment.txt
|
14 |
+
# or conda env create -f deps/win/conda_environment.yml
|
15 |
+
conda activate cvcspw
|
16 |
+
pip install -r deps/win/pip_requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
## Dataset download and preparation:
|
20 |
+
### Solution 1 - From Google Drice
|
21 |
+
Download the storage folder directly from Google Drive [here](link google drive)
|
22 |
+
and place it in the root dir of the project
|
23 |
+
After runnning this step, your storage directory should look like this:
|
24 |
+
```text
|
25 |
+
storage
|
26 |
+
├── MOTChallenge
|
27 |
+
├── MOT17
|
28 |
+
├── motcha_coco_annotations
|
29 |
+
├── MOTSynth
|
30 |
+
├── annotations
|
31 |
+
├── comb_annotations
|
32 |
+
├── frames
|
33 |
+
├── motsynth_output
|
34 |
+
```
|
35 |
+
### Solution 2 - From scratch
|
36 |
+
#### Prepare MOTSynth dataset
|
37 |
+
1. Download MOTSynth_1.
|
38 |
+
```
|
39 |
+
wget -P ./storage/MOTSynth https://motchallenge.net/data/MOTSynth_1.zip
|
40 |
+
unzip ./storage/MOTSynth/MOTSynth_1.zip
|
41 |
+
rm ./storage/MOTSynth/MOTSynth_1.zip
|
42 |
+
```
|
43 |
+
2. Delete video from 123 to 256
|
44 |
+
2. Extract frames from the videos
|
45 |
+
```
|
46 |
+
python tools/anns/to_frames.py --motsynth-root ./storage/MOTSynth
|
47 |
+
|
48 |
+
# now you can delete other videos
|
49 |
+
rm -r ./storage/MOTSynth/MOTSynth_1
|
50 |
+
```
|
51 |
+
3. Download and extract annotations
|
52 |
+
```
|
53 |
+
wget -P ./storage/MOTSynth https://motchallenge.net/data/MOTSynth_coco_annotations.zip
|
54 |
+
unzip ./storage/MOTSynth/MOTSynth_coco_annotations.zip
|
55 |
+
rm ./storage/MOTSynth/MOTSynth_coco_annotations.zip
|
56 |
+
```
|
57 |
+
4. Prepare combined annotations for MOTSynth from the original coco annotations
|
58 |
+
```
|
59 |
+
python tools/anns/combine_anns.py --motsynth-path ./storage/MOTSynth
|
60 |
+
```
|
61 |
+
#### Prepare MOT17 dataset
|
62 |
+
|
63 |
+
|
64 |
+
## Colab Usage
|
65 |
+
|
66 |
+
You can also use [Google Colab](https://colab.research.google.com) if you need remote resources like GPUs.
|
67 |
+
In the notebook folder you can find some useful .ipynb files and remember to load the storage folder in your GDrive before usage.
|
68 |
+
|
69 |
+
## Object Detection
|
70 |
+
|
71 |
+
An adaption of torchvision's detection reference code is done to train Faster R-CNN on a portion of the MOTSynth dataset. To train the model you can run:
|
72 |
+
```
|
73 |
+
./scripts/train_detector
|
74 |
+
```
|
75 |
+
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
from tkinter.ttk import Style
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
import torchvision
|
7 |
+
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2
|
8 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
9 |
+
from configs.path_cfg import MOTCHA_ROOT, OUTPUT_DIR
|
10 |
+
from src.detection.graph_utils import add_bbox
|
11 |
+
from src.detection.vision import presets
|
12 |
+
logging.getLogger('PIL').setLevel(logging.CRITICAL)
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(baseline: bool = False):
|
16 |
+
if baseline:
|
17 |
+
model = fasterrcnn_resnet50_fpn_v2(
|
18 |
+
weights="DEFAULT")
|
19 |
+
else:
|
20 |
+
model = fasterrcnn_resnet50_fpn_v2()
|
21 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
22 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
|
23 |
+
checkpoint = torch.load(osp.join(OUTPUT_DIR, "detection_logs",
|
24 |
+
"fasterrcnn_training", "checkpoint.pth"), map_location="cpu")
|
25 |
+
model.load_state_dict(checkpoint["model"])
|
26 |
+
model.eval()
|
27 |
+
return model
|
28 |
+
|
29 |
+
|
30 |
+
def detect_with_resnet50Model_finetuning_motsynth(image):
|
31 |
+
model = load_model()
|
32 |
+
transformEval = presets.DetectionPresetEval()
|
33 |
+
image_tensor = transformEval(image, None)[0]
|
34 |
+
prediction = model([image_tensor])[0]
|
35 |
+
image_w_bbox = add_bbox(image_tensor, prediction, 0.85)
|
36 |
+
torchvision.io.write_png(image_w_bbox, "custom_out.png")
|
37 |
+
return "custom_out.png"
|
38 |
+
|
39 |
+
|
40 |
+
def detect_with_resnet50Model_baseline(image):
|
41 |
+
model = load_model(baseline=True)
|
42 |
+
transformEval = presets.DetectionPresetEval()
|
43 |
+
image_tensor = transformEval(image, None)[0]
|
44 |
+
prediction = model([image_tensor])[0]
|
45 |
+
image_w_bbox = add_bbox(image_tensor, prediction, 0.85)
|
46 |
+
torchvision.io.write_png(image_w_bbox, "baseline_out.png")
|
47 |
+
return "baseline_out.png"
|
48 |
+
|
49 |
+
|
50 |
+
title = "Performance comparision of Faster R-CNN for people detection with syntetic data"
|
51 |
+
description = "<p style='text-align: center'>Performance comparision of Faster R-CNN models for people detecion using MOTSynth and MOT17"
|
52 |
+
examples = [[osp.join(MOTCHA_ROOT, "MOT17", "train",
|
53 |
+
"MOT17-09-DPM", "img1", "000001.jpg")]]
|
54 |
+
|
55 |
+
|
56 |
+
io_baseline = gr.Interface(detect_with_resnet50Model_baseline, gr.Image(type="pil"), gr.Image(
|
57 |
+
type="file", shape=(1920, 1080), label="FasterR-CNN_Resnet50_COCO"))
|
58 |
+
|
59 |
+
io_custom = gr.Interface(detect_with_resnet50Model_finetuning_motsynth, gr.Image(type="pil"), gr.Image(
|
60 |
+
type="file", shape=(1920, 1080), label="FasterR-CNN_Resnet50_FinteTuning_MOTSynth"))
|
61 |
+
|
62 |
+
gr.Parallel(io_baseline, io_custom, title=title,
|
63 |
+
description=description, examples=examples).launch(enable_queue=True)
|
configs/__init__.py
ADDED
File without changes
|
configs/path_cfg.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
IN_COLAB = False
|
6 |
+
if 'COLAB_GPU' in os.environ:
|
7 |
+
IN_COLAB=True
|
8 |
+
|
9 |
+
cwd = os.getcwd()
|
10 |
+
|
11 |
+
if(IN_COLAB):
|
12 |
+
MOTSYNTH_ROOT = '/content/gdrive/MyDrive/CVCS/storage/MOTSynth'
|
13 |
+
MOTCHA_ROOT = '/content/gdrive/MyDrive/CVCS/storage/MOTChallenge'
|
14 |
+
OUTPUT_DIR = '/content/gdrive/MyDrive/CVCS/storage/motsynth_output'
|
15 |
+
else:
|
16 |
+
# windows config
|
17 |
+
MOTSYNTH_ROOT = cwd + '\storage\MOTSynth'
|
18 |
+
MOTCHA_ROOT = cwd + '\storage\MOTChallenge'
|
19 |
+
OUTPUT_DIR = cwd + '\storage\motsynth_output'
|
deps/win/conda_environment.yml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cvcspw
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- ca-certificates=2022.07.19=haa95532_0
|
6 |
+
- certifi=2022.6.15=py38haa95532_0
|
7 |
+
- openssl=1.1.1q=h2bbff1b_0
|
8 |
+
- pip=22.1.2=py38haa95532_0
|
9 |
+
- python=3.8.13=h6244533_0
|
10 |
+
- setuptools=63.4.1=py38haa95532_0
|
11 |
+
- sqlite=3.39.2=h2bbff1b_0
|
12 |
+
- vc=14.2=h21ff451_1
|
13 |
+
- vs2015_runtime=14.27.29016=h5e58377_2
|
14 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
15 |
+
- wincertstore=0.2=py38haa95532_2
|
16 |
+
- pip:
|
17 |
+
- charset-normalizer==2.1.1
|
18 |
+
- coloredlogs==15.0.1
|
19 |
+
- cycler==0.11.0
|
20 |
+
- fonttools==4.37.1
|
21 |
+
- humanfriendly==10.0
|
22 |
+
- idna==3.3
|
23 |
+
- kiwisolver==1.4.4
|
24 |
+
- matplotlib==3.5.3
|
25 |
+
- numpy==1.23.2
|
26 |
+
- packaging==21.3
|
27 |
+
- pandas==1.4.4
|
28 |
+
- pillow==9.2.0
|
29 |
+
- pycocotools==2.0.4
|
30 |
+
- pyparsing==3.0.9
|
31 |
+
- pyreadline3==3.4.1
|
32 |
+
- python-dateutil==2.8.2
|
33 |
+
- pytz==2022.2.1
|
34 |
+
- requests==2.28.1
|
35 |
+
- seaborn==0.12.0
|
36 |
+
- six==1.16.0
|
37 |
+
- torch==1.12.1+cu116
|
38 |
+
- torchaudio==0.12.1+cu116
|
39 |
+
- torchvision==0.13.1+cu116
|
40 |
+
- typing-extensions==4.3.0
|
41 |
+
- urllib3==1.26.12
|
42 |
+
prefix: C:\Users\matte\anaconda3\envs\cvcspw
|
deps/win/conda_requirements.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: win-64
|
4 |
+
ca-certificates=2022.07.19=haa95532_0
|
5 |
+
certifi=2022.6.15=py38haa95532_0
|
6 |
+
charset-normalizer=2.1.1=pypi_0
|
7 |
+
coloredlogs=15.0.1=pypi_0
|
8 |
+
cycler=0.11.0=pypi_0
|
9 |
+
fonttools=4.37.1=pypi_0
|
10 |
+
humanfriendly=10.0=pypi_0
|
11 |
+
idna=3.3=pypi_0
|
12 |
+
kiwisolver=1.4.4=pypi_0
|
13 |
+
matplotlib=3.5.3=pypi_0
|
14 |
+
numpy=1.23.2=pypi_0
|
15 |
+
openssl=1.1.1q=h2bbff1b_0
|
16 |
+
packaging=21.3=pypi_0
|
17 |
+
pandas=1.4.4=pypi_0
|
18 |
+
pillow=9.2.0=pypi_0
|
19 |
+
pip=22.1.2=py38haa95532_0
|
20 |
+
pycocotools=2.0.4=pypi_0
|
21 |
+
pyparsing=3.0.9=pypi_0
|
22 |
+
pyreadline3=3.4.1=pypi_0
|
23 |
+
python=3.8.13=h6244533_0
|
24 |
+
python-dateutil=2.8.2=pypi_0
|
25 |
+
pytz=2022.2.1=pypi_0
|
26 |
+
requests=2.28.1=pypi_0
|
27 |
+
seaborn=0.12.0=pypi_0
|
28 |
+
setuptools=63.4.1=py38haa95532_0
|
29 |
+
six=1.16.0=pypi_0
|
30 |
+
sqlite=3.39.2=h2bbff1b_0
|
31 |
+
torch=1.12.1+cu116=pypi_0
|
32 |
+
torchaudio=0.12.1+cu116=pypi_0
|
33 |
+
torchvision=0.13.1+cu116=pypi_0
|
34 |
+
typing-extensions=4.3.0=pypi_0
|
35 |
+
urllib3=1.26.12=pypi_0
|
36 |
+
vc=14.2=h21ff451_1
|
37 |
+
vs2015_runtime=14.27.29016=h5e58377_2
|
38 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
39 |
+
wincertstore=0.2=py38haa95532_2
|
deps/win/pip_requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
certifi==2022.6.15
|
2 |
+
charset-normalizer==2.1.1
|
3 |
+
coloredlogs==15.0.1
|
4 |
+
cycler==0.11.0
|
5 |
+
fonttools==4.37.1
|
6 |
+
humanfriendly==10.0
|
7 |
+
idna==3.3
|
8 |
+
kiwisolver==1.4.4
|
9 |
+
matplotlib==3.5.3
|
10 |
+
numpy==1.23.2
|
11 |
+
packaging==21.3
|
12 |
+
pandas==1.4.4
|
13 |
+
Pillow==9.2.0
|
14 |
+
pip==22.1.2
|
15 |
+
pycocotools==2.0.4
|
16 |
+
pyparsing==3.0.9
|
17 |
+
pyreadline3==3.4.1
|
18 |
+
python-dateutil==2.8.2
|
19 |
+
pytz==2022.2.1
|
20 |
+
requests==2.28.1
|
21 |
+
seaborn==0.12.0
|
22 |
+
setuptools==63.4.1
|
23 |
+
six==1.16.0
|
24 |
+
torch==1.12.1
|
25 |
+
torchaudio==0.12.1
|
26 |
+
torchvision==0.13.1
|
27 |
+
typing_extensions==4.3.0
|
28 |
+
urllib3==1.26.12
|
29 |
+
wheel==0.37.1
|
30 |
+
wincertstore==0.2
|
notebook/colab/detector_show.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebook/colab/train_detector.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/evaluate_detector.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
python -m tools.train_detector --model-eval "d://cvcspw/storage/motsynth_output/detection_logs/fasterrcnn_training/checkpoint.pth" --test-only
|
scripts/inference_detector.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
python -m tools.inference_detector --model-path ./storage/motsynth_output/detection_logs/fasterrcnn_training_2/checkpoint.pth
|
scripts/train_detector.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
python -m tools.train_detector
|
src/__init__.py
ADDED
File without changes
|
src/detection/__init__.py
ADDED
File without changes
|
src/detection/graph_utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torchvision.transforms.functional as F
|
3 |
+
from torchvision import transforms
|
4 |
+
from typing import DefaultDict
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import matplotlib
|
7 |
+
import torch
|
8 |
+
import logging
|
9 |
+
from torchvision.utils import draw_bounding_boxes
|
10 |
+
matplotlib.style.use('ggplot')
|
11 |
+
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
|
12 |
+
logging.getLogger('PIL').setLevel(logging.CRITICAL)
|
13 |
+
|
14 |
+
|
15 |
+
def save_plot(train_loss_list, label, output_dir):
|
16 |
+
"""
|
17 |
+
Function to save the loss plot to disk.
|
18 |
+
"""
|
19 |
+
# Loss plots.
|
20 |
+
plt.figure(figsize=(10, 7))
|
21 |
+
plt.plot(
|
22 |
+
train_loss_list, linestyle='-',
|
23 |
+
label=label
|
24 |
+
)
|
25 |
+
plt.xlabel('Epochs')
|
26 |
+
plt.ylabel('Loss')
|
27 |
+
plt.legend()
|
28 |
+
plt.savefig(f"{output_dir}/{label}.png")
|
29 |
+
|
30 |
+
|
31 |
+
def save_train_loss_plot(train_loss_dict: DefaultDict, output_dir):
|
32 |
+
"""
|
33 |
+
Function to save the loss plots to disk.
|
34 |
+
"""
|
35 |
+
for key in train_loss_dict.keys():
|
36 |
+
save_plot(train_loss_dict[key], key, output_dir)
|
37 |
+
|
38 |
+
|
39 |
+
def show(imgs):
|
40 |
+
if not isinstance(imgs, list):
|
41 |
+
imgs = [imgs]
|
42 |
+
fig, axs = plt.subplots(nrows=len(imgs), ncols=1,
|
43 |
+
figsize=(45, 21), squeeze=False)
|
44 |
+
for i, img in enumerate(imgs):
|
45 |
+
img = img.detach()
|
46 |
+
img = F.to_pil_image(img)
|
47 |
+
img = np.asarray(img)
|
48 |
+
axs[i, 0].imshow(img)
|
49 |
+
axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|
50 |
+
plt.show()
|
51 |
+
|
52 |
+
|
53 |
+
def plot_img_tensor(img_tensor):
|
54 |
+
transforms.ToPILImage()(img_tensor).show()
|
55 |
+
|
56 |
+
|
57 |
+
def show_img(data_loader, model, device, th=0.7):
|
58 |
+
for imgs, target in data_loader:
|
59 |
+
with torch.no_grad():
|
60 |
+
prediction = model([imgs[0].to(device)])[0]
|
61 |
+
plot_img_tensor(add_bbox(imgs[0], prediction, th))
|
62 |
+
plot_img_tensor(add_bbox(imgs[0], target[0]['boxes']))
|
63 |
+
break
|
64 |
+
|
65 |
+
|
66 |
+
def add_bbox(img, output, th=None):
|
67 |
+
img_canvas = img.clone()
|
68 |
+
img_canvas = torch.clip(img*255, 0, 255)
|
69 |
+
img_canvas = img_canvas.type(torch.uint8)
|
70 |
+
|
71 |
+
if th == None:
|
72 |
+
img_with_bbbox = draw_bounding_boxes(
|
73 |
+
img_canvas, boxes=output, width=4)
|
74 |
+
else:
|
75 |
+
mask = (output["scores"] > th) & (output["labels"] == 1)
|
76 |
+
scores_list = [score for score in (
|
77 |
+
output["scores"][mask]).tolist()]
|
78 |
+
labels_list = [str(label) for label in (
|
79 |
+
output["labels"][mask]).tolist()]
|
80 |
+
labels = ["person" for label in labels_list if label == "1"]
|
81 |
+
assert len(labels) == len(scores_list) == len(labels_list)
|
82 |
+
|
83 |
+
for i in range(0, len(labels)):
|
84 |
+
labels[i] = f"{labels[i]}:{scores_list[i]:.3f}"
|
85 |
+
img_with_bbbox = draw_bounding_boxes(
|
86 |
+
img_canvas, boxes=output["boxes"][mask], labels=labels, width=4)
|
87 |
+
return img_with_bbbox
|
src/detection/model_factory.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FasterRCNN, FastRCNNPredictor
|
4 |
+
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
|
5 |
+
from torchvision.models.resnet import ResNet50_Weights
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
def set_seeds(seed: int = 42):
|
11 |
+
"""Sets random sets for torch operations.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
seed (int, optional): Random seed to set. Defaults to 42.
|
15 |
+
"""
|
16 |
+
# Set the seed for general torch operations
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
# Set the seed for CUDA torch operations (ones that happen on the GPU)
|
19 |
+
torch.cuda.manual_seed(seed)
|
20 |
+
|
21 |
+
|
22 |
+
class ModelFactory:
|
23 |
+
@staticmethod
|
24 |
+
def get_model(name, weights, backbone, backbone_weights, trainable_backbone_layers):
|
25 |
+
logger.debug(f"get_model -> model:{name}")
|
26 |
+
|
27 |
+
if name == "fasterrcnn_resnet50_fpn":
|
28 |
+
# backbone = backbone
|
29 |
+
model_weights = FasterRCNN_ResNet50_FPN_V2_Weights[weights]
|
30 |
+
model_backbone_weights = ResNet50_Weights[backbone_weights]
|
31 |
+
# trainable_backbone_layers = 1
|
32 |
+
model: FasterRCNN = fasterrcnn_resnet50_fpn_v2(
|
33 |
+
weights=model_weights, backbone_name=backbone, weights_backbone=model_backbone_weights, trainable_backbone_layers=trainable_backbone_layers)
|
34 |
+
|
35 |
+
# for param in model.rpn.parameters():
|
36 |
+
# param.requires_grad = False
|
37 |
+
# for param in model.roi_heads.parameters():
|
38 |
+
# param.requires_grad = False
|
39 |
+
# for param in model.backbone.fpn.parameters():
|
40 |
+
# param.requires_grad = False
|
41 |
+
|
42 |
+
set_seeds()
|
43 |
+
|
44 |
+
num_classes = 2 # 1 class (person) + background
|
45 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
46 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(
|
47 |
+
in_features, num_classes)
|
48 |
+
|
49 |
+
else:
|
50 |
+
logger.error(
|
51 |
+
"Please, provide a valid model as argument. Select one of the following: fasterrcnn_resnet50_fpn.")
|
52 |
+
raise ValueError(name)
|
53 |
+
|
54 |
+
return model
|
src/detection/mot_dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from src.detection.vision.coco_utils import ConvertCocoPolysToMask, CocoDetection, _coco_remove_images_without_annotations
|
4 |
+
from src.detection.vision.transforms import Compose
|
5 |
+
|
6 |
+
|
7 |
+
class UpdateIsCrowd(object):
|
8 |
+
def __init__(self, min_size, min_vis=0.2):
|
9 |
+
self.min_size = min_size
|
10 |
+
self.min_vis = min_vis
|
11 |
+
|
12 |
+
def __call__(self, image, target):
|
13 |
+
for i, ann in enumerate(target['annotations']):
|
14 |
+
bbox = ann['bbox']
|
15 |
+
bbox_too_small = max(bbox[-1], bbox[-2]) < self.min_size
|
16 |
+
|
17 |
+
if 'vis' in ann:
|
18 |
+
vis = ann['vis']
|
19 |
+
|
20 |
+
elif 'keypoints' in ann:
|
21 |
+
vis = (np.array(ann['keypoints'])[2::3] == 2).mean().round(2)
|
22 |
+
|
23 |
+
else:
|
24 |
+
raise RuntimeError(
|
25 |
+
"The given annotations have no visibility measure. Are you sure you want to proceed?")
|
26 |
+
|
27 |
+
not_vis = vis < self.min_vis
|
28 |
+
target['annotations'][i]['iscrowd'] = max(
|
29 |
+
ann['iscrowd'], int(bbox_too_small), int(not_vis))
|
30 |
+
|
31 |
+
return image, target
|
32 |
+
|
33 |
+
|
34 |
+
def get_mot_dataset(img_folder, ann_file, transforms, min_size=25, min_vis=0.2):
|
35 |
+
t = [UpdateIsCrowd(min_size=min_size, min_vis=min_vis),
|
36 |
+
ConvertCocoPolysToMask()]
|
37 |
+
|
38 |
+
if transforms is not None:
|
39 |
+
t.append(transforms)
|
40 |
+
transforms = Compose(t)
|
41 |
+
|
42 |
+
dataset = CocoDetection(img_folder=img_folder,
|
43 |
+
ann_file=ann_file,
|
44 |
+
transforms=transforms)
|
45 |
+
|
46 |
+
dataset = _coco_remove_images_without_annotations(dataset)
|
47 |
+
|
48 |
+
return dataset
|
src/detection/vision/README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Object detection reference training scripts
|
2 |
+
|
3 |
+
This folder contains reference training scripts for object detection.
|
4 |
+
They serve as a log of how to train specific models, to provide baseline
|
5 |
+
training and evaluation scripts to quickly bootstrap research.
|
6 |
+
|
7 |
+
To execute the example commands below you must install the following:
|
8 |
+
|
9 |
+
```
|
10 |
+
cython
|
11 |
+
pycocotools
|
12 |
+
matplotlib
|
13 |
+
```
|
14 |
+
|
15 |
+
You must modify the following flags:
|
16 |
+
|
17 |
+
`--data-path=/path/to/coco/dataset`
|
18 |
+
|
19 |
+
`--nproc_per_node=<number_of_gpus_available>`
|
20 |
+
|
21 |
+
Except otherwise noted, all models have been trained on 8x V100 GPUs.
|
22 |
+
|
23 |
+
### Faster R-CNN ResNet-50 FPN
|
24 |
+
```
|
25 |
+
torchrun --nproc_per_node=8 train.py\
|
26 |
+
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
|
27 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
28 |
+
```
|
29 |
+
|
30 |
+
### Faster R-CNN MobileNetV3-Large FPN
|
31 |
+
```
|
32 |
+
torchrun --nproc_per_node=8 train.py\
|
33 |
+
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
|
34 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
35 |
+
```
|
36 |
+
|
37 |
+
### Faster R-CNN MobileNetV3-Large 320 FPN
|
38 |
+
```
|
39 |
+
torchrun --nproc_per_node=8 train.py\
|
40 |
+
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
|
41 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
42 |
+
```
|
43 |
+
|
44 |
+
### FCOS ResNet-50 FPN
|
45 |
+
```
|
46 |
+
torchrun --nproc_per_node=8 train.py\
|
47 |
+
--dataset coco --model fcos_resnet50_fpn --epochs 26\
|
48 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
|
49 |
+
```
|
50 |
+
|
51 |
+
### RetinaNet
|
52 |
+
```
|
53 |
+
torchrun --nproc_per_node=8 train.py\
|
54 |
+
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
|
55 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
|
56 |
+
```
|
57 |
+
|
58 |
+
### SSD300 VGG16
|
59 |
+
```
|
60 |
+
torchrun --nproc_per_node=8 train.py\
|
61 |
+
--dataset coco --model ssd300_vgg16 --epochs 120\
|
62 |
+
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
|
63 |
+
--weight-decay 0.0005 --data-augmentation ssd
|
64 |
+
```
|
65 |
+
|
66 |
+
### SSDlite320 MobileNetV3-Large
|
67 |
+
```
|
68 |
+
torchrun --nproc_per_node=8 train.py\
|
69 |
+
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
|
70 |
+
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
|
71 |
+
--weight-decay 0.00004 --data-augmentation ssdlite
|
72 |
+
```
|
73 |
+
|
74 |
+
|
75 |
+
### Mask R-CNN
|
76 |
+
```
|
77 |
+
torchrun --nproc_per_node=8 train.py\
|
78 |
+
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
|
79 |
+
--lr-steps 16 22 --aspect-ratio-group-factor 3
|
80 |
+
```
|
81 |
+
|
82 |
+
|
83 |
+
### Keypoint R-CNN
|
84 |
+
```
|
85 |
+
torchrun --nproc_per_node=8 train.py\
|
86 |
+
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
|
87 |
+
--lr-steps 36 43 --aspect-ratio-group-factor 3
|
88 |
+
```
|
src/detection/vision/__init__.py
ADDED
File without changes
|
src/detection/vision/coco_eval.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import io
|
3 |
+
from contextlib import redirect_stdout
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pycocotools.mask as mask_util
|
7 |
+
import torch
|
8 |
+
from . import utils
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
from pycocotools.cocoeval import COCOeval
|
11 |
+
|
12 |
+
|
13 |
+
class CocoEvaluator:
|
14 |
+
def __init__(self, coco_gt, iou_types):
|
15 |
+
assert isinstance(iou_types, (list, tuple))
|
16 |
+
coco_gt = copy.deepcopy(coco_gt)
|
17 |
+
self.coco_gt = coco_gt
|
18 |
+
|
19 |
+
self.iou_types = iou_types
|
20 |
+
self.coco_eval = {}
|
21 |
+
for iou_type in iou_types:
|
22 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
23 |
+
|
24 |
+
self.img_ids = []
|
25 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
26 |
+
|
27 |
+
def update(self, predictions):
|
28 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
29 |
+
self.img_ids.extend(img_ids)
|
30 |
+
|
31 |
+
for iou_type in self.iou_types:
|
32 |
+
results = self.prepare(predictions, iou_type)
|
33 |
+
with redirect_stdout(io.StringIO()):
|
34 |
+
coco_dt = COCO.loadRes(
|
35 |
+
self.coco_gt, results) if results else COCO()
|
36 |
+
coco_eval = self.coco_eval[iou_type]
|
37 |
+
|
38 |
+
coco_eval.cocoDt = coco_dt
|
39 |
+
coco_eval.params.imgIds = list(img_ids)
|
40 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
41 |
+
|
42 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
43 |
+
|
44 |
+
def synchronize_between_processes(self):
|
45 |
+
for iou_type in self.iou_types:
|
46 |
+
self.eval_imgs[iou_type] = np.concatenate(
|
47 |
+
self.eval_imgs[iou_type], 2)
|
48 |
+
create_common_coco_eval(
|
49 |
+
self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
50 |
+
|
51 |
+
def accumulate(self):
|
52 |
+
for coco_eval in self.coco_eval.values():
|
53 |
+
coco_eval.accumulate()
|
54 |
+
|
55 |
+
def summarize(self):
|
56 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
57 |
+
print(f"IoU metric: {iou_type}")
|
58 |
+
coco_eval.summarize()
|
59 |
+
|
60 |
+
def prepare(self, predictions, iou_type):
|
61 |
+
if iou_type == "bbox":
|
62 |
+
return self.prepare_for_coco_detection(predictions)
|
63 |
+
if iou_type == "segm":
|
64 |
+
return self.prepare_for_coco_segmentation(predictions)
|
65 |
+
if iou_type == "keypoints":
|
66 |
+
return self.prepare_for_coco_keypoint(predictions)
|
67 |
+
raise ValueError(f"Unknown iou type {iou_type}")
|
68 |
+
|
69 |
+
def prepare_for_coco_detection(self, predictions):
|
70 |
+
coco_results = []
|
71 |
+
for original_id, prediction in predictions.items():
|
72 |
+
if len(prediction) == 0:
|
73 |
+
continue
|
74 |
+
|
75 |
+
boxes = prediction["boxes"]
|
76 |
+
boxes = convert_to_xywh(boxes).tolist()
|
77 |
+
scores = prediction["scores"].tolist()
|
78 |
+
labels = prediction["labels"].tolist()
|
79 |
+
|
80 |
+
coco_results.extend(
|
81 |
+
[
|
82 |
+
{
|
83 |
+
"image_id": original_id,
|
84 |
+
"category_id": labels[k],
|
85 |
+
"bbox": box,
|
86 |
+
"score": scores[k],
|
87 |
+
}
|
88 |
+
for k, box in enumerate(boxes)
|
89 |
+
]
|
90 |
+
)
|
91 |
+
return coco_results
|
92 |
+
|
93 |
+
def prepare_for_coco_segmentation(self, predictions):
|
94 |
+
coco_results = []
|
95 |
+
for original_id, prediction in predictions.items():
|
96 |
+
if len(prediction) == 0:
|
97 |
+
continue
|
98 |
+
|
99 |
+
scores = prediction["scores"]
|
100 |
+
labels = prediction["labels"]
|
101 |
+
masks = prediction["masks"]
|
102 |
+
|
103 |
+
masks = masks > 0.5
|
104 |
+
|
105 |
+
scores = prediction["scores"].tolist()
|
106 |
+
labels = prediction["labels"].tolist()
|
107 |
+
|
108 |
+
rles = [
|
109 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
|
110 |
+
]
|
111 |
+
for rle in rles:
|
112 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
113 |
+
|
114 |
+
coco_results.extend(
|
115 |
+
[
|
116 |
+
{
|
117 |
+
"image_id": original_id,
|
118 |
+
"category_id": labels[k],
|
119 |
+
"segmentation": rle,
|
120 |
+
"score": scores[k],
|
121 |
+
}
|
122 |
+
for k, rle in enumerate(rles)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
return coco_results
|
126 |
+
|
127 |
+
def prepare_for_coco_keypoint(self, predictions):
|
128 |
+
coco_results = []
|
129 |
+
for original_id, prediction in predictions.items():
|
130 |
+
if len(prediction) == 0:
|
131 |
+
continue
|
132 |
+
|
133 |
+
boxes = prediction["boxes"]
|
134 |
+
boxes = convert_to_xywh(boxes).tolist()
|
135 |
+
scores = prediction["scores"].tolist()
|
136 |
+
labels = prediction["labels"].tolist()
|
137 |
+
keypoints = prediction["keypoints"]
|
138 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
139 |
+
|
140 |
+
coco_results.extend(
|
141 |
+
[
|
142 |
+
{
|
143 |
+
"image_id": original_id,
|
144 |
+
"category_id": labels[k],
|
145 |
+
"keypoints": keypoint,
|
146 |
+
"score": scores[k],
|
147 |
+
}
|
148 |
+
for k, keypoint in enumerate(keypoints)
|
149 |
+
]
|
150 |
+
)
|
151 |
+
return coco_results
|
152 |
+
|
153 |
+
|
154 |
+
def convert_to_xywh(boxes):
|
155 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
156 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
157 |
+
|
158 |
+
|
159 |
+
def merge(img_ids, eval_imgs):
|
160 |
+
all_img_ids = utils.all_gather(img_ids)
|
161 |
+
all_eval_imgs = utils.all_gather(eval_imgs)
|
162 |
+
|
163 |
+
merged_img_ids = []
|
164 |
+
for p in all_img_ids:
|
165 |
+
merged_img_ids.extend(p)
|
166 |
+
|
167 |
+
merged_eval_imgs = []
|
168 |
+
for p in all_eval_imgs:
|
169 |
+
merged_eval_imgs.append(p)
|
170 |
+
|
171 |
+
merged_img_ids = np.array(merged_img_ids)
|
172 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
173 |
+
|
174 |
+
# keep only unique (and in sorted order) images
|
175 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
176 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
177 |
+
|
178 |
+
return merged_img_ids, merged_eval_imgs
|
179 |
+
|
180 |
+
|
181 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
182 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
183 |
+
img_ids = list(img_ids)
|
184 |
+
eval_imgs = list(eval_imgs.flatten())
|
185 |
+
|
186 |
+
coco_eval.evalImgs = eval_imgs
|
187 |
+
coco_eval.params.imgIds = img_ids
|
188 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
189 |
+
|
190 |
+
|
191 |
+
def evaluate(imgs):
|
192 |
+
with redirect_stdout(io.StringIO()):
|
193 |
+
imgs.evaluate()
|
194 |
+
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
|
src/detection/vision/coco_utils.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import torchvision
|
7 |
+
from . import transforms as T
|
8 |
+
from pycocotools import mask as coco_mask
|
9 |
+
from pycocotools.coco import COCO
|
10 |
+
|
11 |
+
|
12 |
+
class FilterAndRemapCocoCategories:
|
13 |
+
def __init__(self, categories, remap=True):
|
14 |
+
self.categories = categories
|
15 |
+
self.remap = remap
|
16 |
+
|
17 |
+
def __call__(self, image, target):
|
18 |
+
anno = target["annotations"]
|
19 |
+
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
20 |
+
if not self.remap:
|
21 |
+
target["annotations"] = anno
|
22 |
+
return image, target
|
23 |
+
anno = copy.deepcopy(anno)
|
24 |
+
for obj in anno:
|
25 |
+
obj["category_id"] = self.categories.index(obj["category_id"])
|
26 |
+
target["annotations"] = anno
|
27 |
+
return image, target
|
28 |
+
|
29 |
+
|
30 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
31 |
+
masks = []
|
32 |
+
for polygons in segmentations:
|
33 |
+
if isinstance(polygons['counts'], list):
|
34 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
35 |
+
|
36 |
+
else:
|
37 |
+
rles = [polygons]
|
38 |
+
|
39 |
+
mask = coco_mask.decode(rles)
|
40 |
+
if len(mask.shape) < 3:
|
41 |
+
mask = mask[..., None]
|
42 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
43 |
+
mask = mask.any(dim=2)
|
44 |
+
masks.append(mask)
|
45 |
+
if masks:
|
46 |
+
masks = torch.stack(masks, dim=0)
|
47 |
+
else:
|
48 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
49 |
+
return masks
|
50 |
+
|
51 |
+
|
52 |
+
class ConvertCocoPolysToMask:
|
53 |
+
def __call__(self, image, target):
|
54 |
+
w, h = image.size
|
55 |
+
|
56 |
+
image_id = target["image_id"]
|
57 |
+
image_id = torch.tensor([image_id])
|
58 |
+
|
59 |
+
anno = target["annotations"]
|
60 |
+
anno = [obj for obj in anno if obj["iscrowd"] == 0]
|
61 |
+
|
62 |
+
boxes = [obj["bbox"] for obj in anno]
|
63 |
+
# guard against no boxes via resizing
|
64 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
65 |
+
boxes[:, 2:] += boxes[:, :2]
|
66 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
67 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
68 |
+
|
69 |
+
classes = [obj["category_id"] for obj in anno]
|
70 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
71 |
+
|
72 |
+
# masks=None
|
73 |
+
if anno and 'segmentation' in anno[0]:
|
74 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
75 |
+
|
76 |
+
else:
|
77 |
+
segmentations = []
|
78 |
+
|
79 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
80 |
+
|
81 |
+
keypoints = None
|
82 |
+
if anno and "keypoints" in anno[0]:
|
83 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
84 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
85 |
+
num_keypoints = keypoints.shape[0]
|
86 |
+
if num_keypoints:
|
87 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
88 |
+
|
89 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
90 |
+
boxes = boxes[keep]
|
91 |
+
classes = classes[keep]
|
92 |
+
if masks is not None and masks.shape[0] > 0:
|
93 |
+
masks = masks[keep]
|
94 |
+
|
95 |
+
if keypoints is not None:
|
96 |
+
keypoints = keypoints[keep]
|
97 |
+
|
98 |
+
target = {}
|
99 |
+
target["boxes"] = boxes
|
100 |
+
target["labels"] = classes
|
101 |
+
target["masks"] = masks
|
102 |
+
|
103 |
+
target["image_id"] = image_id
|
104 |
+
if keypoints is not None:
|
105 |
+
target["keypoints"] = keypoints
|
106 |
+
|
107 |
+
# for conversion to coco api
|
108 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
109 |
+
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
|
110 |
+
target["area"] = area
|
111 |
+
target["iscrowd"] = iscrowd
|
112 |
+
#target['vis'] = [obj['vis'] for obj in anno]
|
113 |
+
|
114 |
+
return image, target
|
115 |
+
|
116 |
+
|
117 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
118 |
+
def _has_only_empty_bbox(anno):
|
119 |
+
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
120 |
+
|
121 |
+
def _count_visible_keypoints(anno):
|
122 |
+
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
123 |
+
|
124 |
+
min_keypoints_per_image = 10
|
125 |
+
|
126 |
+
def _has_valid_annotation(anno):
|
127 |
+
# if it's empty, there is no annotation
|
128 |
+
if len(anno) == 0:
|
129 |
+
return False
|
130 |
+
# if all boxes have close to zero area, there is no annotation
|
131 |
+
if _has_only_empty_bbox(anno):
|
132 |
+
return False
|
133 |
+
# keypoints task have a slight different critera for considering
|
134 |
+
# if an annotation is valid
|
135 |
+
if "keypoints" not in anno[0]:
|
136 |
+
return True
|
137 |
+
# for keypoint detection tasks, only consider valid images those
|
138 |
+
# containing at least min_keypoints_per_image
|
139 |
+
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
140 |
+
return True
|
141 |
+
return False
|
142 |
+
|
143 |
+
assert isinstance(dataset, torchvision.datasets.CocoDetection)
|
144 |
+
ids = []
|
145 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
146 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
147 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
148 |
+
if cat_list:
|
149 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
150 |
+
if _has_valid_annotation(anno):
|
151 |
+
ids.append(ds_idx)
|
152 |
+
|
153 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
154 |
+
return dataset
|
155 |
+
|
156 |
+
|
157 |
+
def convert_to_coco_api(ds):
|
158 |
+
coco_ds = COCO()
|
159 |
+
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
|
160 |
+
ann_id = 1
|
161 |
+
dataset = {"images": [], "categories": [], "annotations": []}
|
162 |
+
categories = set()
|
163 |
+
for img_idx in range(len(ds)):
|
164 |
+
# find better way to get target
|
165 |
+
# targets = ds.get_annotations(img_idx)
|
166 |
+
img, targets = ds[img_idx]
|
167 |
+
image_id = targets["image_id"].item()
|
168 |
+
img_dict = {}
|
169 |
+
img_dict["id"] = image_id
|
170 |
+
img_dict["height"] = img.shape[-2]
|
171 |
+
img_dict["width"] = img.shape[-1]
|
172 |
+
dataset["images"].append(img_dict)
|
173 |
+
bboxes = targets["boxes"].clone()
|
174 |
+
bboxes[:, 2:] -= bboxes[:, :2]
|
175 |
+
bboxes = bboxes.tolist()
|
176 |
+
labels = targets["labels"].tolist()
|
177 |
+
areas = targets["area"].tolist()
|
178 |
+
iscrowd = targets["iscrowd"].tolist()
|
179 |
+
if "masks" in targets:
|
180 |
+
masks = targets["masks"]
|
181 |
+
# make masks Fortran contiguous for coco_mask
|
182 |
+
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
183 |
+
if "keypoints" in targets:
|
184 |
+
keypoints = targets["keypoints"]
|
185 |
+
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
|
186 |
+
num_objs = len(bboxes)
|
187 |
+
for i in range(num_objs):
|
188 |
+
ann = {}
|
189 |
+
ann["image_id"] = image_id
|
190 |
+
ann["bbox"] = bboxes[i]
|
191 |
+
ann["category_id"] = labels[i]
|
192 |
+
categories.add(labels[i])
|
193 |
+
ann["area"] = areas[i]
|
194 |
+
ann["iscrowd"] = iscrowd[i]
|
195 |
+
ann["id"] = ann_id
|
196 |
+
if "masks" in targets:
|
197 |
+
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
|
198 |
+
if "keypoints" in targets:
|
199 |
+
ann["keypoints"] = keypoints[i]
|
200 |
+
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
|
201 |
+
dataset["annotations"].append(ann)
|
202 |
+
ann_id += 1
|
203 |
+
dataset["categories"] = [{"id": i} for i in sorted(categories)]
|
204 |
+
coco_ds.dataset = dataset
|
205 |
+
coco_ds.createIndex()
|
206 |
+
return coco_ds
|
207 |
+
|
208 |
+
|
209 |
+
def get_coco_api_from_dataset(dataset):
|
210 |
+
for _ in range(10):
|
211 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
212 |
+
break
|
213 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
214 |
+
dataset = dataset.dataset
|
215 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
216 |
+
return dataset.coco
|
217 |
+
return convert_to_coco_api(dataset)
|
218 |
+
|
219 |
+
|
220 |
+
class CocoDetection(torchvision.datasets.CocoDetection):
|
221 |
+
def __init__(self, img_folder, ann_file, transforms):
|
222 |
+
super().__init__(img_folder, ann_file)
|
223 |
+
self._transforms = transforms
|
224 |
+
|
225 |
+
def __getitem__(self, idx):
|
226 |
+
img, target = super().__getitem__(idx)
|
227 |
+
image_id = self.ids[idx]
|
228 |
+
target = dict(image_id=image_id, annotations=target)
|
229 |
+
if self._transforms is not None:
|
230 |
+
img, target = self._transforms(img, target)
|
231 |
+
return img, target
|
232 |
+
|
233 |
+
|
234 |
+
def get_coco(root, image_set, transforms, mode="instances"):
|
235 |
+
anno_file_template = "{}_{}2017.json"
|
236 |
+
PATHS = {
|
237 |
+
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
|
238 |
+
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
|
239 |
+
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
|
240 |
+
}
|
241 |
+
|
242 |
+
t = [ConvertCocoPolysToMask()]
|
243 |
+
|
244 |
+
if transforms is not None:
|
245 |
+
t.append(transforms)
|
246 |
+
transforms = T.Compose(t)
|
247 |
+
|
248 |
+
img_folder, ann_file = PATHS[image_set]
|
249 |
+
img_folder = os.path.join(root, img_folder)
|
250 |
+
ann_file = os.path.join(root, ann_file)
|
251 |
+
|
252 |
+
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
|
253 |
+
|
254 |
+
if image_set == "train":
|
255 |
+
dataset = _coco_remove_images_without_annotations(dataset)
|
256 |
+
|
257 |
+
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
|
258 |
+
|
259 |
+
return dataset
|
260 |
+
|
261 |
+
|
262 |
+
def get_coco_kp(root, image_set, transforms):
|
263 |
+
return get_coco(root, image_set, transforms, mode="person_keypoints")
|
src/detection/vision/engine.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision.models.detection.faster_rcnn
|
7 |
+
from . import utils
|
8 |
+
from . import coco_eval
|
9 |
+
from . import coco_utils
|
10 |
+
|
11 |
+
|
12 |
+
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
|
13 |
+
model.train()
|
14 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
15 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(
|
16 |
+
window_size=1, fmt="{value:.6f}"))
|
17 |
+
header = f"Epoch: [{epoch}]"
|
18 |
+
lr_scheduler = None
|
19 |
+
if epoch == 0:
|
20 |
+
warmup_factor = 1.0 / 1000
|
21 |
+
warmup_iters = min(1000, len(data_loader) - 1)
|
22 |
+
|
23 |
+
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
24 |
+
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
|
25 |
+
)
|
26 |
+
|
27 |
+
losses_dict = {
|
28 |
+
"lr": [],
|
29 |
+
"loss": [],
|
30 |
+
# loss rpn
|
31 |
+
"loss_objectness": [],
|
32 |
+
"loss_rpn_box_reg": [],
|
33 |
+
# roi heads
|
34 |
+
"loss_classifier": [],
|
35 |
+
"loss_box_reg": [],
|
36 |
+
}
|
37 |
+
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
|
38 |
+
try:
|
39 |
+
images = list(image.to(device) for image in images)
|
40 |
+
targets = [{k: v.to(device) for k, v in t.items()}
|
41 |
+
for t in targets]
|
42 |
+
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
43 |
+
loss_dict = model(images, targets)
|
44 |
+
losses = sum(loss for loss in loss_dict.values())
|
45 |
+
|
46 |
+
# reduce losses over all GPUs for logging purposes
|
47 |
+
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
48 |
+
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
49 |
+
|
50 |
+
loss_value = losses_reduced.item()
|
51 |
+
|
52 |
+
# if problem with loss see below
|
53 |
+
if not math.isfinite(loss_value):
|
54 |
+
print(f"Loss is {loss_value}, stopping training")
|
55 |
+
print(loss_dict_reduced)
|
56 |
+
sys.exit(1)
|
57 |
+
|
58 |
+
except Exception as exp:
|
59 |
+
print("ERROR", str(exp))
|
60 |
+
torch.save({'img': images, 'targets': targets},
|
61 |
+
'error_causing_batch.pth')
|
62 |
+
raise RuntimeError
|
63 |
+
|
64 |
+
optimizer.zero_grad()
|
65 |
+
if scaler is not None:
|
66 |
+
scaler.scale(losses).backward()
|
67 |
+
scaler.step(optimizer)
|
68 |
+
scaler.update()
|
69 |
+
else:
|
70 |
+
losses.backward()
|
71 |
+
optimizer.step()
|
72 |
+
|
73 |
+
if lr_scheduler is not None:
|
74 |
+
lr_scheduler.step()
|
75 |
+
|
76 |
+
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
|
77 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
78 |
+
for name, meter in metric_logger.meters.items():
|
79 |
+
losses_dict[name].append(meter.global_avg)
|
80 |
+
return metric_logger, losses_dict
|
81 |
+
|
82 |
+
|
83 |
+
def _get_iou_types(model):
|
84 |
+
model_without_ddp = model
|
85 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
86 |
+
model_without_ddp = model.module
|
87 |
+
iou_types = ["bbox"]
|
88 |
+
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
|
89 |
+
iou_types.append("segm")
|
90 |
+
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
|
91 |
+
iou_types.append("keypoints")
|
92 |
+
return iou_types
|
93 |
+
|
94 |
+
|
95 |
+
@ torch.inference_mode()
|
96 |
+
def evaluate(model, data_loader, device, iou_types=None):
|
97 |
+
n_threads = torch.get_num_threads()
|
98 |
+
# FIXME remove this and make paste_masks_in_image run on the GPU
|
99 |
+
torch.set_num_threads(1)
|
100 |
+
cpu_device = torch.device("cpu")
|
101 |
+
model.eval()
|
102 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
103 |
+
header = "Test:"
|
104 |
+
coco = coco_utils.get_coco_api_from_dataset(data_loader.dataset)
|
105 |
+
if iou_types is None:
|
106 |
+
iou_types = _get_iou_types(model)
|
107 |
+
coco_evaluator = coco_eval.CocoEvaluator(coco, iou_types)
|
108 |
+
for images, targets in metric_logger.log_every(data_loader, 100, header):
|
109 |
+
images = list(img.to(device) for img in images)
|
110 |
+
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
torch.cuda.synchronize()
|
113 |
+
model_time = time.time()
|
114 |
+
outputs = model(images)
|
115 |
+
|
116 |
+
outputs = [{k: v.to(cpu_device) for k, v in t.items()}
|
117 |
+
for t in outputs]
|
118 |
+
model_time = time.time() - model_time
|
119 |
+
|
120 |
+
res = {target["image_id"].item(): output for target,
|
121 |
+
output in zip(targets, outputs)}
|
122 |
+
evaluator_time = time.time()
|
123 |
+
coco_evaluator.update(res)
|
124 |
+
evaluator_time = time.time() - evaluator_time
|
125 |
+
metric_logger.update(model_time=model_time,
|
126 |
+
evaluator_time=evaluator_time)
|
127 |
+
|
128 |
+
# gather the stats from all processes
|
129 |
+
metric_logger.synchronize_between_processes()
|
130 |
+
print("Averaged stats:", metric_logger)
|
131 |
+
coco_evaluator.synchronize_between_processes()
|
132 |
+
|
133 |
+
# accumulate predictions from all images and print table with results
|
134 |
+
coco_evaluator.accumulate()
|
135 |
+
coco_evaluator.summarize()
|
136 |
+
torch.set_num_threads(n_threads)
|
137 |
+
return coco_evaluator
|
src/detection/vision/group_by_aspect_ratio.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
from collections import defaultdict
|
5 |
+
from itertools import repeat, chain
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import torchvision
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
13 |
+
from torch.utils.model_zoo import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def _repeat_to_at_least(iterable, n):
|
17 |
+
repeat_times = math.ceil(n / len(iterable))
|
18 |
+
repeated = chain.from_iterable(repeat(iterable, repeat_times))
|
19 |
+
return list(repeated)
|
20 |
+
|
21 |
+
|
22 |
+
class GroupedBatchSampler(BatchSampler):
|
23 |
+
"""
|
24 |
+
Wraps another sampler to yield a mini-batch of indices.
|
25 |
+
It enforces that the batch only contain elements from the same group.
|
26 |
+
It also tries to provide mini-batches which follows an ordering which is
|
27 |
+
as close as possible to the ordering from the original sampler.
|
28 |
+
Args:
|
29 |
+
sampler (Sampler): Base sampler.
|
30 |
+
group_ids (list[int]): If the sampler produces indices in range [0, N),
|
31 |
+
`group_ids` must be a list of `N` ints which contains the group id of each sample.
|
32 |
+
The group ids must be a continuous set of integers starting from
|
33 |
+
0, i.e. they must be in the range [0, num_groups).
|
34 |
+
batch_size (int): Size of mini-batch.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, sampler, group_ids, batch_size):
|
38 |
+
if not isinstance(sampler, Sampler):
|
39 |
+
raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
|
40 |
+
self.sampler = sampler
|
41 |
+
self.group_ids = group_ids
|
42 |
+
self.batch_size = batch_size
|
43 |
+
|
44 |
+
def __iter__(self):
|
45 |
+
buffer_per_group = defaultdict(list)
|
46 |
+
samples_per_group = defaultdict(list)
|
47 |
+
|
48 |
+
num_batches = 0
|
49 |
+
for idx in self.sampler:
|
50 |
+
group_id = self.group_ids[idx]
|
51 |
+
buffer_per_group[group_id].append(idx)
|
52 |
+
samples_per_group[group_id].append(idx)
|
53 |
+
if len(buffer_per_group[group_id]) == self.batch_size:
|
54 |
+
yield buffer_per_group[group_id]
|
55 |
+
num_batches += 1
|
56 |
+
del buffer_per_group[group_id]
|
57 |
+
assert len(buffer_per_group[group_id]) < self.batch_size
|
58 |
+
|
59 |
+
# now we have run out of elements that satisfy
|
60 |
+
# the group criteria, let's return the remaining
|
61 |
+
# elements so that the size of the sampler is
|
62 |
+
# deterministic
|
63 |
+
expected_num_batches = len(self)
|
64 |
+
num_remaining = expected_num_batches - num_batches
|
65 |
+
if num_remaining > 0:
|
66 |
+
# for the remaining batches, take first the buffers with largest number
|
67 |
+
# of elements
|
68 |
+
for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
|
69 |
+
remaining = self.batch_size - len(buffer_per_group[group_id])
|
70 |
+
samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
|
71 |
+
buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
|
72 |
+
assert len(buffer_per_group[group_id]) == self.batch_size
|
73 |
+
yield buffer_per_group[group_id]
|
74 |
+
num_remaining -= 1
|
75 |
+
if num_remaining == 0:
|
76 |
+
break
|
77 |
+
assert num_remaining == 0
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.sampler) // self.batch_size
|
81 |
+
|
82 |
+
|
83 |
+
def _compute_aspect_ratios_slow(dataset, indices=None):
|
84 |
+
print(
|
85 |
+
"Your dataset doesn't support the fast path for "
|
86 |
+
"computing the aspect ratios, so will iterate over "
|
87 |
+
"the full dataset and load every image instead. "
|
88 |
+
"This might take some time..."
|
89 |
+
)
|
90 |
+
if indices is None:
|
91 |
+
indices = range(len(dataset))
|
92 |
+
|
93 |
+
class SubsetSampler(Sampler):
|
94 |
+
def __init__(self, indices):
|
95 |
+
self.indices = indices
|
96 |
+
|
97 |
+
def __iter__(self):
|
98 |
+
return iter(self.indices)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.indices)
|
102 |
+
|
103 |
+
sampler = SubsetSampler(indices)
|
104 |
+
data_loader = torch.utils.data.DataLoader(
|
105 |
+
dataset,
|
106 |
+
batch_size=1,
|
107 |
+
sampler=sampler,
|
108 |
+
num_workers=14, # you might want to increase it for faster processing
|
109 |
+
collate_fn=lambda x: x[0],
|
110 |
+
)
|
111 |
+
aspect_ratios = []
|
112 |
+
with tqdm(total=len(dataset)) as pbar:
|
113 |
+
for _i, (img, _) in enumerate(data_loader):
|
114 |
+
pbar.update(1)
|
115 |
+
height, width = img.shape[-2:]
|
116 |
+
aspect_ratio = float(width) / float(height)
|
117 |
+
aspect_ratios.append(aspect_ratio)
|
118 |
+
return aspect_ratios
|
119 |
+
|
120 |
+
|
121 |
+
def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
|
122 |
+
if indices is None:
|
123 |
+
indices = range(len(dataset))
|
124 |
+
aspect_ratios = []
|
125 |
+
for i in indices:
|
126 |
+
height, width = dataset.get_height_and_width(i)
|
127 |
+
aspect_ratio = float(width) / float(height)
|
128 |
+
aspect_ratios.append(aspect_ratio)
|
129 |
+
return aspect_ratios
|
130 |
+
|
131 |
+
|
132 |
+
def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
|
133 |
+
if indices is None:
|
134 |
+
indices = range(len(dataset))
|
135 |
+
aspect_ratios = []
|
136 |
+
for i in indices:
|
137 |
+
img_info = dataset.coco.imgs[dataset.ids[i]]
|
138 |
+
aspect_ratio = float(img_info["width"]) / float(img_info["height"])
|
139 |
+
aspect_ratios.append(aspect_ratio)
|
140 |
+
return aspect_ratios
|
141 |
+
|
142 |
+
|
143 |
+
def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
|
144 |
+
if indices is None:
|
145 |
+
indices = range(len(dataset))
|
146 |
+
aspect_ratios = []
|
147 |
+
for i in indices:
|
148 |
+
# this doesn't load the data into memory, because PIL loads it lazily
|
149 |
+
width, height = Image.open(dataset.images[i]).size
|
150 |
+
aspect_ratio = float(width) / float(height)
|
151 |
+
aspect_ratios.append(aspect_ratio)
|
152 |
+
return aspect_ratios
|
153 |
+
|
154 |
+
|
155 |
+
def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
|
156 |
+
if indices is None:
|
157 |
+
indices = range(len(dataset))
|
158 |
+
|
159 |
+
ds_indices = [dataset.indices[i] for i in indices]
|
160 |
+
return compute_aspect_ratios(dataset.dataset, ds_indices)
|
161 |
+
|
162 |
+
|
163 |
+
def compute_aspect_ratios(dataset, indices=None):
|
164 |
+
if hasattr(dataset, "get_height_and_width"):
|
165 |
+
return _compute_aspect_ratios_custom_dataset(dataset, indices)
|
166 |
+
|
167 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
168 |
+
return _compute_aspect_ratios_coco_dataset(dataset, indices)
|
169 |
+
|
170 |
+
if isinstance(dataset, torchvision.datasets.VOCDetection):
|
171 |
+
return _compute_aspect_ratios_voc_dataset(dataset, indices)
|
172 |
+
|
173 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
174 |
+
return _compute_aspect_ratios_subset_dataset(dataset, indices)
|
175 |
+
|
176 |
+
# slow path
|
177 |
+
return _compute_aspect_ratios_slow(dataset, indices)
|
178 |
+
|
179 |
+
|
180 |
+
def _quantize(x, bins):
|
181 |
+
bins = copy.deepcopy(bins)
|
182 |
+
bins = sorted(bins)
|
183 |
+
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
|
184 |
+
return quantized
|
185 |
+
|
186 |
+
|
187 |
+
def create_aspect_ratio_groups(dataset, k=0):
|
188 |
+
aspect_ratios = compute_aspect_ratios(dataset)
|
189 |
+
bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
|
190 |
+
groups = _quantize(aspect_ratios, bins)
|
191 |
+
# count number of elements per group
|
192 |
+
counts = np.unique(groups, return_counts=True)[1]
|
193 |
+
fbins = [0] + bins + [np.inf]
|
194 |
+
print(f"Using {fbins} as bins for aspect ratio quantization")
|
195 |
+
print(f"Count of instances per bin: {counts}")
|
196 |
+
return groups
|
src/detection/vision/mot_data.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configparser
|
2 |
+
import csv
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pycocotools.mask as rletools
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
class MOTObjDetect(torch.utils.data.Dataset):
|
14 |
+
""" Data class for the Multiple Object Tracking Dataset
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, root, transforms=None, vis_threshold=0.25,
|
18 |
+
split_seqs=None, frame_range_start=0.0, frame_range_end=1.0):
|
19 |
+
self.root = root
|
20 |
+
self.transforms = transforms
|
21 |
+
self._vis_threshold = vis_threshold
|
22 |
+
self._classes = ('background', 'pedestrian')
|
23 |
+
self._img_paths = []
|
24 |
+
self._split_seqs = split_seqs
|
25 |
+
|
26 |
+
self.mots_gts = {}
|
27 |
+
for f in sorted(os.listdir(root)):
|
28 |
+
path = os.path.join(root, f)
|
29 |
+
|
30 |
+
if not os.path.isdir(path):
|
31 |
+
continue
|
32 |
+
|
33 |
+
if split_seqs is not None and f not in split_seqs:
|
34 |
+
continue
|
35 |
+
|
36 |
+
config_file = os.path.join(path, 'seqinfo.ini')
|
37 |
+
|
38 |
+
assert os.path.exists(config_file), \
|
39 |
+
'Path does not exist: {}'.format(config_file)
|
40 |
+
|
41 |
+
config = configparser.ConfigParser()
|
42 |
+
config.read(config_file)
|
43 |
+
seq_len = int(config['Sequence']['seqLength'])
|
44 |
+
im_ext = config['Sequence']['imExt']
|
45 |
+
im_dir = config['Sequence']['imDir']
|
46 |
+
|
47 |
+
img_dir = os.path.join(path, im_dir)
|
48 |
+
|
49 |
+
start_frame = int(frame_range_start * seq_len)
|
50 |
+
end_frame = int(frame_range_end * seq_len)
|
51 |
+
|
52 |
+
# for i in range(seq_len):
|
53 |
+
for i in range(start_frame, end_frame):
|
54 |
+
img_path = os.path.join(img_dir, f"{i + 1:06d}{im_ext}")
|
55 |
+
assert os.path.exists(
|
56 |
+
img_path), f'Path does not exist: {img_path}'
|
57 |
+
self._img_paths.append(img_path)
|
58 |
+
|
59 |
+
# print(len(self._img_paths))
|
60 |
+
|
61 |
+
if self.has_masks:
|
62 |
+
gt_file = os.path.join(
|
63 |
+
os.path.dirname(img_dir), 'gt', 'gt.txt')
|
64 |
+
self.mots_gts[gt_file] = load_mots_gt(gt_file)
|
65 |
+
|
66 |
+
def __str__(self):
|
67 |
+
if self._split_seqs is None:
|
68 |
+
return self.root
|
69 |
+
return f"{self.root}/{self._split_seqs}"
|
70 |
+
|
71 |
+
@property
|
72 |
+
def num_classes(self):
|
73 |
+
return len(self._classes)
|
74 |
+
|
75 |
+
def _get_annotation(self, idx):
|
76 |
+
"""
|
77 |
+
"""
|
78 |
+
|
79 |
+
if 'test' in self.root:
|
80 |
+
num_objs = 0
|
81 |
+
boxes = torch.zeros((num_objs, 4), dtype=torch.float32)
|
82 |
+
|
83 |
+
return {'boxes': boxes,
|
84 |
+
'labels': torch.ones((num_objs,), dtype=torch.int64),
|
85 |
+
'image_id': torch.tensor([idx]),
|
86 |
+
'area': (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
|
87 |
+
'iscrowd': torch.zeros((num_objs,), dtype=torch.int64),
|
88 |
+
'visibilities': torch.zeros((num_objs), dtype=torch.float32)}
|
89 |
+
|
90 |
+
img_path = self._img_paths[idx]
|
91 |
+
file_index = int(os.path.basename(img_path).split('.')[0])
|
92 |
+
|
93 |
+
gt_file = os.path.join(os.path.dirname(
|
94 |
+
os.path.dirname(img_path)), 'gt', 'gt.txt')
|
95 |
+
|
96 |
+
assert os.path.exists(gt_file), \
|
97 |
+
'GT file does not exist: {}'.format(gt_file)
|
98 |
+
|
99 |
+
bounding_boxes = []
|
100 |
+
|
101 |
+
if self.has_masks:
|
102 |
+
mask_objects_per_frame = self.mots_gts[gt_file][file_index]
|
103 |
+
masks = []
|
104 |
+
for mask_object in mask_objects_per_frame:
|
105 |
+
# class_id = 1 is car
|
106 |
+
# class_id = 2 is pedestrian
|
107 |
+
# class_id = 10 IGNORE
|
108 |
+
if mask_object.class_id in [1, 10] or not rletools.area(mask_object.mask):
|
109 |
+
continue
|
110 |
+
|
111 |
+
bbox = rletools.toBbox(mask_object.mask)
|
112 |
+
x1, y1, w, h = [int(c) for c in bbox]
|
113 |
+
|
114 |
+
bb = {}
|
115 |
+
bb['bb_left'] = x1
|
116 |
+
bb['bb_top'] = y1
|
117 |
+
bb['bb_width'] = w
|
118 |
+
bb['bb_height'] = h
|
119 |
+
|
120 |
+
# print(bb, rletools.area(mask_object.mask))
|
121 |
+
|
122 |
+
bb['visibility'] = 1.0
|
123 |
+
bb['track_id'] = mask_object.track_id
|
124 |
+
|
125 |
+
masks.append(rletools.decode(mask_object.mask))
|
126 |
+
bounding_boxes.append(bb)
|
127 |
+
else:
|
128 |
+
with open(gt_file, "r") as inf:
|
129 |
+
reader = csv.reader(inf, delimiter=',')
|
130 |
+
for row in reader:
|
131 |
+
visibility = float(row[8])
|
132 |
+
|
133 |
+
if int(row[0]) == file_index and int(row[6]) == 1 and int(row[7]) == 1 and visibility and visibility >= self._vis_threshold:
|
134 |
+
bb = {}
|
135 |
+
bb['bb_left'] = int(row[2])
|
136 |
+
bb['bb_top'] = int(row[3])
|
137 |
+
bb['bb_width'] = int(row[4])
|
138 |
+
bb['bb_height'] = int(row[5])
|
139 |
+
bb['visibility'] = float(row[8])
|
140 |
+
bb['track_id'] = int(row[1])
|
141 |
+
|
142 |
+
bounding_boxes.append(bb)
|
143 |
+
|
144 |
+
num_objs = len(bounding_boxes)
|
145 |
+
|
146 |
+
boxes = torch.zeros((num_objs, 4), dtype=torch.float32)
|
147 |
+
visibilities = torch.zeros((num_objs), dtype=torch.float32)
|
148 |
+
track_ids = torch.zeros((num_objs), dtype=torch.long)
|
149 |
+
|
150 |
+
for i, bb in enumerate(bounding_boxes):
|
151 |
+
# Make pixel indexes 0-based, should already be 0-based (or not)
|
152 |
+
x1 = bb['bb_left'] # - 1
|
153 |
+
y1 = bb['bb_top'] # - 1
|
154 |
+
# This -1 accounts for the width (width of 1 x1=x2)
|
155 |
+
x2 = x1 + bb['bb_width'] # - 1
|
156 |
+
y2 = y1 + bb['bb_height'] # - 1
|
157 |
+
|
158 |
+
boxes[i, 0] = x1
|
159 |
+
boxes[i, 1] = y1
|
160 |
+
boxes[i, 2] = x2
|
161 |
+
boxes[i, 3] = y2
|
162 |
+
visibilities[i] = bb['visibility']
|
163 |
+
track_ids[i] = bb['track_id']
|
164 |
+
|
165 |
+
annos = {'boxes': boxes,
|
166 |
+
'labels': torch.ones((num_objs,), dtype=torch.int64),
|
167 |
+
'image_id': torch.tensor([idx]),
|
168 |
+
'area': (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
|
169 |
+
'iscrowd': torch.zeros((num_objs,), dtype=torch.int64),
|
170 |
+
'visibilities': visibilities,
|
171 |
+
'track_ids': track_ids, }
|
172 |
+
|
173 |
+
if self.has_masks:
|
174 |
+
# annos['masks'] = torch.tensor(masks, dtype=torch.uint8)
|
175 |
+
annos['masks'] = torch.from_numpy(np.stack(masks))
|
176 |
+
return annos
|
177 |
+
|
178 |
+
@property
|
179 |
+
def has_masks(self):
|
180 |
+
return '/MOTS20/' in self.root
|
181 |
+
|
182 |
+
def __getitem__(self, idx):
|
183 |
+
# load images ad masks
|
184 |
+
img_path = self._img_paths[idx]
|
185 |
+
# mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
186 |
+
img = Image.open(img_path).convert("RGB")
|
187 |
+
|
188 |
+
target = self._get_annotation(idx)
|
189 |
+
|
190 |
+
if self.transforms is not None:
|
191 |
+
img, target = self.transforms(img, target)
|
192 |
+
|
193 |
+
return img, target
|
194 |
+
|
195 |
+
def __len__(self):
|
196 |
+
return len(self._img_paths)
|
197 |
+
|
198 |
+
def write_results_files(self, results, output_dir):
|
199 |
+
"""Write the detections in the format for MOT17Det sumbission
|
200 |
+
|
201 |
+
all_boxes[image] = N x 5 array of detections in (x1, y1, x2, y2, score)
|
202 |
+
|
203 |
+
Each file contains these lines:
|
204 |
+
<frame>, <id>, <bb_left>, <bb_top>, <bb_width>, <bb_height>, <conf>, <x>, <y>, <z>
|
205 |
+
|
206 |
+
Files to sumbit:
|
207 |
+
./MOT17-01.txt
|
208 |
+
./MOT17-02.txt
|
209 |
+
./MOT17-03.txt
|
210 |
+
./MOT17-04.txt
|
211 |
+
./MOT17-05.txt
|
212 |
+
./MOT17-06.txt
|
213 |
+
./MOT17-07.txt
|
214 |
+
./MOT17-08.txt
|
215 |
+
./MOT17-09.txt
|
216 |
+
./MOT17-10.txt
|
217 |
+
./MOT17-11.txt
|
218 |
+
./MOT17-12.txt
|
219 |
+
./MOT17-13.txt
|
220 |
+
./MOT17-14.txt
|
221 |
+
"""
|
222 |
+
|
223 |
+
#format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1"
|
224 |
+
|
225 |
+
files = {}
|
226 |
+
for image_id, res in results.items():
|
227 |
+
path = self._img_paths[image_id]
|
228 |
+
img1, name = osp.split(path)
|
229 |
+
# get image number out of name
|
230 |
+
frame = int(name.split('.')[0])
|
231 |
+
# smth like /train/MOT17-09-FRCNN or /train/MOT17-09
|
232 |
+
tmp = osp.dirname(img1)
|
233 |
+
# get the folder name of the sequence and split it
|
234 |
+
tmp = osp.basename(tmp).split('-')
|
235 |
+
# Now get the output name of the file
|
236 |
+
out = tmp[0]+'-'+tmp[1]+'.txt'
|
237 |
+
outfile = osp.join(output_dir, out)
|
238 |
+
|
239 |
+
# check if out in keys and create empty list if not
|
240 |
+
if outfile not in files.keys():
|
241 |
+
files[outfile] = []
|
242 |
+
|
243 |
+
if 'masks' in res:
|
244 |
+
delimiter = ' '
|
245 |
+
# print(torch.unique(res['masks'][0]))
|
246 |
+
# > 0.5 #res['masks'].bool()
|
247 |
+
masks = res['masks'].squeeze(dim=1)
|
248 |
+
|
249 |
+
index_map = torch.arange(masks.size(0))[:, None, None]
|
250 |
+
index_map = index_map.expand_as(masks)
|
251 |
+
|
252 |
+
masks = torch.logical_and(
|
253 |
+
# remove background
|
254 |
+
masks > 0.5,
|
255 |
+
# remove overlapp by largest probablity
|
256 |
+
index_map == masks.argmax(dim=0)
|
257 |
+
)
|
258 |
+
for res_i in range(len(masks)):
|
259 |
+
track_id = -1
|
260 |
+
if 'track_ids' in res:
|
261 |
+
track_id = res['track_ids'][res_i].item()
|
262 |
+
mask = masks[res_i]
|
263 |
+
mask = np.asfortranarray(mask)
|
264 |
+
|
265 |
+
rle_mask = rletools.encode(mask)
|
266 |
+
|
267 |
+
files[outfile].append(
|
268 |
+
[frame,
|
269 |
+
track_id,
|
270 |
+
2, # class pedestrian
|
271 |
+
mask.shape[0],
|
272 |
+
mask.shape[1],
|
273 |
+
rle_mask['counts'].decode(encoding='UTF-8')])
|
274 |
+
else:
|
275 |
+
delimiter = ','
|
276 |
+
for res_i in range(len(res['boxes'])):
|
277 |
+
track_id = -1
|
278 |
+
if 'track_ids' in res:
|
279 |
+
track_id = res['track_ids'][res_i].item()
|
280 |
+
box = res['boxes'][res_i]
|
281 |
+
score = res['scores'][res_i]
|
282 |
+
|
283 |
+
x1 = box[0].item()
|
284 |
+
y1 = box[1].item()
|
285 |
+
x2 = box[2].item()
|
286 |
+
y2 = box[3].item()
|
287 |
+
|
288 |
+
out = [frame, track_id, x1, y1, x2 - x1,
|
289 |
+
y2 - y1, score.item(), -1, -1, -1]
|
290 |
+
|
291 |
+
if 'keypoints' in res:
|
292 |
+
out.extend(res['keypoints'][res_i]
|
293 |
+
[:, :2].flatten().tolist())
|
294 |
+
out.extend(res['keypoints_scores']
|
295 |
+
[res_i].flatten().tolist())
|
296 |
+
|
297 |
+
files[outfile].append(out)
|
298 |
+
|
299 |
+
for k, v in files.items():
|
300 |
+
with open(k, "w") as of:
|
301 |
+
writer = csv.writer(of, delimiter=delimiter)
|
302 |
+
for d in v:
|
303 |
+
writer.writerow(d)
|
304 |
+
|
305 |
+
|
306 |
+
class SegmentedObject:
|
307 |
+
"""
|
308 |
+
Helper class for segmentation objects.
|
309 |
+
"""
|
310 |
+
|
311 |
+
def __init__(self, mask: dict, class_id: int, track_id: int, full_bbox=None) -> None:
|
312 |
+
self.mask = mask
|
313 |
+
self.class_id = class_id
|
314 |
+
self.track_id = track_id
|
315 |
+
self.full_bbox = full_bbox
|
316 |
+
|
317 |
+
|
318 |
+
def load_mots_gt(path: str) -> dict:
|
319 |
+
"""Load MOTS ground truth from path."""
|
320 |
+
objects_per_frame = {}
|
321 |
+
track_ids_per_frame = {} # Check that no frame contains two objects with same id
|
322 |
+
combined_mask_per_frame = {} # Check that no frame contains overlapping masks
|
323 |
+
|
324 |
+
with open(path, "r") as gt_file:
|
325 |
+
for line in gt_file:
|
326 |
+
line = line.strip()
|
327 |
+
fields = line.split(" ")
|
328 |
+
|
329 |
+
frame = int(fields[0])
|
330 |
+
if frame not in objects_per_frame:
|
331 |
+
objects_per_frame[frame] = []
|
332 |
+
# if frame not in track_ids_per_frame:
|
333 |
+
# track_ids_per_frame[frame] = set()
|
334 |
+
# if int(fields[1]) in track_ids_per_frame[frame]:
|
335 |
+
# assert False, f"Multiple objects with track id {fields[1]} in frame {fields[0]}"
|
336 |
+
# else:
|
337 |
+
# track_ids_per_frame[frame].add(int(fields[1]))
|
338 |
+
|
339 |
+
class_id = int(fields[2])
|
340 |
+
if not (class_id == 1 or class_id == 2 or class_id == 10):
|
341 |
+
assert False, "Unknown object class " + fields[2]
|
342 |
+
|
343 |
+
mask = {
|
344 |
+
'size': [int(fields[3]), int(fields[4])],
|
345 |
+
'counts': fields[5].encode(encoding='UTF-8')}
|
346 |
+
if frame not in combined_mask_per_frame:
|
347 |
+
combined_mask_per_frame[frame] = mask
|
348 |
+
elif rletools.area(rletools.merge([
|
349 |
+
combined_mask_per_frame[frame], mask],
|
350 |
+
intersect=True)):
|
351 |
+
assert False, "Objects with overlapping masks in frame " + \
|
352 |
+
fields[0]
|
353 |
+
else:
|
354 |
+
combined_mask_per_frame[frame] = rletools.merge(
|
355 |
+
[combined_mask_per_frame[frame], mask],
|
356 |
+
intersect=False)
|
357 |
+
|
358 |
+
full_bbox = None
|
359 |
+
if len(fields) == 10:
|
360 |
+
full_bbox = [int(fields[6]), int(fields[7]),
|
361 |
+
int(fields[8]), int(fields[9])]
|
362 |
+
|
363 |
+
objects_per_frame[frame].append(SegmentedObject(
|
364 |
+
mask,
|
365 |
+
class_id,
|
366 |
+
int(fields[1]),
|
367 |
+
full_bbox
|
368 |
+
))
|
369 |
+
|
370 |
+
return objects_per_frame
|
src/detection/vision/presets.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import transforms as T
|
3 |
+
|
4 |
+
|
5 |
+
class DetectionPresetTrain:
|
6 |
+
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
|
7 |
+
if data_augmentation == "hflip":
|
8 |
+
self.transforms = T.Compose(
|
9 |
+
[
|
10 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
11 |
+
T.PILToTensor(),
|
12 |
+
T.ConvertImageDtype(torch.float),
|
13 |
+
]
|
14 |
+
)
|
15 |
+
elif data_augmentation == "ssd":
|
16 |
+
self.transforms = T.Compose(
|
17 |
+
[
|
18 |
+
T.RandomPhotometricDistort(),
|
19 |
+
T.RandomZoomOut(fill=list(mean)),
|
20 |
+
T.RandomIoUCrop(),
|
21 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
22 |
+
T.PILToTensor(),
|
23 |
+
T.ConvertImageDtype(torch.float),
|
24 |
+
]
|
25 |
+
)
|
26 |
+
elif data_augmentation == "ssdlite":
|
27 |
+
self.transforms = T.Compose(
|
28 |
+
[
|
29 |
+
T.RandomIoUCrop(),
|
30 |
+
T.RandomHorizontalFlip(p=hflip_prob),
|
31 |
+
T.PILToTensor(),
|
32 |
+
T.ConvertImageDtype(torch.float),
|
33 |
+
]
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
raise ValueError(
|
37 |
+
f'Unknown data augmentation policy "{data_augmentation}"')
|
38 |
+
|
39 |
+
def __call__(self, img, target):
|
40 |
+
return self.transforms(img, target)
|
41 |
+
|
42 |
+
|
43 |
+
class DetectionPresetEval:
|
44 |
+
def __init__(self):
|
45 |
+
self.transforms = T.ToTensor()
|
46 |
+
|
47 |
+
def __call__(self, img, target):
|
48 |
+
return self.transforms(img, target)
|
src/detection/vision/transforms.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
from torchvision.transforms import transforms as T
|
8 |
+
|
9 |
+
|
10 |
+
def _flip_coco_person_keypoints(kps, width):
|
11 |
+
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
12 |
+
flipped_data = kps[:, flip_inds]
|
13 |
+
flipped_data[..., 0] = width - flipped_data[..., 0]
|
14 |
+
# Maintain COCO convention that if visibility == 0, then x, y = 0
|
15 |
+
inds = flipped_data[..., 2] == 0
|
16 |
+
flipped_data[inds] = 0
|
17 |
+
return flipped_data
|
18 |
+
|
19 |
+
|
20 |
+
class Compose:
|
21 |
+
def __init__(self, transforms):
|
22 |
+
self.transforms = transforms
|
23 |
+
|
24 |
+
def __call__(self, image, target):
|
25 |
+
for t in self.transforms:
|
26 |
+
image, target = t(image, target)
|
27 |
+
return image, target
|
28 |
+
|
29 |
+
|
30 |
+
class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
31 |
+
def forward(
|
32 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
33 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
34 |
+
if torch.rand(1) < self.p:
|
35 |
+
image = F.hflip(image)
|
36 |
+
if target is not None:
|
37 |
+
width, _ = F.get_image_size(image)
|
38 |
+
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
|
39 |
+
if "masks" in target:
|
40 |
+
target["masks"] = target["masks"].flip(-1)
|
41 |
+
if "keypoints" in target:
|
42 |
+
keypoints = target["keypoints"]
|
43 |
+
keypoints = _flip_coco_person_keypoints(keypoints, width)
|
44 |
+
target["keypoints"] = keypoints
|
45 |
+
return image, target
|
46 |
+
|
47 |
+
|
48 |
+
class ToTensor(nn.Module):
|
49 |
+
def forward(
|
50 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
51 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
52 |
+
image = F.pil_to_tensor(image)
|
53 |
+
image = F.convert_image_dtype(image)
|
54 |
+
return image, target
|
55 |
+
|
56 |
+
|
57 |
+
class PILToTensor(nn.Module):
|
58 |
+
def forward(
|
59 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
60 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
61 |
+
image = F.pil_to_tensor(image)
|
62 |
+
return image, target
|
63 |
+
|
64 |
+
|
65 |
+
class ConvertImageDtype(nn.Module):
|
66 |
+
def __init__(self, dtype: torch.dtype) -> None:
|
67 |
+
super().__init__()
|
68 |
+
self.dtype = dtype
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
72 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
73 |
+
image = F.convert_image_dtype(image, self.dtype)
|
74 |
+
return image, target
|
75 |
+
|
76 |
+
|
77 |
+
class RandomIoUCrop(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
min_scale: float = 0.3,
|
81 |
+
max_scale: float = 1.0,
|
82 |
+
min_aspect_ratio: float = 0.5,
|
83 |
+
max_aspect_ratio: float = 2.0,
|
84 |
+
sampler_options: Optional[List[float]] = None,
|
85 |
+
trials: int = 40,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
|
89 |
+
self.min_scale = min_scale
|
90 |
+
self.max_scale = max_scale
|
91 |
+
self.min_aspect_ratio = min_aspect_ratio
|
92 |
+
self.max_aspect_ratio = max_aspect_ratio
|
93 |
+
if sampler_options is None:
|
94 |
+
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
|
95 |
+
self.options = sampler_options
|
96 |
+
self.trials = trials
|
97 |
+
|
98 |
+
def forward(
|
99 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
100 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
101 |
+
if target is None:
|
102 |
+
raise ValueError("The targets can't be None for this transform.")
|
103 |
+
|
104 |
+
if isinstance(image, torch.Tensor):
|
105 |
+
if image.ndimension() not in {2, 3}:
|
106 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
107 |
+
elif image.ndimension() == 2:
|
108 |
+
image = image.unsqueeze(0)
|
109 |
+
|
110 |
+
orig_w, orig_h = F.get_image_size(image)
|
111 |
+
|
112 |
+
while True:
|
113 |
+
# sample an option
|
114 |
+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
|
115 |
+
min_jaccard_overlap = self.options[idx]
|
116 |
+
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
|
117 |
+
return image, target
|
118 |
+
|
119 |
+
for _ in range(self.trials):
|
120 |
+
# check the aspect ratio limitations
|
121 |
+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
|
122 |
+
new_w = int(orig_w * r[0])
|
123 |
+
new_h = int(orig_h * r[1])
|
124 |
+
aspect_ratio = new_w / new_h
|
125 |
+
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
|
126 |
+
continue
|
127 |
+
|
128 |
+
# check for 0 area crops
|
129 |
+
r = torch.rand(2)
|
130 |
+
left = int((orig_w - new_w) * r[0])
|
131 |
+
top = int((orig_h - new_h) * r[1])
|
132 |
+
right = left + new_w
|
133 |
+
bottom = top + new_h
|
134 |
+
if left == right or top == bottom:
|
135 |
+
continue
|
136 |
+
|
137 |
+
# check for any valid boxes with centers within the crop area
|
138 |
+
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
|
139 |
+
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
|
140 |
+
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
|
141 |
+
if not is_within_crop_area.any():
|
142 |
+
continue
|
143 |
+
|
144 |
+
# check at least 1 box with jaccard limitations
|
145 |
+
boxes = target["boxes"][is_within_crop_area]
|
146 |
+
ious = torchvision.ops.boxes.box_iou(
|
147 |
+
boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
|
148 |
+
)
|
149 |
+
if ious.max() < min_jaccard_overlap:
|
150 |
+
continue
|
151 |
+
|
152 |
+
# keep only valid boxes and perform cropping
|
153 |
+
target["boxes"] = boxes
|
154 |
+
target["labels"] = target["labels"][is_within_crop_area]
|
155 |
+
target["boxes"][:, 0::2] -= left
|
156 |
+
target["boxes"][:, 1::2] -= top
|
157 |
+
target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
|
158 |
+
target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
|
159 |
+
image = F.crop(image, top, left, new_h, new_w)
|
160 |
+
|
161 |
+
return image, target
|
162 |
+
|
163 |
+
|
164 |
+
class RandomZoomOut(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
if fill is None:
|
170 |
+
fill = [0.0, 0.0, 0.0]
|
171 |
+
self.fill = fill
|
172 |
+
self.side_range = side_range
|
173 |
+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
|
174 |
+
raise ValueError(f"Invalid canvas side range provided {side_range}.")
|
175 |
+
self.p = p
|
176 |
+
|
177 |
+
@torch.jit.unused
|
178 |
+
def _get_fill_value(self, is_pil):
|
179 |
+
# type: (bool) -> int
|
180 |
+
# We fake the type to make it work on JIT
|
181 |
+
return tuple(int(x) for x in self.fill) if is_pil else 0
|
182 |
+
|
183 |
+
def forward(
|
184 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
185 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
186 |
+
if isinstance(image, torch.Tensor):
|
187 |
+
if image.ndimension() not in {2, 3}:
|
188 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
189 |
+
elif image.ndimension() == 2:
|
190 |
+
image = image.unsqueeze(0)
|
191 |
+
|
192 |
+
if torch.rand(1) >= self.p:
|
193 |
+
return image, target
|
194 |
+
|
195 |
+
orig_w, orig_h = F.get_image_size(image)
|
196 |
+
|
197 |
+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
|
198 |
+
canvas_width = int(orig_w * r)
|
199 |
+
canvas_height = int(orig_h * r)
|
200 |
+
|
201 |
+
r = torch.rand(2)
|
202 |
+
left = int((canvas_width - orig_w) * r[0])
|
203 |
+
top = int((canvas_height - orig_h) * r[1])
|
204 |
+
right = canvas_width - (left + orig_w)
|
205 |
+
bottom = canvas_height - (top + orig_h)
|
206 |
+
|
207 |
+
if torch.jit.is_scripting():
|
208 |
+
fill = 0
|
209 |
+
else:
|
210 |
+
fill = self._get_fill_value(F._is_pil_image(image))
|
211 |
+
|
212 |
+
image = F.pad(image, [left, top, right, bottom], fill=fill)
|
213 |
+
if isinstance(image, torch.Tensor):
|
214 |
+
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
|
215 |
+
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
|
216 |
+
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
|
217 |
+
..., :, (left + orig_w) :
|
218 |
+
] = v
|
219 |
+
|
220 |
+
if target is not None:
|
221 |
+
target["boxes"][:, 0::2] += left
|
222 |
+
target["boxes"][:, 1::2] += top
|
223 |
+
|
224 |
+
return image, target
|
225 |
+
|
226 |
+
|
227 |
+
class RandomPhotometricDistort(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
contrast: Tuple[float] = (0.5, 1.5),
|
231 |
+
saturation: Tuple[float] = (0.5, 1.5),
|
232 |
+
hue: Tuple[float] = (-0.05, 0.05),
|
233 |
+
brightness: Tuple[float] = (0.875, 1.125),
|
234 |
+
p: float = 0.5,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self._brightness = T.ColorJitter(brightness=brightness)
|
238 |
+
self._contrast = T.ColorJitter(contrast=contrast)
|
239 |
+
self._hue = T.ColorJitter(hue=hue)
|
240 |
+
self._saturation = T.ColorJitter(saturation=saturation)
|
241 |
+
self.p = p
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
245 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
246 |
+
if isinstance(image, torch.Tensor):
|
247 |
+
if image.ndimension() not in {2, 3}:
|
248 |
+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
|
249 |
+
elif image.ndimension() == 2:
|
250 |
+
image = image.unsqueeze(0)
|
251 |
+
|
252 |
+
r = torch.rand(7)
|
253 |
+
|
254 |
+
if r[0] < self.p:
|
255 |
+
image = self._brightness(image)
|
256 |
+
|
257 |
+
contrast_before = r[1] < 0.5
|
258 |
+
if contrast_before:
|
259 |
+
if r[2] < self.p:
|
260 |
+
image = self._contrast(image)
|
261 |
+
|
262 |
+
if r[3] < self.p:
|
263 |
+
image = self._saturation(image)
|
264 |
+
|
265 |
+
if r[4] < self.p:
|
266 |
+
image = self._hue(image)
|
267 |
+
|
268 |
+
if not contrast_before:
|
269 |
+
if r[5] < self.p:
|
270 |
+
image = self._contrast(image)
|
271 |
+
|
272 |
+
if r[6] < self.p:
|
273 |
+
channels = F.get_image_num_channels(image)
|
274 |
+
permutation = torch.randperm(channels)
|
275 |
+
|
276 |
+
is_pil = F._is_pil_image(image)
|
277 |
+
if is_pil:
|
278 |
+
image = F.pil_to_tensor(image)
|
279 |
+
image = F.convert_image_dtype(image)
|
280 |
+
image = image[..., permutation, :, :]
|
281 |
+
if is_pil:
|
282 |
+
image = F.to_pil_image(image)
|
283 |
+
|
284 |
+
return image, target
|
src/detection/vision/utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class SmoothedValue:
|
12 |
+
"""Track a series of values and provide access to smoothed values over a
|
13 |
+
window or the global series average.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, window_size=20, fmt=None):
|
17 |
+
if fmt is None:
|
18 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
19 |
+
self.deque = deque(maxlen=window_size)
|
20 |
+
self.total = 0.0
|
21 |
+
self.count = 0
|
22 |
+
self.fmt = fmt
|
23 |
+
|
24 |
+
def update(self, value, n=1):
|
25 |
+
self.deque.append(value)
|
26 |
+
self.count += n
|
27 |
+
self.total += value * n
|
28 |
+
|
29 |
+
def synchronize_between_processes(self):
|
30 |
+
"""
|
31 |
+
Warning: does not synchronize the deque!
|
32 |
+
"""
|
33 |
+
if not is_dist_avail_and_initialized():
|
34 |
+
return
|
35 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
36 |
+
dist.barrier()
|
37 |
+
dist.all_reduce(t)
|
38 |
+
t = t.tolist()
|
39 |
+
self.count = int(t[0])
|
40 |
+
self.total = t[1]
|
41 |
+
|
42 |
+
@property
|
43 |
+
def median(self):
|
44 |
+
d = torch.tensor(list(self.deque))
|
45 |
+
return d.median().item()
|
46 |
+
|
47 |
+
@property
|
48 |
+
def avg(self):
|
49 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
50 |
+
return d.mean().item()
|
51 |
+
|
52 |
+
@property
|
53 |
+
def global_avg(self):
|
54 |
+
return self.total / self.count
|
55 |
+
|
56 |
+
@property
|
57 |
+
def max(self):
|
58 |
+
return max(self.deque)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def value(self):
|
62 |
+
return self.deque[-1]
|
63 |
+
|
64 |
+
def __str__(self):
|
65 |
+
return self.fmt.format(
|
66 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def all_gather(data):
|
71 |
+
"""
|
72 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
73 |
+
Args:
|
74 |
+
data: any picklable object
|
75 |
+
Returns:
|
76 |
+
list[data]: list of data gathered from each rank
|
77 |
+
"""
|
78 |
+
world_size = get_world_size()
|
79 |
+
if world_size == 1:
|
80 |
+
return [data]
|
81 |
+
data_list = [None] * world_size
|
82 |
+
dist.all_gather_object(data_list, data)
|
83 |
+
return data_list
|
84 |
+
|
85 |
+
|
86 |
+
def reduce_dict(input_dict, average=True):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
input_dict (dict): all the values will be reduced
|
90 |
+
average (bool): whether to do average or sum
|
91 |
+
Reduce the values in the dictionary from all processes so that all processes
|
92 |
+
have the averaged results. Returns a dict with the same fields as
|
93 |
+
input_dict, after reduction.
|
94 |
+
"""
|
95 |
+
world_size = get_world_size()
|
96 |
+
if world_size < 2:
|
97 |
+
return input_dict
|
98 |
+
with torch.inference_mode():
|
99 |
+
names = []
|
100 |
+
values = []
|
101 |
+
# sort the keys so that they are consistent across processes
|
102 |
+
for k in sorted(input_dict.keys()):
|
103 |
+
names.append(k)
|
104 |
+
values.append(input_dict[k])
|
105 |
+
values = torch.stack(values, dim=0)
|
106 |
+
dist.all_reduce(values)
|
107 |
+
if average:
|
108 |
+
values /= world_size
|
109 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
110 |
+
return reduced_dict
|
111 |
+
|
112 |
+
|
113 |
+
class MetricLogger:
|
114 |
+
def __init__(self, delimiter="\t"):
|
115 |
+
self.meters = defaultdict(SmoothedValue)
|
116 |
+
self.delimiter = delimiter
|
117 |
+
|
118 |
+
def update(self, **kwargs):
|
119 |
+
for k, v in kwargs.items():
|
120 |
+
if isinstance(v, torch.Tensor):
|
121 |
+
v = v.item()
|
122 |
+
assert isinstance(v, (float, int))
|
123 |
+
self.meters[k].update(v)
|
124 |
+
|
125 |
+
def __getattr__(self, attr):
|
126 |
+
if attr in self.meters:
|
127 |
+
return self.meters[attr]
|
128 |
+
if attr in self.__dict__:
|
129 |
+
return self.__dict__[attr]
|
130 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
131 |
+
|
132 |
+
def __str__(self):
|
133 |
+
loss_str = []
|
134 |
+
for name, meter in self.meters.items():
|
135 |
+
loss_str.append(f"{name}: {str(meter)}")
|
136 |
+
return self.delimiter.join(loss_str)
|
137 |
+
|
138 |
+
def synchronize_between_processes(self):
|
139 |
+
for meter in self.meters.values():
|
140 |
+
meter.synchronize_between_processes()
|
141 |
+
|
142 |
+
def add_meter(self, name, meter):
|
143 |
+
self.meters[name] = meter
|
144 |
+
|
145 |
+
def log_every(self, iterable, print_freq, header=None):
|
146 |
+
i = 0
|
147 |
+
if not header:
|
148 |
+
header = ""
|
149 |
+
start_time = time.time()
|
150 |
+
end = time.time()
|
151 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
152 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
153 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
154 |
+
if torch.cuda.is_available():
|
155 |
+
log_msg = self.delimiter.join(
|
156 |
+
[
|
157 |
+
header,
|
158 |
+
"[{0" + space_fmt + "}/{1}]",
|
159 |
+
"eta: {eta}",
|
160 |
+
"{meters}",
|
161 |
+
"time: {time}",
|
162 |
+
"data: {data}",
|
163 |
+
"max mem: {memory:.0f}",
|
164 |
+
]
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
log_msg = self.delimiter.join(
|
168 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
169 |
+
)
|
170 |
+
MB = 1024.0 * 1024.0
|
171 |
+
for obj in iterable:
|
172 |
+
data_time.update(time.time() - end)
|
173 |
+
yield obj
|
174 |
+
iter_time.update(time.time() - end)
|
175 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
176 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
177 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
178 |
+
if torch.cuda.is_available():
|
179 |
+
print(
|
180 |
+
log_msg.format(
|
181 |
+
i,
|
182 |
+
len(iterable),
|
183 |
+
eta=eta_string,
|
184 |
+
meters=str(self),
|
185 |
+
time=str(iter_time),
|
186 |
+
data=str(data_time),
|
187 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
188 |
+
)
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
print(
|
192 |
+
log_msg.format(
|
193 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
194 |
+
)
|
195 |
+
)
|
196 |
+
i += 1
|
197 |
+
end = time.time()
|
198 |
+
total_time = time.time() - start_time
|
199 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
200 |
+
print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
|
201 |
+
|
202 |
+
|
203 |
+
def collate_fn(batch):
|
204 |
+
return tuple(zip(*batch))
|
205 |
+
|
206 |
+
|
207 |
+
def mkdir(path):
|
208 |
+
try:
|
209 |
+
os.makedirs(path, exist_ok=True)
|
210 |
+
except OSError as e:
|
211 |
+
if e.errno != errno.EEXIST:
|
212 |
+
raise
|
213 |
+
|
214 |
+
|
215 |
+
def setup_for_distributed(is_master):
|
216 |
+
"""
|
217 |
+
This function disables printing when not in master process
|
218 |
+
"""
|
219 |
+
import builtins as __builtin__
|
220 |
+
|
221 |
+
builtin_print = __builtin__.print
|
222 |
+
|
223 |
+
def print(*args, **kwargs):
|
224 |
+
force = kwargs.pop("force", False)
|
225 |
+
if is_master or force:
|
226 |
+
builtin_print(*args, **kwargs)
|
227 |
+
|
228 |
+
__builtin__.print = print
|
229 |
+
|
230 |
+
|
231 |
+
def is_dist_avail_and_initialized():
|
232 |
+
if not dist.is_available():
|
233 |
+
return False
|
234 |
+
if not dist.is_initialized():
|
235 |
+
return False
|
236 |
+
return True
|
237 |
+
|
238 |
+
|
239 |
+
def get_world_size():
|
240 |
+
if not is_dist_avail_and_initialized():
|
241 |
+
return 1
|
242 |
+
return dist.get_world_size()
|
243 |
+
|
244 |
+
|
245 |
+
def get_rank():
|
246 |
+
if not is_dist_avail_and_initialized():
|
247 |
+
return 0
|
248 |
+
return dist.get_rank()
|
249 |
+
|
250 |
+
|
251 |
+
def is_main_process():
|
252 |
+
return get_rank() == 0
|
253 |
+
|
254 |
+
|
255 |
+
def save_on_master(*args, **kwargs):
|
256 |
+
if is_main_process():
|
257 |
+
torch.save(*args, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
def init_distributed_mode(args):
|
261 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
262 |
+
args.rank = int(os.environ["RANK"])
|
263 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
264 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
265 |
+
elif "SLURM_PROCID" in os.environ:
|
266 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
267 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
268 |
+
else:
|
269 |
+
print("Not using distributed mode")
|
270 |
+
args.distributed = False
|
271 |
+
return
|
272 |
+
|
273 |
+
args.distributed = True
|
274 |
+
|
275 |
+
torch.cuda.set_device(args.gpu)
|
276 |
+
args.dist_backend = "nccl"
|
277 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
278 |
+
torch.distributed.init_process_group(
|
279 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
280 |
+
)
|
281 |
+
torch.distributed.barrier()
|
282 |
+
setup_for_distributed(args.rank == 0)
|
tools/__init__.py
ADDED
File without changes
|
tools/anns/combine_anns.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def parse_args():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument(
|
11 |
+
'--motsynth-path', help="Directory path containing the 'annotations' directory with .json files")
|
12 |
+
parser.add_argument(
|
13 |
+
'--save-path', help='Root file in which the new annoation files will be stored. If not provided, motsynth-root will be used')
|
14 |
+
parser.add_argument('--save-dir', default='comb_annotations',
|
15 |
+
help="name of directory within 'save-path'in which MOTS annotation files will be stored")
|
16 |
+
parser.add_argument('--subsample', default=20, type=int,
|
17 |
+
help="Frame subsampling rate. If e.g. 10 is selected, then we will select 1 in 10 frames")
|
18 |
+
parser.add_argument('--split', default='train',
|
19 |
+
help="Name of split (i.e. set of sequences being merged) being used. A file named '{args.split}.txt needs to exist in the splits dir")
|
20 |
+
parser.add_argument(
|
21 |
+
'--name', help="Name of the split that file that will be generated. If not provided, the split name will be used")
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
if args.save_path is None:
|
26 |
+
args.save_path = args.motsynth_path
|
27 |
+
|
28 |
+
if args.name is None:
|
29 |
+
args.name = args.split
|
30 |
+
|
31 |
+
assert args.subsample > 0, "Argument '--subsample' needs to be a positive integer. Set it to 1 to use every frame"
|
32 |
+
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def read_split_file(path):
|
37 |
+
with open(path, 'r') as file:
|
38 |
+
seq_list = file.read().splitlines()
|
39 |
+
|
40 |
+
return seq_list
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Determine which sequences to use
|
45 |
+
seqs = [seq.zfill(3) for seq in read_split_file(osp.join(
|
46 |
+
osp.dirname(os.path.abspath(__file__)), 'splits', f'{args.split}.txt'))]
|
47 |
+
comb_anns = {'images': [], 'annotations': [],
|
48 |
+
'categories': None, 'info': {}}
|
49 |
+
|
50 |
+
for seq in tqdm.tqdm(seqs):
|
51 |
+
ann_path = osp.join(args.motsynth_path, 'annotations', f'{seq}.json')
|
52 |
+
with open(ann_path) as f:
|
53 |
+
seq_ann = json.load(f)
|
54 |
+
|
55 |
+
# Subsample images and annotations if needed
|
56 |
+
if args.subsample > 1:
|
57 |
+
seq_ann['images'] = [{**img, **seq_ann['info']} for img in seq_ann['images'] if (
|
58 |
+
(img['frame_n'] - 1) % args.subsample) == 0] # -1 bc in the paper this was 0-based
|
59 |
+
img_ids = [img['id'] for img in seq_ann['images']]
|
60 |
+
seq_ann['annotations'] = [
|
61 |
+
ann for ann in seq_ann['annotations'] if ann['image_id'] in img_ids]
|
62 |
+
|
63 |
+
comb_anns['images'].extend(seq_ann['images'])
|
64 |
+
comb_anns['annotations'].extend(seq_ann['annotations'])
|
65 |
+
comb_anns['info'][seq] = seq_ann['info']
|
66 |
+
|
67 |
+
if len(seqs) > 0:
|
68 |
+
comb_anns['categories'] = seq_ann['categories']
|
69 |
+
comb_anns['licenses'] = seq_ann['categories']
|
70 |
+
|
71 |
+
# Sanity check:
|
72 |
+
img_ids = [img['id'] for img in comb_anns['images']]
|
73 |
+
ann_ids = [ann['id'] for ann in comb_anns['annotations']]
|
74 |
+
assert len(img_ids) == len(set(img_ids))
|
75 |
+
assert len(ann_ids) == len(set(ann_ids))
|
76 |
+
|
77 |
+
# Save the new annotations file
|
78 |
+
comb_anns_dir = osp.join(args.save_path, args.save_dir)
|
79 |
+
os.makedirs(comb_anns_dir, exist_ok=True)
|
80 |
+
comb_anns_path = osp.join(comb_anns_dir, f"{args.name}.json")
|
81 |
+
with open(comb_anns_path, 'w') as json_file:
|
82 |
+
json.dump(comb_anns, json_file)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
args = parse_args()
|
87 |
+
main(args)
|
tools/anns/generate_mot_format_files.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os.path as osp
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
import tqdm
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from generate_mots_format_files import save_seqinfo
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--motsynth-path', help="Directory path containing the 'annotations' directory with .json files")
|
16 |
+
parser.add_argument('--save-path', help='Root file in which the new annoation files will be stored. If not provided, motsynth-root will be used')
|
17 |
+
parser.add_argument('--save-dir', default='mot_annotations', help="name of directory within 'save-path'in which MOTS annotation files will be stored")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
if args.save_path is None:
|
21 |
+
args.save_path = args.motsynth_path
|
22 |
+
|
23 |
+
return args
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
ann_dir = osp.join(args.motsynth_path, 'annotations')
|
27 |
+
mot_ann_dir = osp.join(args.save_path, args.save_dir)
|
28 |
+
seqs = [f'{seq_num:03}' for seq_num in range(768) if seq_num not in (629, 757, 524, 652)]
|
29 |
+
|
30 |
+
for seq in tqdm.tqdm(seqs):
|
31 |
+
ann_path = osp.join(ann_dir, f'{seq}.json')
|
32 |
+
with open(ann_path) as f:
|
33 |
+
seq_ann = json.load(f)
|
34 |
+
|
35 |
+
rows = []
|
36 |
+
img_id2frame = {im['id']: im['frame_n'] for im in seq_ann['images']}
|
37 |
+
|
38 |
+
for ann in seq_ann['annotations']:
|
39 |
+
# We compute the 3D location as the mid point between both feet keypoints in 3D
|
40 |
+
kps = np.array(ann['keypoints_3d']).reshape(-1, 4)
|
41 |
+
feet_pos_3d = kps[[-1, -4], :3].mean(axis = 0).round(4)
|
42 |
+
|
43 |
+
row = {'frame': img_id2frame[ann['image_id']],# STARTS AT 0!!!
|
44 |
+
'id': ann['ped_id'],
|
45 |
+
'bb_left': ann['bbox'][0] + 1, # Make it 1-based??
|
46 |
+
'bb_top': ann['bbox'][1] + 1,
|
47 |
+
'bb_width': ann['bbox'][2],
|
48 |
+
'bb_height': ann['bbox'][3],
|
49 |
+
'conf': 1 - ann['iscrowd'],
|
50 |
+
'class': 1 if ann['iscrowd'] == 0 else 8, # Class 8 means distractor. It is the one used by Trackeval as 'iscrowd'
|
51 |
+
# We compute visibility as the proportion of visible keypoints
|
52 |
+
'vis': (np.array(ann['keypoints'])[2::3] ==2).mean().round(2),
|
53 |
+
'x': feet_pos_3d[0],
|
54 |
+
'y': feet_pos_3d[1],
|
55 |
+
'z': feet_pos_3d[2]}
|
56 |
+
|
57 |
+
rows.append(row)
|
58 |
+
|
59 |
+
# Save gt.txt file
|
60 |
+
# Format in https://github.com/dendorferpatrick/MOTChallengeEvalKit/tree/master/MOT
|
61 |
+
mot_ann = pd.DataFrame(rows, columns = ['frame', 'id', 'bb_left', 'bb_top', 'bb_width', 'bb_height', 'conf','class', 'vis', 'x', 'y', 'z'])
|
62 |
+
gt_dir = osp.join(mot_ann_dir, seq, 'gt')
|
63 |
+
os.makedirs(gt_dir, exist_ok=True)
|
64 |
+
mot_ann.to_csv(osp.join(gt_dir, 'gt.txt'), header=None, index=None, sep=',')
|
65 |
+
|
66 |
+
|
67 |
+
# Save seqinfo.ini
|
68 |
+
seqinfo_path = osp.join(mot_ann_dir, seq, 'seqinfo.ini')
|
69 |
+
save_seqinfo(seqinfo_path, info = seq_ann['info'])
|
70 |
+
|
71 |
+
if __name__ =='__main__':
|
72 |
+
args = parse_args()
|
73 |
+
main(args)
|
tools/anns/generate_mots_format_files.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODOs:
|
3 |
+
- argparse
|
4 |
+
- List sequences by number
|
5 |
+
- Get rid of asserts
|
6 |
+
"""
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
import os.path as osp
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
|
14 |
+
import configparser
|
15 |
+
|
16 |
+
import tqdm
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('--motsynth-path', help="Directory path containing the 'annotations' directory with .json files")
|
22 |
+
parser.add_argument('--save-path', help='Root file in which the new annoation files will be stored. If not provided, motsynth-root will be used')
|
23 |
+
parser.add_argument('--save-dir', default='mots_annotations', help="name of directory within 'save-path'in which MOTS annotation files will be stored")
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
if args.save_path is None:
|
27 |
+
args.save_path = args.motsynth_path
|
28 |
+
|
29 |
+
return args
|
30 |
+
|
31 |
+
def save_seqinfo(seqinfo_path, info):
|
32 |
+
seqinfo = configparser.ConfigParser()
|
33 |
+
seqinfo.optionxform = str # Otherwise capital letters are ignored in keys
|
34 |
+
|
35 |
+
seqinfo['Sequence'] = dict(name=info['seq_name'],
|
36 |
+
frameRate=info['fps'],
|
37 |
+
seqLength=info['sequence_length'],
|
38 |
+
imWidth= info['img_width'],
|
39 |
+
imHeight= info['img_height'],
|
40 |
+
weather=info['weather'],
|
41 |
+
time=info['time'],
|
42 |
+
isNight=info['is_night'],
|
43 |
+
isMoving=info['is_moving'],
|
44 |
+
FOV=info['cam_fov'],
|
45 |
+
imExt='.jpg',
|
46 |
+
fx=1158,
|
47 |
+
fy=1158,
|
48 |
+
cx=960,
|
49 |
+
cy=540)
|
50 |
+
|
51 |
+
with open(seqinfo_path, 'w') as configfile: # save
|
52 |
+
seqinfo.write(configfile, space_around_delimiters=False)
|
53 |
+
|
54 |
+
|
55 |
+
def main(args):
|
56 |
+
ann_dir = osp.join(args.motsynth_path, 'annotations')
|
57 |
+
mots_ann_dir = osp.join(args.save_path, args.save_dir)
|
58 |
+
|
59 |
+
seqs = [f'{seq_num:03}' for seq_num in range(768) if seq_num not in (629, 757, 524, 652)]
|
60 |
+
|
61 |
+
for seq in tqdm.tqdm(seqs):
|
62 |
+
ann_path = osp.join(ann_dir, f'{seq}.json')
|
63 |
+
with open(ann_path) as f:
|
64 |
+
seq_ann = json.load(f)
|
65 |
+
|
66 |
+
rows = []
|
67 |
+
img_id2frame = {im['id']: im['frame_n'] for im in seq_ann['images']}
|
68 |
+
|
69 |
+
for ann in seq_ann['annotations']:
|
70 |
+
assert ann['category_id'] == 1
|
71 |
+
if ann['area']: # Include only objects with non-empty masks
|
72 |
+
if not ann['iscrowd']:
|
73 |
+
mots_id = 2000 + ann['ped_id']
|
74 |
+
|
75 |
+
else: # ID = 10000 means that the instance should be ignored during eval.
|
76 |
+
mots_id = 10000
|
77 |
+
|
78 |
+
row = {'time_frame': img_id2frame[ann['image_id']],# STARTS AT 0!!!
|
79 |
+
'id': mots_id,
|
80 |
+
'class_id': 2,
|
81 |
+
'img_height': ann['segmentation']['size'][0],
|
82 |
+
'img_width': ann['segmentation']['size'][1],
|
83 |
+
'rle': ann['segmentation']['counts']}
|
84 |
+
|
85 |
+
rows.append(row)
|
86 |
+
|
87 |
+
# Save gt.txt file
|
88 |
+
# Format in https://www.vision.rwth-aachen.de/page/mots
|
89 |
+
mots_ann = pd.DataFrame(rows, columns = ['time_frame', 'id', 'class_id', 'img_height', 'img_width', 'rle'])
|
90 |
+
gt_dir = osp.join(mots_ann_dir, seq, 'gt')
|
91 |
+
os.makedirs(gt_dir, exist_ok=True)
|
92 |
+
mots_ann.to_csv(osp.join(gt_dir, 'gt.txt'), header=None, index=None, sep=' ')
|
93 |
+
|
94 |
+
|
95 |
+
# Save seqinfo.ini
|
96 |
+
seqinfo_path = osp.join(mots_ann_dir, seq, 'seqinfo.ini')
|
97 |
+
save_seqinfo(seqinfo_path, info = seq_ann['info'])
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ =='__main__':
|
101 |
+
args = parse_args()
|
102 |
+
main(args)
|
tools/anns/motcha_to_coco.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import configparser
|
8 |
+
import datetime
|
9 |
+
|
10 |
+
import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
'--data-root', help="Path containing the dataset in a folder")
|
17 |
+
parser.add_argument('--dataset', default='MOT17',
|
18 |
+
help='Name of the dataset to be used. Should be either MOT17 or MOT20')
|
19 |
+
parser.add_argument(
|
20 |
+
'--save-dir', help='Root file in which the new annoation files will be stored. If not provided, data-root will be used')
|
21 |
+
parser.add_argument('--split', default='train',
|
22 |
+
help="Split processed within the dataset. Should be either 'train' or 'test'")
|
23 |
+
parser.add_argument('--save-combined', default=True, action='store_true',
|
24 |
+
help="Determines whether a separate .json file containing all sequence annotations will be created")
|
25 |
+
parser.add_argument('--subsample', default=20, type=int,
|
26 |
+
help="Frame subsampling rate. If e.g. 10 is selected, then we will select 1 in 20 frames")
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
if args.save_dir is None:
|
31 |
+
args.save_dir = osp.join(args.data_root, 'motcha_coco_annotations')
|
32 |
+
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def get_img_id(dataset, seq, fname):
|
37 |
+
# Dataset num, seq num, frame num
|
38 |
+
return int(f"{dataset[3:5]}{seq.split('-')[1]}{int(fname.split('.')[0]):06}")
|
39 |
+
|
40 |
+
|
41 |
+
def read_seqinfo(path):
|
42 |
+
cp = configparser.ConfigParser()
|
43 |
+
cp.read(path)
|
44 |
+
return {'height': int(cp.get('Sequence', 'imHeight')),
|
45 |
+
'width': int(cp.get('Sequence', 'imWidth')),
|
46 |
+
'fps': int(cp.get('Sequence', 'frameRate'))}
|
47 |
+
|
48 |
+
|
49 |
+
def main(args):
|
50 |
+
data_path = osp.join(args.data_root, args.dataset, args.split)
|
51 |
+
seqs = os.listdir(data_path)
|
52 |
+
|
53 |
+
if args.save_combined:
|
54 |
+
comb_data = {'info': {'dataset': args.dataset,
|
55 |
+
'split': args.split,
|
56 |
+
'creation_date': datetime.datetime.today().strftime('%Y-%m-%d-%H-%M')},
|
57 |
+
'images': [],
|
58 |
+
'annotations': [],
|
59 |
+
'categories': [{'id': 1, 'name': 'person', 'supercategory': 'person'}]}
|
60 |
+
|
61 |
+
for seq in tqdm.tqdm(seqs):
|
62 |
+
if args.dataset.lower() == 'mot17':
|
63 |
+
# Choose an arbitrary set of detections for MOT17, annotations are the same
|
64 |
+
if not seq.endswith('FRCNN'):
|
65 |
+
continue
|
66 |
+
|
67 |
+
print(f"Processing sequence {seq} in dataset {args.dataset}")
|
68 |
+
|
69 |
+
seq_path = osp.join(data_path, seq)
|
70 |
+
seqinfo_path = osp.join(seq_path, 'seqinfo.ini')
|
71 |
+
gt_path = osp.join(seq_path, 'gt/gt.txt')
|
72 |
+
im_dir = osp.join(seq_path, 'img1')
|
73 |
+
|
74 |
+
if args.dataset.lower() == 'mot17':
|
75 |
+
seq_ = '-'.join(seq.split('-')[:-1]) # Get rid of detector string
|
76 |
+
|
77 |
+
else:
|
78 |
+
seq_ = seq.copy()
|
79 |
+
|
80 |
+
seqinfo = read_seqinfo(seqinfo_path)
|
81 |
+
data = {'info': {'sequence': seq_,
|
82 |
+
'dataset': args.dataset,
|
83 |
+
'split': args.split,
|
84 |
+
'creation_date': datetime.datetime.today().strftime('%Y-%m-%d-%H-%M'),
|
85 |
+
**seqinfo},
|
86 |
+
'images': [],
|
87 |
+
'annotations': [],
|
88 |
+
'categories': [{'id': 1, 'name': 'person', 'supercategory': 'person'}]}
|
89 |
+
|
90 |
+
# Load Bounding Box annotations
|
91 |
+
gt = np.loadtxt(gt_path, dtype=np.float32, delimiter=',')
|
92 |
+
keep_classes = [1, 2, 7, 8, 12]
|
93 |
+
mask = np.isin(gt[:, 7], keep_classes)
|
94 |
+
gt = gt[mask]
|
95 |
+
anns = [{'ped_id': row[1],
|
96 |
+
'frame_n': row[0],
|
97 |
+
'category_id': 1,
|
98 |
+
'id': f"{get_img_id(args.dataset, seq, f'{int(row[0]):06}.jpg')}{int(row_i):010}",
|
99 |
+
'image_id': get_img_id(args.dataset, seq, f'{int(row[0]):06}.jpg'),
|
100 |
+
# 'bbox': row[2:6].tolist(),
|
101 |
+
# MOTCha annotations are 1-based
|
102 |
+
'bbox': [row[2] - 1, row[3] - 1, row[4], row[5]],
|
103 |
+
'area': row[4]*row[5],
|
104 |
+
'vis': row[8],
|
105 |
+
'iscrowd': 1 - row[6]}
|
106 |
+
for row_i, row in enumerate(gt.astype(float)) if row[0] % args.subsample == 0]
|
107 |
+
|
108 |
+
# Load Image information
|
109 |
+
all_img_ids = list(set([aa['image_id'] for aa in anns]))
|
110 |
+
imgs = [{'file_name': osp.join(args.dataset, args.split, seq, 'img1', fname),
|
111 |
+
'height': seqinfo['height'],
|
112 |
+
'width': seqinfo['width'],
|
113 |
+
'frame_n': int(fname.split('.')[0]),
|
114 |
+
'id': get_img_id(args.dataset, seq, fname)}
|
115 |
+
for fname in os.listdir(im_dir) if get_img_id(args.dataset, seq, fname) in all_img_ids]
|
116 |
+
assert len(set([im['id'] for im in imgs])) == len(imgs)
|
117 |
+
data['images'].extend(imgs)
|
118 |
+
assert len(str(imgs[0]['id'])) == len(str(anns[0]['image_id']))
|
119 |
+
|
120 |
+
data['annotations'].extend(anns)
|
121 |
+
|
122 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
123 |
+
fname = f"{args.dataset}_{seq_}.json" if args.dataset not in seq_ else f"{seq_}.json"
|
124 |
+
save_path = osp.join(args.save_dir, fname)
|
125 |
+
with open(save_path, 'w') as f:
|
126 |
+
json.dump(data, f)
|
127 |
+
|
128 |
+
print(f"Saved result at {save_path}")
|
129 |
+
|
130 |
+
if args.save_combined:
|
131 |
+
comb_data['annotations'].extend(anns)
|
132 |
+
comb_data['images'].extend(imgs)
|
133 |
+
|
134 |
+
if args.save_combined:
|
135 |
+
save_path = osp.join(
|
136 |
+
args.save_dir, f"{args.dataset}_{args.split}.json")
|
137 |
+
with open(save_path, 'w') as f:
|
138 |
+
json.dump(comb_data, f)
|
139 |
+
|
140 |
+
print(f"Saved combined result at {save_path}")
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
args = parse_args()
|
145 |
+
main(args)
|
tools/anns/splits/motsynth_split1.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
000
|
2 |
+
008
|
3 |
+
016
|
4 |
+
024
|
5 |
+
032
|
6 |
+
040
|
7 |
+
048
|
8 |
+
056
|
9 |
+
064
|
10 |
+
072
|
11 |
+
080
|
12 |
+
088
|
13 |
+
096
|
14 |
+
104
|
15 |
+
112
|
16 |
+
120
|
tools/anns/splits/motsynth_split2.txt
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
000
|
2 |
+
004
|
3 |
+
008
|
4 |
+
012
|
5 |
+
016
|
6 |
+
020
|
7 |
+
024
|
8 |
+
028
|
9 |
+
032
|
10 |
+
036
|
11 |
+
040
|
12 |
+
044
|
13 |
+
048
|
14 |
+
052
|
15 |
+
056
|
16 |
+
060
|
17 |
+
064
|
18 |
+
068
|
19 |
+
072
|
20 |
+
076
|
21 |
+
080
|
22 |
+
084
|
23 |
+
088
|
24 |
+
092
|
25 |
+
096
|
26 |
+
100
|
27 |
+
104
|
28 |
+
108
|
29 |
+
112
|
30 |
+
116
|
31 |
+
120
|
tools/anns/splits/motsynth_split3.txt
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
000
|
2 |
+
002
|
3 |
+
004
|
4 |
+
006
|
5 |
+
008
|
6 |
+
010
|
7 |
+
012
|
8 |
+
014
|
9 |
+
016
|
10 |
+
018
|
11 |
+
020
|
12 |
+
022
|
13 |
+
024
|
14 |
+
026
|
15 |
+
028
|
16 |
+
030
|
17 |
+
032
|
18 |
+
034
|
19 |
+
036
|
20 |
+
038
|
21 |
+
040
|
22 |
+
042
|
23 |
+
044
|
24 |
+
046
|
25 |
+
048
|
26 |
+
050
|
27 |
+
052
|
28 |
+
054
|
29 |
+
056
|
30 |
+
058
|
31 |
+
060
|
32 |
+
062
|
33 |
+
064
|
34 |
+
066
|
35 |
+
068
|
36 |
+
070
|
37 |
+
072
|
38 |
+
074
|
39 |
+
076
|
40 |
+
078
|
41 |
+
080
|
42 |
+
082
|
43 |
+
084
|
44 |
+
086
|
45 |
+
088
|
46 |
+
090
|
47 |
+
092
|
48 |
+
094
|
49 |
+
096
|
50 |
+
098
|
51 |
+
100
|
52 |
+
102
|
53 |
+
104
|
54 |
+
106
|
55 |
+
108
|
56 |
+
110
|
57 |
+
112
|
58 |
+
114
|
59 |
+
116
|
60 |
+
118
|
61 |
+
120
|
62 |
+
122
|
tools/anns/splits/motsynth_split4.txt
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
000
|
2 |
+
001
|
3 |
+
002
|
4 |
+
003
|
5 |
+
004
|
6 |
+
005
|
7 |
+
006
|
8 |
+
007
|
9 |
+
008
|
10 |
+
009
|
11 |
+
010
|
12 |
+
011
|
13 |
+
012
|
14 |
+
013
|
15 |
+
014
|
16 |
+
015
|
17 |
+
016
|
18 |
+
017
|
19 |
+
018
|
20 |
+
019
|
21 |
+
020
|
22 |
+
021
|
23 |
+
022
|
24 |
+
023
|
25 |
+
024
|
26 |
+
025
|
27 |
+
026
|
28 |
+
027
|
29 |
+
028
|
30 |
+
029
|
31 |
+
030
|
32 |
+
031
|
33 |
+
032
|
34 |
+
033
|
35 |
+
034
|
36 |
+
035
|
37 |
+
036
|
38 |
+
037
|
39 |
+
038
|
40 |
+
039
|
41 |
+
040
|
42 |
+
041
|
43 |
+
042
|
44 |
+
043
|
45 |
+
044
|
46 |
+
045
|
47 |
+
046
|
48 |
+
047
|
49 |
+
048
|
50 |
+
049
|
51 |
+
050
|
52 |
+
051
|
53 |
+
052
|
54 |
+
053
|
55 |
+
054
|
56 |
+
055
|
57 |
+
056
|
58 |
+
057
|
59 |
+
058
|
60 |
+
059
|
61 |
+
060
|
62 |
+
061
|
63 |
+
062
|
64 |
+
063
|
65 |
+
064
|
66 |
+
065
|
67 |
+
066
|
68 |
+
067
|
69 |
+
068
|
70 |
+
069
|
71 |
+
070
|
72 |
+
071
|
73 |
+
072
|
74 |
+
073
|
75 |
+
074
|
76 |
+
075
|
77 |
+
076
|
78 |
+
077
|
79 |
+
078
|
80 |
+
079
|
81 |
+
080
|
82 |
+
081
|
83 |
+
082
|
84 |
+
083
|
85 |
+
084
|
86 |
+
085
|
87 |
+
086
|
88 |
+
087
|
89 |
+
088
|
90 |
+
089
|
91 |
+
090
|
92 |
+
091
|
93 |
+
092
|
94 |
+
093
|
95 |
+
094
|
96 |
+
095
|
97 |
+
096
|
98 |
+
097
|
99 |
+
098
|
100 |
+
099
|
101 |
+
100
|
102 |
+
101
|
103 |
+
102
|
104 |
+
103
|
105 |
+
104
|
106 |
+
105
|
107 |
+
106
|
108 |
+
107
|
109 |
+
108
|
110 |
+
109
|
111 |
+
110
|
112 |
+
111
|
113 |
+
112
|
114 |
+
113
|
115 |
+
114
|
116 |
+
115
|
117 |
+
116
|
118 |
+
117
|
119 |
+
118
|
120 |
+
119
|
121 |
+
120
|
122 |
+
121
|
123 |
+
122
|
tools/anns/store_reid_imgs.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--ann-path', help=".JSON annotations file path in COCO format")
|
15 |
+
parser.add_argument('--frames-path', help="Root directory containing images")
|
16 |
+
parser.add_argument('--save-dir', help='Root file in which the new annoation files will be stored. If not provided, data-root will be used')
|
17 |
+
parser.add_argument('--start-iter', default=0, type=int)
|
18 |
+
|
19 |
+
#args = parser.parse_args(['--ann-path', '/storage/user/brasoand/MOTSynth/comb_annotations/train_mini.json'])
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
if args.frames_path is None:
|
23 |
+
args.frames_path = osp.dirname(osp.dirname(args.ann_path))
|
24 |
+
|
25 |
+
if args.save_dir is None:
|
26 |
+
#args.save_dir = osp.join(osp.dirname(osp.dirname(args.ann_path)), 'reid_images')
|
27 |
+
args.save_dir = osp.join(osp.dirname(osp.dirname(args.ann_path)), 'reid')
|
28 |
+
|
29 |
+
|
30 |
+
return args
|
31 |
+
|
32 |
+
def crop_box(im, bbox):
|
33 |
+
x1, y1, w, h = bbox
|
34 |
+
x2, y2 = x1+ w, y1+ h
|
35 |
+
return im.crop((x1, y1, x2, y2))
|
36 |
+
|
37 |
+
|
38 |
+
def main(args):
|
39 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
40 |
+
|
41 |
+
# Read annotations
|
42 |
+
with open(args.ann_path) as f:
|
43 |
+
anns = json.load(f)
|
44 |
+
|
45 |
+
|
46 |
+
# Annotation ids are used as file names to store boxes.
|
47 |
+
# Therefore they need to be unique
|
48 |
+
ann_ids = [ann['id'] for ann in anns['annotations']]
|
49 |
+
assert len(ann_ids) == len(set(ann_ids))
|
50 |
+
imgid2file = {img['id']: img['file_name'] for img in anns['images']}
|
51 |
+
|
52 |
+
# TODO: This needs to go!!!!!!!
|
53 |
+
anns['annotations'] = [ann for ann in anns['annotations'] if not osp.exists(osp.join(args.save_dir, f"{ann['id']}.png"))]
|
54 |
+
len(anns['annotations'])
|
55 |
+
im2anns = defaultdict(list)
|
56 |
+
for ann in anns['annotations']:
|
57 |
+
im2anns[imgid2file[ann['image_id']]].append(ann)
|
58 |
+
|
59 |
+
for img_file, im_anns in tqdm.tqdm(im2anns.items()):
|
60 |
+
#break
|
61 |
+
# Read Image
|
62 |
+
im_path = osp.join(args.frames_path, img_file)
|
63 |
+
if not osp.exists(im_path):
|
64 |
+
im_path = osp.join(args.frames_path, img_file.replace('rgb/', ''))
|
65 |
+
|
66 |
+
assert osp.exists(im_path)
|
67 |
+
im = Image.open(im_path)
|
68 |
+
|
69 |
+
for ann in im_anns:
|
70 |
+
box_path = osp.join(args.save_dir, f"{ann['id']}.png")
|
71 |
+
|
72 |
+
if osp.exists(box_path):
|
73 |
+
continue
|
74 |
+
|
75 |
+
#if ann['bbox'][-2] > 2000 or ann['bbox'][-1] > 2000:
|
76 |
+
# continue
|
77 |
+
|
78 |
+
box_im = crop_box(im, ann['bbox'])
|
79 |
+
box_im.save(box_path)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
args = parse_args()
|
84 |
+
main(args)
|
tools/anns/to_frames.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import cv2
|
5 |
+
import glob
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
parser = argparse.ArgumentParser(description='Get frames from a video')
|
14 |
+
parser.add_argument(
|
15 |
+
'--motsynth-root', help='Directory hosting MOTSYnth part directories')
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
video_paths = glob.glob(
|
19 |
+
osp.join(args.motsynth_root, 'MOTSynth_[0-9]/[0-9][0-9][0-9].mp4'))
|
20 |
+
|
21 |
+
frames_dir = os.path.join(args.motsynth_root, "frames")
|
22 |
+
os.makedirs(frames_dir, exist_ok=True)
|
23 |
+
|
24 |
+
print("Start extracting frames...")
|
25 |
+
|
26 |
+
for video_path in tqdm.tqdm(video_paths):
|
27 |
+
vidcap = cv2.VideoCapture(video_path)
|
28 |
+
|
29 |
+
seq_name = osp.basename(video_path).split(".")[0].zfill(3)
|
30 |
+
out_dir = os.path.join(frames_dir, seq_name, 'rgb')
|
31 |
+
os.makedirs(out_dir, exist_ok=True)
|
32 |
+
|
33 |
+
count = 1
|
34 |
+
success = True
|
35 |
+
|
36 |
+
#print("Unpacking video...")
|
37 |
+
|
38 |
+
while success:
|
39 |
+
success, image = vidcap.read()
|
40 |
+
if count < 3:
|
41 |
+
count += 1
|
42 |
+
continue
|
43 |
+
if not success or count == 1803:
|
44 |
+
break
|
45 |
+
if count % 200 == 0:
|
46 |
+
print("Extract frames until: " +
|
47 |
+
str(count - 3).zfill(4) + ".jpg")
|
48 |
+
filename = os.path.join(out_dir, str(count - 3).zfill(4) + ".jpg")
|
49 |
+
cv2.imwrite(filename, image) # save frame as JPEG file
|
50 |
+
count += 1
|
51 |
+
|
52 |
+
print("Done!")
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
main()
|
tools/inference_detector.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2, FastRCNNPredictor
|
3 |
+
from configs.path_cfg import OUTPUT_DIR
|
4 |
+
from src.detection.vision.engine import evaluate
|
5 |
+
from tools.train_detector import create_dataset, create_data_loader, get_transform
|
6 |
+
from src.detection.graph_utils import add_bbox, show_img
|
7 |
+
import os.path as osp
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
|
11 |
+
def parse_args(add_help=True):
|
12 |
+
parser = argparse.ArgumentParser(
|
13 |
+
description="Detector inference", add_help=add_help)
|
14 |
+
|
15 |
+
# path to model used for inference
|
16 |
+
parser.add_argument("--model-path", type=str,
|
17 |
+
help="Path with model checkpoint used for inference")
|
18 |
+
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
if args.model_path is None:
|
22 |
+
args.model_path = osp.join(
|
23 |
+
OUTPUT_DIR, "detection_logs", "fasterrcnn_training", "checkpoint.pth")
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main(args):
|
28 |
+
ds_val = create_dataset(
|
29 |
+
"motsynth_val", get_transform(False, "hflip"), "test")
|
30 |
+
data_loader_val = create_data_loader(ds_val, "test", 1, 0)
|
31 |
+
|
32 |
+
device = torch.device("cuda")
|
33 |
+
model = fasterrcnn_resnet50_fpn_v2()
|
34 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
35 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
|
36 |
+
checkpoint = torch.load(
|
37 |
+
args.model_path, map_location="cpu")
|
38 |
+
model.load_state_dict(checkpoint["model"])
|
39 |
+
model.eval()
|
40 |
+
model.to(device)
|
41 |
+
show_img(data_loader_val, model, device, 0.8)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
args = parse_args()
|
46 |
+
main(args)
|
tools/train_detector.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from configs.path_cfg import MOTSYNTH_ROOT, MOTCHA_ROOT, OUTPUT_DIR
|
3 |
+
import datetime
|
4 |
+
import os.path as osp
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import coloredlogs
|
8 |
+
import logging
|
9 |
+
from torchinfo import summary
|
10 |
+
import torch
|
11 |
+
import torch.utils.data
|
12 |
+
from src.detection.vision.mot_data import MOTObjDetect
|
13 |
+
from src.detection.model_factory import ModelFactory
|
14 |
+
from src.detection.graph_utils import save_train_loss_plot
|
15 |
+
import src.detection.vision.presets as presets
|
16 |
+
import src.detection.vision.utils as utils
|
17 |
+
from src.detection.vision.engine import train_one_epoch, evaluate
|
18 |
+
from src.detection.vision.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
|
19 |
+
from src.detection.mot_dataset import get_mot_dataset
|
20 |
+
import torchvision
|
21 |
+
|
22 |
+
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2
|
23 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
24 |
+
coloredlogs.install(level='DEBUG')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def get_args_parser(add_help=True):
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser(
|
32 |
+
description="PyTorch Detection Training", add_help=add_help)
|
33 |
+
|
34 |
+
# Output directory used to save model, plots and summary
|
35 |
+
parser.add_argument("--output-dir", default='fasterrcnn_training',
|
36 |
+
type=str, help="Path to save outputs (default: fasterrcnn_training)")
|
37 |
+
|
38 |
+
# Dataset params
|
39 |
+
parser.add_argument("--train-dataset", default="motsynth_split1",
|
40 |
+
type=str, help="Dataset name. Please select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17 (default: motsynth_split1)")
|
41 |
+
parser.add_argument("--val-dataset", default="MOT17",
|
42 |
+
type=str, help="Dataset name. Please select one of the following: MOT17 (default: MOT17)")
|
43 |
+
|
44 |
+
# Transforms params
|
45 |
+
parser.add_argument(
|
46 |
+
"--data-augmentation", default="hflip", type=str, help="Data augmentation policy (default: hflip)"
|
47 |
+
)
|
48 |
+
|
49 |
+
# Data Loaders params
|
50 |
+
parser.add_argument(
|
51 |
+
"-b", "--batch-size", default=3, type=int, help="Images per gpu (default: 3)"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"-j", "--workers", default=0, type=int, metavar="N", help="Number of data loading workers (default: 0)"
|
55 |
+
)
|
56 |
+
parser.add_argument("--aspect-ratio-group-factor", default=3,
|
57 |
+
type=int, help="Aspect ration group factor (default:3)")
|
58 |
+
|
59 |
+
# Model param
|
60 |
+
parser.add_argument(
|
61 |
+
"--model", default="fasterrcnn_resnet50_fpn", type=str, help="Model name (default: fasterrcnn_resnet50_fpn)")
|
62 |
+
parser.add_argument(
|
63 |
+
"--weights", default="DEFAULT", type=str, help="Model weights (default: DEFAULT)"
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--backbone", default='resnet50', type=str, help="Type of backbone (default: resnet50)"
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--trainable-backbone-layers", default=3, type=int, help="Number of trainable layers of backbone (default: 3)"
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--backbone-weights", default="DEFAULT", type=str, help="Backbone weights (default: DEFAULT)"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Device param
|
76 |
+
parser.add_argument("--device", default="cuda", type=str,
|
77 |
+
help="device (default: cuda)")
|
78 |
+
|
79 |
+
# Test mode param
|
80 |
+
parser.add_argument(
|
81 |
+
"--test-only",
|
82 |
+
dest="test_only",
|
83 |
+
help="Only test the model",
|
84 |
+
action="store_true",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--model-eval", type=str, help="model path for test only mode"
|
88 |
+
)
|
89 |
+
|
90 |
+
# Optimizer params
|
91 |
+
parser.add_argument(
|
92 |
+
"--lr",
|
93 |
+
default=0.0025,
|
94 |
+
type=float,
|
95 |
+
help="Learning rate (default: 0.0025)",
|
96 |
+
)
|
97 |
+
parser.add_argument("--momentum", default=0.9,
|
98 |
+
type=float, metavar="M", help="Momentum (default: 0.9")
|
99 |
+
parser.add_argument(
|
100 |
+
"--wd",
|
101 |
+
"--weight-decay",
|
102 |
+
default=1e-4,
|
103 |
+
type=float,
|
104 |
+
metavar="W",
|
105 |
+
help="Weight decay (default: 1e-4)",
|
106 |
+
dest="weight_decay",
|
107 |
+
)
|
108 |
+
|
109 |
+
# Lr Scheduler params
|
110 |
+
parser.add_argument(
|
111 |
+
"--lr-scheduler", default="multisteplr", type=str, help="Name of lr scheduler (default: multisteplr)"
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--lr-steps",
|
115 |
+
default=[16, 22],
|
116 |
+
nargs="+",
|
117 |
+
type=int,
|
118 |
+
help="Decrease lr every step-size epochs (multisteplr scheduler only)",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--lr-gamma", default=0.1, type=float, help="Decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
|
122 |
+
)
|
123 |
+
|
124 |
+
# Mixed precision training params
|
125 |
+
parser.add_argument("--amp", action="store_true",
|
126 |
+
help="Use torch.cuda.amp for mixed precision training")
|
127 |
+
|
128 |
+
# Resume training params
|
129 |
+
parser.add_argument("--resume", default="", type=str,
|
130 |
+
help="path of checkpoint")
|
131 |
+
|
132 |
+
# training param
|
133 |
+
parser.add_argument("--start_epoch", default=0,
|
134 |
+
type=int, help="start epoch")
|
135 |
+
parser.add_argument("--epochs", default=30, type=int,
|
136 |
+
metavar="N", help="number of total epochs to run")
|
137 |
+
parser.add_argument("--print-freq", default=20,
|
138 |
+
type=int, help="print frequency")
|
139 |
+
|
140 |
+
return parser
|
141 |
+
|
142 |
+
|
143 |
+
def get_transform(train, data_augmentation):
|
144 |
+
if train:
|
145 |
+
return presets.DetectionPresetTrain(data_augmentation)
|
146 |
+
else:
|
147 |
+
return presets.DetectionPresetEval()
|
148 |
+
|
149 |
+
|
150 |
+
def get_motsynth_dataset(ds_name: str, transforms):
|
151 |
+
data_path = osp.join(MOTSYNTH_ROOT, 'comb_annotations', f"{ds_name}.json")
|
152 |
+
dataset = get_mot_dataset(MOTSYNTH_ROOT, data_path, transforms=transforms)
|
153 |
+
return dataset
|
154 |
+
|
155 |
+
|
156 |
+
def get_MOT17_dataset(split: str, split_seqs: List, transforms):
|
157 |
+
data_path = osp.join(MOTCHA_ROOT, "MOT17", "train")
|
158 |
+
dataset = MOTObjDetect(
|
159 |
+
data_path, transforms=transforms, split_seqs=split_seqs)
|
160 |
+
return dataset
|
161 |
+
|
162 |
+
|
163 |
+
def create_dataset(ds_name: str, transforms, split=None):
|
164 |
+
if (ds_name.startswith("motsynth")):
|
165 |
+
return get_motsynth_dataset(ds_name, transforms)
|
166 |
+
|
167 |
+
elif (ds_name.startswith("MOT17")):
|
168 |
+
if split == "train":
|
169 |
+
split_seqs = ['MOT17-02-FRCNN', 'MOT17-04-FRCNN',
|
170 |
+
'MOT17-11-FRCNN', 'MOT17-13-FRCNN']
|
171 |
+
elif split == "test":
|
172 |
+
split_seqs = ['MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-05-FRCNN']
|
173 |
+
return get_MOT17_dataset(split, split_seqs, transforms)
|
174 |
+
|
175 |
+
else:
|
176 |
+
logger.error(
|
177 |
+
"Please, provide a valid dataset as argument. Select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17.")
|
178 |
+
raise ValueError(ds_name)
|
179 |
+
|
180 |
+
|
181 |
+
def create_data_loader(dataset, split: str, batch_size, workers, aspect_ratio_group_factor=-1):
|
182 |
+
data_loader = None
|
183 |
+
if split == "train":
|
184 |
+
# random sampling on training dataset
|
185 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
186 |
+
if aspect_ratio_group_factor >= 0:
|
187 |
+
group_ids = create_aspect_ratio_groups(
|
188 |
+
dataset, k=aspect_ratio_group_factor)
|
189 |
+
train_batch_sampler = GroupedBatchSampler(
|
190 |
+
train_sampler, group_ids, batch_size)
|
191 |
+
else:
|
192 |
+
train_batch_sampler = torch.utils.data.BatchSampler(
|
193 |
+
train_sampler, batch_size, drop_last=True)
|
194 |
+
data_loader = torch.utils.data.DataLoader(
|
195 |
+
dataset, batch_sampler=train_batch_sampler, num_workers=workers, collate_fn=utils.collate_fn
|
196 |
+
)
|
197 |
+
elif split == "test":
|
198 |
+
# sequential sampling on eval dataset
|
199 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset)
|
200 |
+
data_loader = torch.utils.data.DataLoader(
|
201 |
+
dataset, batch_size=1, sampler=test_sampler, num_workers=workers, collate_fn=utils.collate_fn
|
202 |
+
)
|
203 |
+
return data_loader
|
204 |
+
|
205 |
+
|
206 |
+
def create_optimizer(model, lr, momentum, weight_decay):
|
207 |
+
params = [p for p in model.parameters() if p.requires_grad]
|
208 |
+
optimizer = torch.optim.SGD(
|
209 |
+
params, lr=lr, momentum=momentum, weight_decay=weight_decay)
|
210 |
+
return optimizer
|
211 |
+
|
212 |
+
|
213 |
+
def create_lr_scheduler(optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs):
|
214 |
+
if lr_scheduler_type == "multisteplr":
|
215 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
216 |
+
optimizer, milestones=lr_steps, gamma=lr_gamma)
|
217 |
+
logger.debug(
|
218 |
+
f"lr_scheduler: {lr_scheduler_type}, milestones: {lr_steps}, gamma: {lr_gamma}")
|
219 |
+
|
220 |
+
elif lr_scheduler_type == "cosineannealinglr":
|
221 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
222 |
+
optimizer, T_max=epochs)
|
223 |
+
logger.debug(
|
224 |
+
f"lr_scheduler: {lr_scheduler_type}, T_max: {epochs}")
|
225 |
+
else:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Invalid lr scheduler '{lr_scheduler_type}'. Only MultiStepLR and CosineAnnealingLR are supported."
|
228 |
+
)
|
229 |
+
return lr_scheduler
|
230 |
+
|
231 |
+
|
232 |
+
def resume_training(model, optimizer, lr_scheduler, scaler, args):
|
233 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
234 |
+
|
235 |
+
model.load_state_dict(checkpoint["model"])
|
236 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
237 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
238 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
239 |
+
if args.amp:
|
240 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
241 |
+
|
242 |
+
|
243 |
+
def save_model_checkpoint(model, optimizer, lr_scheduler, epoch, scaler, output_dir, args):
|
244 |
+
if output_dir:
|
245 |
+
checkpoint = {
|
246 |
+
"model": model.state_dict(),
|
247 |
+
"optimizer": optimizer.state_dict(),
|
248 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
249 |
+
"args": args,
|
250 |
+
"epoch": epoch,
|
251 |
+
}
|
252 |
+
if args.amp:
|
253 |
+
checkpoint["scaler"] = scaler.state_dict()
|
254 |
+
utils.save_on_master(checkpoint, os.path.join(
|
255 |
+
output_dir, f"model_{epoch}.pth"))
|
256 |
+
utils.save_on_master(checkpoint, os.path.join(
|
257 |
+
output_dir, "checkpoint.pth"))
|
258 |
+
|
259 |
+
|
260 |
+
def save_plots(losses_dict, batch_loss_dict, output_dir):
|
261 |
+
if not losses_dict:
|
262 |
+
for name, metric in batch_loss_dict.items():
|
263 |
+
losses_dict[name] = []
|
264 |
+
|
265 |
+
for name, metric in batch_loss_dict.items():
|
266 |
+
losses_dict[name].extend(metric)
|
267 |
+
save_train_loss_plot(losses_dict, output_dir)
|
268 |
+
|
269 |
+
|
270 |
+
def save_model_summary(model, output_dir, batch_size):
|
271 |
+
with open(osp.join(output_dir, "summary.txt"), 'w', encoding="utf-8") as f:
|
272 |
+
print(summary(model,
|
273 |
+
# (batch_size, color_channels, height, width)
|
274 |
+
input_size=(batch_size, 3, 1080, 1920),
|
275 |
+
verbose=0,
|
276 |
+
col_names=["input_size", "output_size",
|
277 |
+
"num_params", "kernel_size", "trainable"],
|
278 |
+
col_width=20,
|
279 |
+
row_settings=["var_names"]), file=f)
|
280 |
+
|
281 |
+
|
282 |
+
def save_args(output_dir, args):
|
283 |
+
with open(osp.join(output_dir, "args.txt"), 'w', encoding="utf-8") as f:
|
284 |
+
print(args, file=f)
|
285 |
+
|
286 |
+
|
287 |
+
def save_evaluate_summary(stats, output_dir):
|
288 |
+
metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"]
|
289 |
+
# the standard metrics
|
290 |
+
results = {
|
291 |
+
metric: float(stats[idx] *
|
292 |
+
100 if stats[idx] >= 0 else "nan")
|
293 |
+
for idx, metric in enumerate(metrics)
|
294 |
+
}
|
295 |
+
with open(osp.join(output_dir, "evaluate.txt"), 'w', encoding="utf-8") as f:
|
296 |
+
print(results, file=f)
|
297 |
+
|
298 |
+
|
299 |
+
def main(args):
|
300 |
+
|
301 |
+
output_dir = None
|
302 |
+
if args.output_dir:
|
303 |
+
output_dir = osp.join(
|
304 |
+
OUTPUT_DIR, 'detection_logs', args.output_dir)
|
305 |
+
utils.mkdir(output_dir)
|
306 |
+
output_plots_dir = osp.join(output_dir, "plots")
|
307 |
+
utils.mkdir(output_plots_dir)
|
308 |
+
|
309 |
+
logger.debug("COMMAND LINE ARGUMENTS")
|
310 |
+
logger.debug(args)
|
311 |
+
save_args(output_dir, args)
|
312 |
+
|
313 |
+
device = torch.device(args.device)
|
314 |
+
logger.debug(f"DEVICE: {device}")
|
315 |
+
|
316 |
+
logger.debug("CREATE DATASETS")
|
317 |
+
ds_train_name = args.train_dataset
|
318 |
+
ds_val_name = args.val_dataset
|
319 |
+
data_augmentation = args.data_augmentation
|
320 |
+
|
321 |
+
dataset_train = create_dataset(
|
322 |
+
ds_train_name, get_transform(True, data_augmentation), "train")
|
323 |
+
dataset_test = create_dataset(
|
324 |
+
ds_val_name, get_transform(False, data_augmentation), "test")
|
325 |
+
|
326 |
+
logger.debug("CREATE DATA LOADERS")
|
327 |
+
batch_size = args.batch_size
|
328 |
+
workers = args.workers
|
329 |
+
aspect_ratio_group_factor = args.aspect_ratio_group_factor
|
330 |
+
data_loader_train = create_data_loader(
|
331 |
+
dataset_train, "train", batch_size, workers, aspect_ratio_group_factor)
|
332 |
+
data_loader_test = create_data_loader(
|
333 |
+
dataset_test, "test", batch_size, workers)
|
334 |
+
|
335 |
+
if args.test_only:
|
336 |
+
logger.debug("TEST ONLY")
|
337 |
+
model = fasterrcnn_resnet50_fpn_v2()
|
338 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
339 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
|
340 |
+
checkpoint = torch.load(args.model_eval, map_location="cuda")
|
341 |
+
model.load_state_dict(checkpoint["model"])
|
342 |
+
model.to(device)
|
343 |
+
coco_evaluator = evaluate(model, data_loader_test,
|
344 |
+
device=device, iou_types=['bbox'])
|
345 |
+
save_evaluate_summary(
|
346 |
+
coco_evaluator.coco_eval['bbox'].stats, output_dir)
|
347 |
+
return
|
348 |
+
|
349 |
+
logger.debug("CREATE MODEL")
|
350 |
+
model_name = args.model
|
351 |
+
weights = args.weights
|
352 |
+
backbone = args.backbone
|
353 |
+
backbone_weights = args.backbone_weights
|
354 |
+
trainable_backbone_layers = args.trainable_backbone_layers
|
355 |
+
model = ModelFactory.get_model(
|
356 |
+
model_name, weights, backbone, backbone_weights, trainable_backbone_layers)
|
357 |
+
save_model_summary(model, output_dir, batch_size)
|
358 |
+
|
359 |
+
logger.debug("CREATE OPTIMIZER")
|
360 |
+
lr = args.lr
|
361 |
+
momentum = args.momentum
|
362 |
+
weight_decay = args.weight_decay
|
363 |
+
optimizer = create_optimizer(
|
364 |
+
model, lr, momentum, weight_decay)
|
365 |
+
|
366 |
+
logger.debug("CREATE LR SCHEDULER")
|
367 |
+
epochs = args.epochs
|
368 |
+
lr_scheduler_type = args.lr_scheduler.lower()
|
369 |
+
lr_steps = args.lr_steps
|
370 |
+
lr_gamma = args.lr_gamma
|
371 |
+
lr_scheduler = create_lr_scheduler(
|
372 |
+
optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs)
|
373 |
+
|
374 |
+
logger.debug("CONFIGURE SCALER FOR amp")
|
375 |
+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
|
376 |
+
|
377 |
+
if args.resume:
|
378 |
+
logger.debug("RESUME TRAINING")
|
379 |
+
resume_training(model, optimizer, lr_scheduler,
|
380 |
+
scaler, args)
|
381 |
+
|
382 |
+
logger.debug("START TRAINING")
|
383 |
+
print_freq = args.print_freq
|
384 |
+
start_epoch = args.start_epoch
|
385 |
+
losses_dict = {}
|
386 |
+
start_time = time.time()
|
387 |
+
for epoch in range(start_epoch, epochs):
|
388 |
+
_, batch_loss_dict = train_one_epoch(model, optimizer, data_loader_train, device,
|
389 |
+
epoch, print_freq, scaler)
|
390 |
+
lr_scheduler.step()
|
391 |
+
save_plots(losses_dict, batch_loss_dict,
|
392 |
+
output_dir=output_plots_dir)
|
393 |
+
|
394 |
+
coco_evaluator = evaluate(model, data_loader_test,
|
395 |
+
device=device, iou_types=['bbox'])
|
396 |
+
save_evaluate_summary(
|
397 |
+
coco_evaluator.coco_eval['bbox'].stats, output_dir)
|
398 |
+
|
399 |
+
save_model_checkpoint(
|
400 |
+
model, optimizer, lr_scheduler, epoch, scaler, output_dir, args)
|
401 |
+
total_time = time.time() - start_time
|
402 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
403 |
+
logger.debug(f"TRAINING TIME: {total_time_str}")
|
404 |
+
|
405 |
+
|
406 |
+
if __name__ == "__main__":
|
407 |
+
args = get_args_parser().parse_args()
|
408 |
+
main(args)
|