Spaces:
Running
Running
Initial commit.
Browse files- animeinsseg/__init__.py +708 -0
- animeinsseg/anime_instances.py +301 -0
- animeinsseg/data/__init__.py +2 -0
- animeinsseg/data/dataset.py +929 -0
- animeinsseg/data/maskrefine_dataset.py +235 -0
- animeinsseg/data/metrics.py +348 -0
- animeinsseg/data/paste_methods.py +327 -0
- animeinsseg/data/sampler.py +226 -0
- animeinsseg/data/syndataset.py +213 -0
- animeinsseg/data/transforms.py +299 -0
- animeinsseg/inpainting/__init__.py +0 -0
- animeinsseg/inpainting/ldm_inpaint.py +353 -0
- animeinsseg/inpainting/patch_match.py +203 -0
- animeinsseg/models/__init__.py +7 -0
- animeinsseg/models/animeseg_refine/__init__.py +189 -0
- animeinsseg/models/animeseg_refine/encoders.py +51 -0
- animeinsseg/models/animeseg_refine/isnet.py +645 -0
- animeinsseg/models/animeseg_refine/models.py +0 -0
- animeinsseg/models/animeseg_refine/modnet.py +667 -0
- animeinsseg/models/animeseg_refine/u2net.py +228 -0
- animeinsseg/models/rtmdet_inshead_custom.py +370 -0
- app.py +67 -0
- requirements.txt +15 -0
animeinsseg/__init__.py
ADDED
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mmcv, torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from einops import rearrange
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import cv2
|
7 |
+
import gc
|
8 |
+
import math
|
9 |
+
|
10 |
+
from .anime_instances import AnimeInstances
|
11 |
+
import numpy as np
|
12 |
+
from typing import List, Tuple, Union, Optional, Callable
|
13 |
+
from mmengine import Config
|
14 |
+
from mmengine.model.utils import revert_sync_batchnorm
|
15 |
+
from mmdet.utils import register_all_modules, get_test_pipeline_cfg
|
16 |
+
from mmdet.apis import init_detector
|
17 |
+
from mmdet.registry import MODELS
|
18 |
+
from mmdet.structures import DetDataSample, SampleList
|
19 |
+
from mmdet.structures.bbox.transforms import scale_boxes, get_box_wh
|
20 |
+
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead
|
21 |
+
from pycocotools.coco import COCO
|
22 |
+
from mmcv.transforms import Compose
|
23 |
+
from mmdet.models.detectors.single_stage import SingleStageDetector
|
24 |
+
|
25 |
+
from utils.logger import LOGGER
|
26 |
+
from utils.io_utils import square_pad_resize, find_all_imgs, imglist2grid, mask2rle, dict2json, scaledown_maxsize, resize_pad
|
27 |
+
from utils.constants import DEFAULT_DEVICE, CATEGORIES
|
28 |
+
from utils.booru_tagger import Tagger
|
29 |
+
|
30 |
+
from .models.animeseg_refine import AnimeSegmentation, load_refinenet, get_mask
|
31 |
+
from .models.rtmdet_inshead_custom import RTMDetInsSepBNHeadCustom
|
32 |
+
|
33 |
+
from torchvision.ops.boxes import box_iou
|
34 |
+
import torch.nn.functional as F
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_refine_batch(segmentations: np.ndarray, img: np.ndarray, max_batch_size: int = 4, device: str = 'cpu', input_size: int = 720):
|
38 |
+
|
39 |
+
img, (pt, pb, pl, pr) = resize_pad(img, input_size, pad_value=(0, 0, 0))
|
40 |
+
|
41 |
+
img = img.transpose((2, 0, 1)).astype(np.float32) / 255.
|
42 |
+
|
43 |
+
batch = []
|
44 |
+
num_seg = len(segmentations)
|
45 |
+
|
46 |
+
for ii, seg in enumerate(segmentations):
|
47 |
+
seg, _ = resize_pad(seg, input_size, 0)
|
48 |
+
seg = seg[None, ...]
|
49 |
+
batch.append(np.concatenate((img, seg)))
|
50 |
+
|
51 |
+
if ii == num_seg - 1:
|
52 |
+
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
|
53 |
+
elif len(batch) >= max_batch_size:
|
54 |
+
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
|
55 |
+
batch = []
|
56 |
+
|
57 |
+
|
58 |
+
VALID_REFINEMETHODS = {'animeseg', 'none'}
|
59 |
+
|
60 |
+
register_all_modules()
|
61 |
+
|
62 |
+
|
63 |
+
def single_image_preprocess(img: Union[str, np.ndarray], pipeline: Compose):
|
64 |
+
if isinstance(img, str):
|
65 |
+
img = mmcv.imread(img)
|
66 |
+
elif not isinstance(img, np.ndarray):
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
# img = square_pad_resize(img, 1024)[0]
|
70 |
+
|
71 |
+
data_ = dict(img=img, img_id=0)
|
72 |
+
data_ = pipeline(data_)
|
73 |
+
data_['inputs'] = [data_['inputs']]
|
74 |
+
data_['data_samples'] = [data_['data_samples']]
|
75 |
+
|
76 |
+
return data_, img
|
77 |
+
|
78 |
+
def animeseg_refine(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
|
79 |
+
|
80 |
+
num_pred = len(det_pred.pred_instances)
|
81 |
+
if num_pred < 1:
|
82 |
+
return
|
83 |
+
|
84 |
+
with torch.no_grad():
|
85 |
+
if to_rgb:
|
86 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
87 |
+
seg_thr = 0.5
|
88 |
+
mask = get_mask(net, img, s=input_size)[..., 0]
|
89 |
+
mask = (mask > seg_thr)
|
90 |
+
|
91 |
+
ins_masks = det_pred.pred_instances.masks
|
92 |
+
|
93 |
+
if isinstance(ins_masks, torch.Tensor):
|
94 |
+
tensor_device = ins_masks.device
|
95 |
+
tensor_dtype = ins_masks.dtype
|
96 |
+
to_tensor = True
|
97 |
+
ins_masks = ins_masks.cpu().numpy()
|
98 |
+
|
99 |
+
area_original = np.sum(ins_masks, axis=(1, 2))
|
100 |
+
masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
|
101 |
+
area_refined = np.sum(masks_refined, axis=(1, 2))
|
102 |
+
|
103 |
+
for ii in range(num_pred):
|
104 |
+
if area_refined[ii] / area_original[ii] > 0.3:
|
105 |
+
ins_masks[ii] = masks_refined[ii]
|
106 |
+
ins_masks = np.ascontiguousarray(ins_masks)
|
107 |
+
|
108 |
+
# for ii, insm in enumerate(ins_masks):
|
109 |
+
# cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
|
110 |
+
|
111 |
+
if to_tensor:
|
112 |
+
ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
|
113 |
+
|
114 |
+
det_pred.pred_instances.masks = ins_masks
|
115 |
+
# rst = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
|
116 |
+
# cv2.imwrite('rst.png', rst)
|
117 |
+
|
118 |
+
|
119 |
+
# def refinenet_forward(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
|
120 |
+
|
121 |
+
# num_pred = len(det_pred.pred_instances)
|
122 |
+
# if num_pred < 1:
|
123 |
+
# return
|
124 |
+
|
125 |
+
# with torch.no_grad():
|
126 |
+
# if to_rgb:
|
127 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
128 |
+
# seg_thr = 0.5
|
129 |
+
|
130 |
+
# h0, w0 = h, w = img.shape[0], img.shape[1]
|
131 |
+
# if h > w:
|
132 |
+
# h, w = input_size, int(input_size * w / h)
|
133 |
+
# else:
|
134 |
+
# h, w = int(input_size * h / w), input_size
|
135 |
+
# ph, pw = input_size - h, input_size - w
|
136 |
+
# tmpImg = np.zeros([s, s, 3], dtype=np.float32)
|
137 |
+
# tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
|
138 |
+
# tmpImg = tmpImg.transpose((2, 0, 1))
|
139 |
+
# tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
|
140 |
+
# with torch.no_grad():
|
141 |
+
# if use_amp:
|
142 |
+
# with amp.autocast():
|
143 |
+
# pred = model(tmpImg)
|
144 |
+
# pred = pred.to(dtype=torch.float32)
|
145 |
+
# else:
|
146 |
+
# pred = model(tmpImg)
|
147 |
+
# pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
148 |
+
# pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
|
149 |
+
# return pred
|
150 |
+
|
151 |
+
# mask = (mask > seg_thr)
|
152 |
+
|
153 |
+
# ins_masks = det_pred.pred_instances.masks
|
154 |
+
|
155 |
+
# if isinstance(ins_masks, torch.Tensor):
|
156 |
+
# tensor_device = ins_masks.device
|
157 |
+
# tensor_dtype = ins_masks.dtype
|
158 |
+
# to_tensor = True
|
159 |
+
# ins_masks = ins_masks.cpu().numpy()
|
160 |
+
|
161 |
+
# area_original = np.sum(ins_masks, axis=(1, 2))
|
162 |
+
# masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
|
163 |
+
# area_refined = np.sum(masks_refined, axis=(1, 2))
|
164 |
+
|
165 |
+
# for ii in range(num_pred):
|
166 |
+
# if area_refined[ii] / area_original[ii] > 0.3:
|
167 |
+
# ins_masks[ii] = masks_refined[ii]
|
168 |
+
# ins_masks = np.ascontiguousarray(ins_masks)
|
169 |
+
|
170 |
+
# # for ii, insm in enumerate(ins_masks):
|
171 |
+
# # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
|
172 |
+
|
173 |
+
# if to_tensor:
|
174 |
+
# ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
|
175 |
+
|
176 |
+
# det_pred.pred_instances.masks = ins_masks
|
177 |
+
|
178 |
+
|
179 |
+
def read_imglst_from_txt(filep) -> List[str]:
|
180 |
+
with open(filep, 'r', encoding='utf8') as f:
|
181 |
+
lines = f.read().splitlines()
|
182 |
+
return lines
|
183 |
+
|
184 |
+
|
185 |
+
class AnimeInsSeg:
|
186 |
+
|
187 |
+
def __init__(self, ckpt: str, default_det_size: int = 640, device: str = None,
|
188 |
+
refine_kwargs: dict = {'refine_method': 'refinenet_isnet'},
|
189 |
+
tagger_path: str = 'models/wd-v1-4-swinv2-tagger-v2/model.onnx', mask_thr=0.3) -> None:
|
190 |
+
self.ckpt = ckpt
|
191 |
+
self.default_det_size = default_det_size
|
192 |
+
self.device = DEFAULT_DEVICE if device is None else device
|
193 |
+
|
194 |
+
# init detector in mmdet's way
|
195 |
+
|
196 |
+
ckpt = torch.load(ckpt, map_location='cpu')
|
197 |
+
cfg = Config.fromstring(ckpt['meta']['cfg'].replace('file_client_args', 'backend_args'), file_format='.py')
|
198 |
+
cfg.visualizer = []
|
199 |
+
cfg.vis_backends = {}
|
200 |
+
cfg.default_hooks.pop('visualization')
|
201 |
+
|
202 |
+
|
203 |
+
# self.model: SingleStageDetector = init_detector(cfg, checkpoint=None, device='cpu')
|
204 |
+
model = MODELS.build(cfg.model)
|
205 |
+
model = revert_sync_batchnorm(model)
|
206 |
+
|
207 |
+
self.model = model.to(self.device).eval()
|
208 |
+
self.model.load_state_dict(ckpt['state_dict'], strict=False)
|
209 |
+
self.model = self.model.to(self.device).eval()
|
210 |
+
self.cfg = cfg.copy()
|
211 |
+
|
212 |
+
test_pipeline = get_test_pipeline_cfg(self.cfg.copy())
|
213 |
+
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
|
214 |
+
test_pipeline = Compose(test_pipeline)
|
215 |
+
self.default_data_pipeline = test_pipeline
|
216 |
+
|
217 |
+
self.refinenet = None
|
218 |
+
self.refinenet_animeseg: AnimeSegmentation = None
|
219 |
+
self.postprocess_refine: Callable = None
|
220 |
+
|
221 |
+
if refine_kwargs is not None:
|
222 |
+
self.set_refine_method(**refine_kwargs)
|
223 |
+
|
224 |
+
self.tagger = None
|
225 |
+
self.tagger_path = tagger_path
|
226 |
+
|
227 |
+
self.mask_thr = mask_thr
|
228 |
+
|
229 |
+
def init_tagger(self, tagger_path: str = None):
|
230 |
+
tagger_path = self.tagger_path if tagger_path is None else tagger_path
|
231 |
+
self.tagger = Tagger(self.tagger_path)
|
232 |
+
|
233 |
+
def infer_tags(self, instances: AnimeInstances, img: np.ndarray, infer_grey: bool = False):
|
234 |
+
if self.tagger is None:
|
235 |
+
self.init_tagger()
|
236 |
+
|
237 |
+
if infer_grey:
|
238 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None][..., [0, 0, 0]]
|
239 |
+
|
240 |
+
num_ins = len(instances)
|
241 |
+
for ii in range(num_ins):
|
242 |
+
bbox = instances.bboxes[ii]
|
243 |
+
mask = instances.masks[ii]
|
244 |
+
if isinstance(bbox, torch.Tensor):
|
245 |
+
bbox = bbox.cpu().numpy()
|
246 |
+
mask = mask.cpu().numpy()
|
247 |
+
bbox = bbox.astype(np.int32)
|
248 |
+
|
249 |
+
crop = img[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]].copy()
|
250 |
+
mask = mask[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]]
|
251 |
+
crop[mask == 0] = 255
|
252 |
+
tags, character_tags = self.tagger.label_cv2_bgr(crop)
|
253 |
+
exclude_tags = ['simple_background', 'white_background']
|
254 |
+
valid_tags = []
|
255 |
+
for tag in tags:
|
256 |
+
if tag in exclude_tags:
|
257 |
+
continue
|
258 |
+
valid_tags.append(tag)
|
259 |
+
instances.tags[ii] = ' '.join(valid_tags)
|
260 |
+
instances.character_tags[ii] = character_tags
|
261 |
+
|
262 |
+
@torch.no_grad()
|
263 |
+
def infer_embeddings(self, imgs, det_size = None):
|
264 |
+
|
265 |
+
def hijack_bbox_mask_post_process(
|
266 |
+
self,
|
267 |
+
results,
|
268 |
+
mask_feat,
|
269 |
+
cfg,
|
270 |
+
rescale: bool = False,
|
271 |
+
with_nms: bool = True,
|
272 |
+
img_meta: Optional[dict] = None):
|
273 |
+
|
274 |
+
stride = self.prior_generator.strides[0][0]
|
275 |
+
if rescale:
|
276 |
+
assert img_meta.get('scale_factor') is not None
|
277 |
+
scale_factor = [1 / s for s in img_meta['scale_factor']]
|
278 |
+
results.bboxes = scale_boxes(results.bboxes, scale_factor)
|
279 |
+
|
280 |
+
if hasattr(results, 'score_factors'):
|
281 |
+
# TODO: Add sqrt operation in order to be consistent with
|
282 |
+
# the paper.
|
283 |
+
score_factors = results.pop('score_factors')
|
284 |
+
results.scores = results.scores * score_factors
|
285 |
+
|
286 |
+
# filter small size bboxes
|
287 |
+
if cfg.get('min_bbox_size', -1) >= 0:
|
288 |
+
w, h = get_box_wh(results.bboxes)
|
289 |
+
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
|
290 |
+
if not valid_mask.all():
|
291 |
+
results = results[valid_mask]
|
292 |
+
|
293 |
+
# results.mask_feat = mask_feat
|
294 |
+
return results, mask_feat
|
295 |
+
|
296 |
+
def hijack_detector_predict(self: SingleStageDetector,
|
297 |
+
batch_inputs: torch.Tensor,
|
298 |
+
batch_data_samples: SampleList,
|
299 |
+
rescale: bool = True) -> SampleList:
|
300 |
+
x = self.extract_feat(batch_inputs)
|
301 |
+
|
302 |
+
bbox_head: RTMDetInsSepBNHeadCustom = self.bbox_head
|
303 |
+
old_postprocess = RTMDetInsSepBNHeadCustom._bbox_mask_post_process
|
304 |
+
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = hijack_bbox_mask_post_process
|
305 |
+
# results_list = bbox_head.predict(
|
306 |
+
# x, batch_data_samples, rescale=rescale)
|
307 |
+
|
308 |
+
batch_img_metas = [
|
309 |
+
data_samples.metainfo for data_samples in batch_data_samples
|
310 |
+
]
|
311 |
+
|
312 |
+
outs = bbox_head(x)
|
313 |
+
|
314 |
+
results_list = bbox_head.predict_by_feat(
|
315 |
+
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
|
316 |
+
|
317 |
+
# batch_data_samples = self.add_pred_to_datasample(
|
318 |
+
# batch_data_samples, results_list)
|
319 |
+
|
320 |
+
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = old_postprocess
|
321 |
+
return results_list
|
322 |
+
|
323 |
+
old_predict = SingleStageDetector.predict
|
324 |
+
SingleStageDetector.predict = hijack_detector_predict
|
325 |
+
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
|
326 |
+
|
327 |
+
if len(imgs) > 1:
|
328 |
+
imgs = tqdm(imgs)
|
329 |
+
model = self.model
|
330 |
+
img = imgs[0]
|
331 |
+
data_, img = test_pipeline(img)
|
332 |
+
data = model.data_preprocessor(data_, False)
|
333 |
+
instance_data, mask_feat = model(**data, mode='predict')[0]
|
334 |
+
SingleStageDetector.predict = old_predict
|
335 |
+
|
336 |
+
# print((instance_data.scores > 0.9).sum())
|
337 |
+
return img, instance_data, mask_feat
|
338 |
+
|
339 |
+
def segment_with_bboxes(self, img, bboxes: torch.Tensor, instance_data, mask_feat: torch.Tensor):
|
340 |
+
# instance_data.bboxes: x1, y1, x2, y2
|
341 |
+
maxidx = torch.argmax(instance_data.scores)
|
342 |
+
bbox = instance_data.bboxes[maxidx].cpu().numpy()
|
343 |
+
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
|
344 |
+
tgt_bboxes = instance_data.bboxes
|
345 |
+
|
346 |
+
im_h, im_w = img.shape[:2]
|
347 |
+
long_side = max(im_h, im_w)
|
348 |
+
bbox_head: RTMDetInsSepBNHeadCustom = self.model.bbox_head
|
349 |
+
priors, kernels = instance_data.priors, instance_data.kernels
|
350 |
+
stride = bbox_head.prior_generator.strides[0][0]
|
351 |
+
|
352 |
+
ins_bboxes, ins_segs, scores = [], [], []
|
353 |
+
for bbox in bboxes:
|
354 |
+
bbox = torch.from_numpy(np.array([bbox])).to(tgt_bboxes.dtype).to(tgt_bboxes.device)
|
355 |
+
ioulst = box_iou(bbox, tgt_bboxes).squeeze()
|
356 |
+
matched_idx = torch.argmax(ioulst)
|
357 |
+
|
358 |
+
mask_logits = bbox_head._mask_predict_by_feat_single(
|
359 |
+
mask_feat, kernels[matched_idx][None, ...], priors[matched_idx][None, ...])
|
360 |
+
|
361 |
+
mask_logits = F.interpolate(
|
362 |
+
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
|
363 |
+
|
364 |
+
mask_logits = F.interpolate(
|
365 |
+
mask_logits,
|
366 |
+
size=[long_side, long_side],
|
367 |
+
mode='bilinear',
|
368 |
+
align_corners=False)[..., :im_h, :im_w]
|
369 |
+
mask = mask_logits.sigmoid().squeeze()
|
370 |
+
mask = mask > 0.5
|
371 |
+
mask = mask.cpu().numpy()
|
372 |
+
ins_segs.append(mask)
|
373 |
+
|
374 |
+
matched_iou_score = ioulst[matched_idx]
|
375 |
+
matched_score = instance_data.scores[matched_idx]
|
376 |
+
scores.append(matched_score.cpu().item())
|
377 |
+
matched_bbox = tgt_bboxes[matched_idx]
|
378 |
+
|
379 |
+
ins_bboxes.append(matched_bbox.cpu().numpy())
|
380 |
+
# p1, p2 = (int(matched_bbox[0]), int(matched_bbox[1])), (int(matched_bbox[2]), int(matched_bbox[3]))
|
381 |
+
|
382 |
+
if len(ins_bboxes) > 0:
|
383 |
+
ins_bboxes = np.array(ins_bboxes).astype(np.int32)
|
384 |
+
ins_bboxes[:, 2:] -= ins_bboxes[:, :2]
|
385 |
+
ins_segs = np.array(ins_segs)
|
386 |
+
instances = AnimeInstances(ins_segs, ins_bboxes, scores)
|
387 |
+
|
388 |
+
self._postprocess_refine(instances, img)
|
389 |
+
drawed = instances.draw_instances(img)
|
390 |
+
# cv2.imshow('drawed', drawed)
|
391 |
+
# cv2.waitKey(0)
|
392 |
+
|
393 |
+
return instances
|
394 |
+
|
395 |
+
def set_detect_size(self, det_size: Union[int, Tuple]):
|
396 |
+
if isinstance(det_size, int):
|
397 |
+
det_size = (det_size, det_size)
|
398 |
+
self.default_data_pipeline.transforms[1].scale = det_size
|
399 |
+
self.default_data_pipeline.transforms[2].size = det_size
|
400 |
+
|
401 |
+
@torch.no_grad()
|
402 |
+
def infer(self, imgs: Union[List, str, np.ndarray],
|
403 |
+
pred_score_thr: float = 0.3,
|
404 |
+
refine_kwargs: dict = None,
|
405 |
+
output_type: str="tensor",
|
406 |
+
det_size: int = None,
|
407 |
+
save_dir: str = '',
|
408 |
+
save_visualization: bool = False,
|
409 |
+
save_annotation: str = '',
|
410 |
+
infer_tags: bool = False,
|
411 |
+
obj_id_start: int = -1,
|
412 |
+
img_id_start: int = -1,
|
413 |
+
verbose: bool = False,
|
414 |
+
infer_grey: bool = False,
|
415 |
+
save_mask_only: bool = False,
|
416 |
+
val_dir=None,
|
417 |
+
max_instances: int = 100,
|
418 |
+
**kwargs) -> Union[List[AnimeInstances], AnimeInstances, None]:
|
419 |
+
|
420 |
+
"""
|
421 |
+
Args:
|
422 |
+
imgs (str, ndarray, Sequence[str/ndarray]):
|
423 |
+
Either image files or loaded images.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
:obj:`AnimeInstances` or list[:obj:`AnimeInstances`]:
|
427 |
+
If save_annotation or save_annotation, return None.
|
428 |
+
"""
|
429 |
+
|
430 |
+
if det_size is not None:
|
431 |
+
self.set_detect_size(det_size)
|
432 |
+
if refine_kwargs is not None:
|
433 |
+
self.set_refine_method(**refine_kwargs)
|
434 |
+
|
435 |
+
self.set_max_instance(max_instances)
|
436 |
+
|
437 |
+
if isinstance(imgs, str):
|
438 |
+
if imgs.endswith('.txt'):
|
439 |
+
imgs = read_imglst_from_txt(imgs)
|
440 |
+
|
441 |
+
if save_annotation or save_visualization:
|
442 |
+
return self._infer_save_annotations(imgs, pred_score_thr, det_size, save_dir, save_visualization, \
|
443 |
+
save_annotation, infer_tags, obj_id_start, img_id_start, val_dir=val_dir)
|
444 |
+
else:
|
445 |
+
return self._infer_simple(imgs, pred_score_thr, det_size, output_type, infer_tags, verbose=verbose, infer_grey=infer_grey)
|
446 |
+
|
447 |
+
def _det_forward(self, img, test_pipeline, pred_score_thr: float = 0.3) -> Tuple[AnimeInstances, np.ndarray]:
|
448 |
+
data_, img = test_pipeline(img)
|
449 |
+
with torch.no_grad():
|
450 |
+
results: DetDataSample = self.model.test_step(data_)[0]
|
451 |
+
pred_instances = results.pred_instances
|
452 |
+
pred_instances = pred_instances[pred_instances.scores > pred_score_thr]
|
453 |
+
if len(pred_instances) < 1:
|
454 |
+
return AnimeInstances(), img
|
455 |
+
|
456 |
+
del data_
|
457 |
+
|
458 |
+
bboxes = pred_instances.bboxes.to(torch.int32)
|
459 |
+
bboxes[:, 2:] -= bboxes[:, :2]
|
460 |
+
masks = pred_instances.masks
|
461 |
+
scores = pred_instances.scores
|
462 |
+
return AnimeInstances(masks, bboxes, scores), img
|
463 |
+
|
464 |
+
def _infer_simple(self, imgs: Union[List, str, np.ndarray],
|
465 |
+
pred_score_thr: float = 0.3,
|
466 |
+
det_size: int = None,
|
467 |
+
output_type: str = "tensor",
|
468 |
+
infer_tags: bool = False,
|
469 |
+
infer_grey: bool = False,
|
470 |
+
verbose: bool = False) -> Union[DetDataSample, List[DetDataSample]]:
|
471 |
+
|
472 |
+
if isinstance(imgs, List):
|
473 |
+
return_list = True
|
474 |
+
else:
|
475 |
+
return_list = False
|
476 |
+
|
477 |
+
assert output_type in {'tensor', 'numpy'}
|
478 |
+
|
479 |
+
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
|
480 |
+
predictions = []
|
481 |
+
|
482 |
+
if len(imgs) > 1:
|
483 |
+
imgs = tqdm(imgs)
|
484 |
+
|
485 |
+
for img in imgs:
|
486 |
+
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
|
487 |
+
# drawed = instances.draw_instances(img)
|
488 |
+
# cv2.imwrite('drawed.jpg', drawed)
|
489 |
+
self.postprocess_results(instances, img)
|
490 |
+
# drawed = instances.draw_instances(img)
|
491 |
+
# cv2.imwrite('drawed_post.jpg', drawed)
|
492 |
+
|
493 |
+
if infer_tags:
|
494 |
+
self.infer_tags(instances, img, infer_grey)
|
495 |
+
|
496 |
+
if output_type == 'numpy':
|
497 |
+
instances.to_numpy()
|
498 |
+
|
499 |
+
predictions.append(instances)
|
500 |
+
|
501 |
+
if return_list:
|
502 |
+
return predictions
|
503 |
+
else:
|
504 |
+
return predictions[0]
|
505 |
+
|
506 |
+
def _infer_save_annotations(self, imgs: Union[List, str, np.ndarray],
|
507 |
+
pred_score_thr: float = 0.3,
|
508 |
+
det_size: int = None,
|
509 |
+
save_dir: str = '',
|
510 |
+
save_visualization: bool = False,
|
511 |
+
save_annotation: str = '',
|
512 |
+
infer_tags: bool = False,
|
513 |
+
obj_id_start: int = 100000000000,
|
514 |
+
img_id_start: int = 100000000000,
|
515 |
+
save_mask_only: bool = False,
|
516 |
+
val_dir = None,
|
517 |
+
**kwargs) -> None:
|
518 |
+
|
519 |
+
coco_api = None
|
520 |
+
if isinstance(imgs, str) and imgs.endswith('.json'):
|
521 |
+
coco_api = COCO(imgs)
|
522 |
+
|
523 |
+
if val_dir is None:
|
524 |
+
val_dir = osp.join(osp.dirname(osp.dirname(imgs)), 'val')
|
525 |
+
imgs = coco_api.getImgIds()
|
526 |
+
imgp2ids = {}
|
527 |
+
imgps, coco_imgmetas = [], []
|
528 |
+
for imgid in imgs:
|
529 |
+
imeta = coco_api.loadImgs(imgid)[0]
|
530 |
+
imgname = imeta['file_name']
|
531 |
+
imgp = osp.join(val_dir, imgname)
|
532 |
+
imgp2ids[imgp] = imgid
|
533 |
+
imgps.append(imgp)
|
534 |
+
coco_imgmetas.append(imeta)
|
535 |
+
imgs = imgps
|
536 |
+
|
537 |
+
test_pipeline, imgs, target_dir = self.prepare_data_pipeline(imgs, det_size)
|
538 |
+
if save_dir == '':
|
539 |
+
save_dir = osp.join(target_dir, \
|
540 |
+
osp.basename(self.ckpt).replace('.ckpt', '').replace('.pth', '').replace('.pt', ''))
|
541 |
+
|
542 |
+
if not osp.exists(save_dir):
|
543 |
+
os.makedirs(save_dir)
|
544 |
+
|
545 |
+
det_annotations = []
|
546 |
+
image_meta = []
|
547 |
+
obj_id = obj_id_start + 1
|
548 |
+
image_id = img_id_start + 1
|
549 |
+
|
550 |
+
for ii, img in enumerate(tqdm(imgs)):
|
551 |
+
# prepare data
|
552 |
+
if isinstance(img, str):
|
553 |
+
img_name = osp.basename(img)
|
554 |
+
else:
|
555 |
+
img_name = f'{ii}'.zfill(12) + '.jpg'
|
556 |
+
|
557 |
+
if coco_api is not None:
|
558 |
+
image_id = imgp2ids[img]
|
559 |
+
|
560 |
+
try:
|
561 |
+
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
|
562 |
+
except Exception as e:
|
563 |
+
raise e
|
564 |
+
if isinstance(e, torch.cuda.OutOfMemoryError):
|
565 |
+
gc.collect()
|
566 |
+
torch.cuda.empty_cache()
|
567 |
+
torch.cuda.ipc_collect()
|
568 |
+
try:
|
569 |
+
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
|
570 |
+
except:
|
571 |
+
LOGGER.warning(f'cuda out of memory: {img_name}')
|
572 |
+
if isinstance(img, str):
|
573 |
+
img = cv2.imread(img)
|
574 |
+
instances = None
|
575 |
+
|
576 |
+
if instances is not None:
|
577 |
+
self.postprocess_results(instances, img)
|
578 |
+
|
579 |
+
if infer_tags:
|
580 |
+
self.infer_tags(instances, img)
|
581 |
+
|
582 |
+
if save_visualization:
|
583 |
+
out_file = osp.join(save_dir, img_name)
|
584 |
+
self.save_visualization(out_file, img, instances)
|
585 |
+
|
586 |
+
if save_annotation:
|
587 |
+
im_h, im_w = img.shape[:2]
|
588 |
+
image_meta.append({
|
589 |
+
"id": image_id,"height": im_h,"width": im_w,
|
590 |
+
"file_name": img_name, "id": image_id
|
591 |
+
})
|
592 |
+
if instances is not None:
|
593 |
+
for ii in range(len(instances)):
|
594 |
+
segmentation = instances.masks[ii].squeeze().cpu().numpy().astype(np.uint8)
|
595 |
+
area = segmentation.sum()
|
596 |
+
segmentation *= 255
|
597 |
+
if save_mask_only:
|
598 |
+
cv2.imwrite(osp.join(save_dir, 'mask_' + str(ii).zfill(3) + '_' +img_name+'.png'), segmentation)
|
599 |
+
else:
|
600 |
+
score = instances.scores[ii]
|
601 |
+
if isinstance(score, torch.Tensor):
|
602 |
+
score = score.item()
|
603 |
+
score = float(score)
|
604 |
+
bbox = instances.bboxes[ii].cpu().numpy()
|
605 |
+
bbox = bbox.astype(np.float32).tolist()
|
606 |
+
segmentation = mask2rle(segmentation)
|
607 |
+
tag_string = instances.tags[ii]
|
608 |
+
tag_string_character = instances.character_tags[ii]
|
609 |
+
det_annotations.append({'id': obj_id, 'category_id': 0, 'iscrowd': 0, 'score': score,
|
610 |
+
'segmentation': segmentation, 'image_id': image_id, 'area': area,
|
611 |
+
'tag_string': tag_string, 'tag_string_character': tag_string_character, 'bbox': bbox
|
612 |
+
})
|
613 |
+
obj_id += 1
|
614 |
+
image_id += 1
|
615 |
+
|
616 |
+
if save_annotation != '' and not save_mask_only:
|
617 |
+
det_meta = {"info": {},"licenses": [], "images": image_meta,
|
618 |
+
"annotations": det_annotations, "categories": CATEGORIES}
|
619 |
+
detp = save_annotation
|
620 |
+
dict2json(det_meta, detp)
|
621 |
+
LOGGER.info(f'annotations saved to {detp}')
|
622 |
+
|
623 |
+
def set_refine_method(self, refine_method: str = 'none', refine_size: int = 720):
|
624 |
+
if refine_method == 'none':
|
625 |
+
self.postprocess_refine = None
|
626 |
+
elif refine_method == 'animeseg':
|
627 |
+
if self.refinenet_animeseg is None:
|
628 |
+
self.refinenet_animeseg = load_refinenet(refine_method)
|
629 |
+
self.postprocess_refine = lambda det_pred, img: \
|
630 |
+
animeseg_refine(det_pred, img, self.refinenet_animeseg, True, refine_size)
|
631 |
+
elif refine_method == 'refinenet_isnet':
|
632 |
+
if self.refinenet is None:
|
633 |
+
self.refinenet = load_refinenet(refine_method)
|
634 |
+
self.postprocess_refine = self._postprocess_refine
|
635 |
+
else:
|
636 |
+
raise NotImplementedError(f'Invalid refine method: {refine_method}')
|
637 |
+
|
638 |
+
def _postprocess_refine(self, instances: AnimeInstances, img: np.ndarray, refine_size: int = 720, max_refine_batch: int = 4, **kwargs):
|
639 |
+
|
640 |
+
if instances.is_empty:
|
641 |
+
return
|
642 |
+
|
643 |
+
segs = instances.masks
|
644 |
+
is_tensor = instances.is_tensor
|
645 |
+
if is_tensor:
|
646 |
+
segs = segs.cpu().numpy()
|
647 |
+
segs = segs.astype(np.float32)
|
648 |
+
im_h, im_w = img.shape[:2]
|
649 |
+
|
650 |
+
masks = []
|
651 |
+
with torch.no_grad():
|
652 |
+
for batch, (pt, pb, pl, pr) in prepare_refine_batch(segs, img, max_refine_batch, self.device, refine_size):
|
653 |
+
preds = self.refinenet(batch)[0][0].sigmoid()
|
654 |
+
if pb == 0:
|
655 |
+
pb = -im_h
|
656 |
+
if pr == 0:
|
657 |
+
pr = -im_w
|
658 |
+
preds = preds[..., pt: -pb, pl: -pr]
|
659 |
+
preds = torch.nn.functional.interpolate(preds, (im_h, im_w), mode='bilinear', align_corners=True)
|
660 |
+
masks.append(preds.cpu()[:, 0])
|
661 |
+
|
662 |
+
masks = (torch.concat(masks, dim=0) > self.mask_thr).to(self.device)
|
663 |
+
if not is_tensor:
|
664 |
+
masks = masks.cpu().numpy()
|
665 |
+
instances.masks = masks
|
666 |
+
|
667 |
+
|
668 |
+
def prepare_data_pipeline(self, imgs: Union[str, np.ndarray, List], det_size: int) -> Tuple[Compose, List, str]:
|
669 |
+
|
670 |
+
if det_size is None:
|
671 |
+
det_size = self.default_det_size
|
672 |
+
|
673 |
+
target_dir = './workspace/output'
|
674 |
+
# cast imgs to a list of np.ndarray or image_file_path if necessary
|
675 |
+
if isinstance(imgs, str):
|
676 |
+
if osp.isdir(imgs):
|
677 |
+
target_dir = imgs
|
678 |
+
imgs = find_all_imgs(imgs, abs_path=True)
|
679 |
+
elif osp.isfile(imgs):
|
680 |
+
target_dir = osp.dirname(imgs)
|
681 |
+
imgs = [imgs]
|
682 |
+
elif isinstance(imgs, np.ndarray) or isinstance(imgs, str):
|
683 |
+
imgs = [imgs]
|
684 |
+
elif isinstance(imgs, List):
|
685 |
+
if len(imgs) > 0:
|
686 |
+
if isinstance(imgs[0], np.ndarray) or isinstance(imgs[0], str):
|
687 |
+
pass
|
688 |
+
else:
|
689 |
+
raise NotImplementedError
|
690 |
+
else:
|
691 |
+
raise NotImplementedError
|
692 |
+
|
693 |
+
test_pipeline = lambda img: single_image_preprocess(img, pipeline=self.default_data_pipeline)
|
694 |
+
return test_pipeline, imgs, target_dir
|
695 |
+
|
696 |
+
def save_visualization(self, out_file: str, img: np.ndarray, instances: AnimeInstances):
|
697 |
+
drawed = instances.draw_instances(img)
|
698 |
+
mmcv.imwrite(drawed, out_file)
|
699 |
+
|
700 |
+
def postprocess_results(self, results: DetDataSample, img: np.ndarray) -> None:
|
701 |
+
if self.postprocess_refine is not None:
|
702 |
+
self.postprocess_refine(results, img)
|
703 |
+
|
704 |
+
def set_mask_threshold(self, mask_thr: float):
|
705 |
+
self.model.bbox_head.test_cfg['mask_thr_binary'] = mask_thr
|
706 |
+
|
707 |
+
def set_max_instance(self, num_ins):
|
708 |
+
self.model.bbox_head.test_cfg['max_per_img'] = num_ins
|
animeinsseg/anime_instances.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Union, Tuple
|
4 |
+
import torch
|
5 |
+
from utils.constants import COLOR_PALETTE
|
6 |
+
from utils.constants import get_color
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
def tags2multilines(tags: Union[str, List], lw, tf, max_width):
|
10 |
+
if isinstance(tags, str):
|
11 |
+
taglist = tags.split(' ')
|
12 |
+
else:
|
13 |
+
taglist = tags
|
14 |
+
|
15 |
+
sz = cv2.getTextSize(' ', 0, lw / 3, tf)
|
16 |
+
line_height = sz[0][1]
|
17 |
+
line_width = 0
|
18 |
+
if len(taglist) > 0:
|
19 |
+
lines = [taglist[0]]
|
20 |
+
if len(taglist) > 1:
|
21 |
+
for t in taglist[1:]:
|
22 |
+
textl = len(t) * line_height
|
23 |
+
if line_width + line_height + textl > max_width:
|
24 |
+
lines.append(t)
|
25 |
+
line_width = 0
|
26 |
+
else:
|
27 |
+
line_width = line_width + line_height + textl
|
28 |
+
lines[-1] = lines[-1] + ' ' + t
|
29 |
+
return lines, line_height
|
30 |
+
|
31 |
+
class AnimeInstances:
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
masks: Union[np.ndarray, torch.Tensor ]= None,
|
35 |
+
bboxes: Union[np.ndarray, torch.Tensor ] = None,
|
36 |
+
scores: Union[np.ndarray, torch.Tensor ] = None,
|
37 |
+
tags: List[str] = None, character_tags: List[str] = None) -> None:
|
38 |
+
self.masks = masks
|
39 |
+
self.tags = tags
|
40 |
+
self.bboxes = bboxes
|
41 |
+
|
42 |
+
|
43 |
+
if scores is None:
|
44 |
+
scores = [1.] * len(self)
|
45 |
+
if self.is_numpy:
|
46 |
+
scores = np.array(scores)
|
47 |
+
elif self.is_tensor:
|
48 |
+
scores = torch.tensor(scores)
|
49 |
+
|
50 |
+
self.scores = scores
|
51 |
+
|
52 |
+
if tags is None:
|
53 |
+
self.tags = [''] * len(self)
|
54 |
+
self.character_tags = [''] * len(self)
|
55 |
+
else:
|
56 |
+
self.tags = tags
|
57 |
+
self.character_tags = character_tags
|
58 |
+
|
59 |
+
@property
|
60 |
+
def is_cuda(self):
|
61 |
+
if isinstance(self.masks, torch.Tensor) and self.masks.is_cuda:
|
62 |
+
return True
|
63 |
+
else:
|
64 |
+
return False
|
65 |
+
|
66 |
+
@property
|
67 |
+
def is_tensor(self):
|
68 |
+
if self.is_empty:
|
69 |
+
return False
|
70 |
+
else:
|
71 |
+
return isinstance(self.masks, torch.Tensor)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def is_numpy(self):
|
75 |
+
if self.is_empty:
|
76 |
+
return True
|
77 |
+
else:
|
78 |
+
return isinstance(self.masks, np.ndarray)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def is_empty(self):
|
82 |
+
return self.masks is None or len(self.masks) == 0\
|
83 |
+
|
84 |
+
def remove_duplicated(self):
|
85 |
+
|
86 |
+
num_masks = len(self)
|
87 |
+
if num_masks < 2:
|
88 |
+
return
|
89 |
+
|
90 |
+
need_cvt = False
|
91 |
+
if self.is_numpy:
|
92 |
+
need_cvt = True
|
93 |
+
self.to_tensor()
|
94 |
+
|
95 |
+
mask_areas = torch.Tensor([mask.sum() for mask in self.masks])
|
96 |
+
sids = torch.argsort(mask_areas, descending=True)
|
97 |
+
sids = sids.cpu().numpy().tolist()
|
98 |
+
mask_areas = mask_areas[sids]
|
99 |
+
masks = self.masks[sids]
|
100 |
+
bboxes = self.bboxes[sids]
|
101 |
+
tags = [self.tags[sid] for sid in sids]
|
102 |
+
scores = self.scores[sids]
|
103 |
+
|
104 |
+
canvas = masks[0]
|
105 |
+
|
106 |
+
valid_ids: List = np.arange(num_masks).tolist()
|
107 |
+
for ii, mask in enumerate(masks[1:]):
|
108 |
+
|
109 |
+
mask_id = ii + 1
|
110 |
+
canvas_and = torch.bitwise_and(canvas, mask)
|
111 |
+
|
112 |
+
and_area = canvas_and.sum()
|
113 |
+
mask_area = mask_areas[mask_id]
|
114 |
+
|
115 |
+
if and_area / mask_area > 0.8:
|
116 |
+
valid_ids.remove(mask_id)
|
117 |
+
elif mask_id != num_masks - 1:
|
118 |
+
canvas = torch.bitwise_or(canvas, mask)
|
119 |
+
|
120 |
+
sids = valid_ids
|
121 |
+
self.masks = masks[sids]
|
122 |
+
self.bboxes = bboxes[sids]
|
123 |
+
self.tags = [tags[sid] for sid in sids]
|
124 |
+
self.scores = scores[sids]
|
125 |
+
|
126 |
+
if need_cvt:
|
127 |
+
self.to_numpy()
|
128 |
+
|
129 |
+
# sids =
|
130 |
+
|
131 |
+
def draw_instances(self,
|
132 |
+
img: np.ndarray,
|
133 |
+
draw_bbox: bool = True,
|
134 |
+
draw_ins_mask: bool = True,
|
135 |
+
draw_ins_contour: bool = True,
|
136 |
+
draw_tags: bool = False,
|
137 |
+
draw_indices: List = None,
|
138 |
+
mask_alpha: float = 0.4):
|
139 |
+
|
140 |
+
mask_alpha = 0.75
|
141 |
+
|
142 |
+
|
143 |
+
drawed = img.copy()
|
144 |
+
|
145 |
+
if self.is_empty:
|
146 |
+
return drawed
|
147 |
+
|
148 |
+
im_h, im_w = img.shape[:2]
|
149 |
+
|
150 |
+
mask_shape = self.masks[0].shape
|
151 |
+
if mask_shape[0] != im_h or mask_shape[1] != im_w:
|
152 |
+
drawed = cv2.resize(drawed, (mask_shape[1], mask_shape[0]), interpolation=cv2.INTER_AREA)
|
153 |
+
im_h, im_w = mask_shape[0], mask_shape[1]
|
154 |
+
|
155 |
+
if draw_indices is None:
|
156 |
+
draw_indices = list(range(len(self)))
|
157 |
+
ins_dict = {'mask': [], 'tags': [], 'score': [], 'bbox': [], 'character_tags': []}
|
158 |
+
colors = []
|
159 |
+
for idx in draw_indices:
|
160 |
+
ins = self.get_instance(idx, out_type='numpy')
|
161 |
+
for key, data in ins.items():
|
162 |
+
ins_dict[key].append(data)
|
163 |
+
colors.append(get_color(idx))
|
164 |
+
|
165 |
+
if draw_bbox:
|
166 |
+
lw = max(round(sum(drawed.shape) / 2 * 0.003), 2)
|
167 |
+
for color, bbox in zip(colors, ins_dict['bbox']):
|
168 |
+
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2] + bbox[0]), int(bbox[3] + bbox[1]))
|
169 |
+
cv2.rectangle(drawed, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
|
170 |
+
|
171 |
+
if draw_ins_mask:
|
172 |
+
drawed = drawed.astype(np.float32)
|
173 |
+
for color, mask in zip(colors, ins_dict['mask']):
|
174 |
+
p = mask.astype(np.float32)
|
175 |
+
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
|
176 |
+
alpha_msk = (mask_alpha * p)[..., None]
|
177 |
+
alpha_ori = 1 - alpha_msk
|
178 |
+
drawed = drawed * alpha_ori + alpha_msk * blend_mask
|
179 |
+
drawed = drawed.astype(np.uint8)
|
180 |
+
|
181 |
+
if draw_tags:
|
182 |
+
lw = max(round(sum(drawed.shape) / 2 * 0.002), 2)
|
183 |
+
tf = max(lw - 1, 1)
|
184 |
+
for color, tags, bbox in zip(colors, ins_dict['tags'], ins_dict['bbox']):
|
185 |
+
if not tags:
|
186 |
+
continue
|
187 |
+
lines, line_height = tags2multilines(tags, lw, tf, bbox[2])
|
188 |
+
for ii, l in enumerate(lines):
|
189 |
+
xy = (bbox[0], bbox[1] + line_height + int(line_height * 1.2 * ii))
|
190 |
+
cv2.putText(drawed, l, xy, 0, lw / 3, color, thickness=tf, lineType=cv2.LINE_AA)
|
191 |
+
|
192 |
+
# cv2.imshow('canvas', drawed)
|
193 |
+
# cv2.waitKey(0)
|
194 |
+
return drawed
|
195 |
+
|
196 |
+
|
197 |
+
def cuda(self):
|
198 |
+
if self.is_empty:
|
199 |
+
return self
|
200 |
+
self.to_tensor(device='cuda')
|
201 |
+
return self
|
202 |
+
|
203 |
+
def cpu(self):
|
204 |
+
if not self.is_tensor or not self.is_cuda:
|
205 |
+
return self
|
206 |
+
self.masks = self.masks.cpu()
|
207 |
+
self.scores = self.scores.cpu()
|
208 |
+
self.bboxes = self.bboxes.cpu()
|
209 |
+
return self
|
210 |
+
|
211 |
+
def to_tensor(self, device: str = 'cpu'):
|
212 |
+
if self.is_empty:
|
213 |
+
return self
|
214 |
+
elif self.is_tensor and self.masks.device == device:
|
215 |
+
return self
|
216 |
+
self.masks = torch.from_numpy(self.masks).to(device)
|
217 |
+
self.bboxes = torch.from_numpy(self.bboxes).to(device)
|
218 |
+
self.scores = torch.from_numpy(self.scores ).to(device)
|
219 |
+
return self
|
220 |
+
|
221 |
+
def to_numpy(self):
|
222 |
+
if self.is_numpy:
|
223 |
+
return self
|
224 |
+
if self.is_cuda:
|
225 |
+
self.masks = self.masks.cpu().numpy()
|
226 |
+
self.scores = self.scores.cpu().numpy()
|
227 |
+
self.bboxes = self.bboxes.cpu().numpy()
|
228 |
+
else:
|
229 |
+
self.masks = self.masks.numpy()
|
230 |
+
self.scores = self.scores.numpy()
|
231 |
+
self.bboxes = self.bboxes.numpy()
|
232 |
+
return self
|
233 |
+
|
234 |
+
def get_instance(self, ins_idx: int, out_type: str = None, device: str = None):
|
235 |
+
mask = self.masks[ins_idx]
|
236 |
+
tags = self.tags[ins_idx]
|
237 |
+
character_tags = self.character_tags[ins_idx]
|
238 |
+
bbox = self.bboxes[ins_idx]
|
239 |
+
score = self.scores[ins_idx]
|
240 |
+
if out_type is not None:
|
241 |
+
if out_type == 'numpy' and not self.is_numpy:
|
242 |
+
mask = mask.cpu().numpy()
|
243 |
+
bbox = bbox.cpu().numpy()
|
244 |
+
score = score.cpu().numpy()
|
245 |
+
if out_type == 'tensor' and not self.is_tensor:
|
246 |
+
mask = torch.from_numpy(mask)
|
247 |
+
bbox = torch.from_numpy(bbox)
|
248 |
+
score = torch.from_numpy(score)
|
249 |
+
if isinstance(mask, torch.Tensor) and device is not None and mask.device != device:
|
250 |
+
mask = mask.to(device)
|
251 |
+
bbox = bbox.to(device)
|
252 |
+
score = score.to(device)
|
253 |
+
|
254 |
+
return {
|
255 |
+
'mask': mask,
|
256 |
+
'tags': tags,
|
257 |
+
'character_tags': character_tags,
|
258 |
+
'bbox': bbox,
|
259 |
+
'score': score
|
260 |
+
}
|
261 |
+
|
262 |
+
def __len__(self):
|
263 |
+
if self.is_empty:
|
264 |
+
return 0
|
265 |
+
else:
|
266 |
+
return len(self.masks)
|
267 |
+
|
268 |
+
def resize(self, h, w, mode = 'area'):
|
269 |
+
if self.is_empty:
|
270 |
+
return
|
271 |
+
if self.is_tensor:
|
272 |
+
masks = self.masks.to(torch.float).unsqueeze(1)
|
273 |
+
oh, ow = masks.shape[2], masks.shape[3]
|
274 |
+
hs, ws = h / oh, w / ow
|
275 |
+
bboxes = self.bboxes.float()
|
276 |
+
bboxes[:, ::2] *= hs
|
277 |
+
bboxes[:, 1::2] *= ws
|
278 |
+
self.bboxes = torch.round(bboxes).int()
|
279 |
+
masks = torch.nn.functional.interpolate(masks, (h, w), mode=mode)
|
280 |
+
self.masks = masks.squeeze(1) > 0.3
|
281 |
+
|
282 |
+
def compose_masks(self, output_type=None):
|
283 |
+
if self.is_empty:
|
284 |
+
return None
|
285 |
+
else:
|
286 |
+
mask = self.masks[0]
|
287 |
+
if len(self.masks) > 1:
|
288 |
+
for m in self.masks[1:]:
|
289 |
+
if self.is_numpy:
|
290 |
+
mask = np.logical_or(mask, m)
|
291 |
+
else:
|
292 |
+
mask = torch.logical_or(mask, m)
|
293 |
+
if output_type is not None:
|
294 |
+
if output_type == 'numpy' and not self.is_numpy:
|
295 |
+
mask = mask.cpu().numpy()
|
296 |
+
if output_type == 'tensor' and not self.is_tensor:
|
297 |
+
mask = torch.from_numpy(mask)
|
298 |
+
return mask
|
299 |
+
|
300 |
+
|
301 |
+
|
animeinsseg/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from .dataset import *
|
2 |
+
# from .syndataset import *
|
animeinsseg/data/dataset.py
ADDED
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
4 |
+
import copy
|
5 |
+
from time import time
|
6 |
+
import mmcv
|
7 |
+
from mmcv.transforms import to_tensor
|
8 |
+
from mmdet.datasets.transforms import LoadAnnotations, RandomCrop, PackDetInputs, Mosaic, CachedMosaic, CachedMixUp, FilterAnnotations
|
9 |
+
from mmdet.structures.mask import BitmapMasks, PolygonMasks
|
10 |
+
from mmdet.datasets import CocoDataset
|
11 |
+
from mmdet.registry import DATASETS, TRANSFORMS
|
12 |
+
from numpy import random
|
13 |
+
from mmdet.structures.bbox import autocast_box_type, BaseBoxes
|
14 |
+
from mmengine.structures import InstanceData, PixelData
|
15 |
+
from mmdet.structures import DetDataSample
|
16 |
+
from utils.io_utils import bbox_overlap_xy
|
17 |
+
from utils.logger import LOGGER
|
18 |
+
|
19 |
+
@DATASETS.register_module()
|
20 |
+
class AnimeMangaMixedDataset(CocoDataset):
|
21 |
+
|
22 |
+
def __init__(self, animeins_root: str = None, animeins_annfile: str = None, manga109_annfile: str = None, manga109_root: str = None, *args, **kwargs) -> None:
|
23 |
+
self.animeins_annfile = animeins_annfile
|
24 |
+
self.animeins_root = animeins_root
|
25 |
+
self.manga109_annfile = manga109_annfile
|
26 |
+
self.manga109_root = manga109_root
|
27 |
+
self.cat_ids = []
|
28 |
+
self.cat_img_map = {}
|
29 |
+
super().__init__(*args, **kwargs)
|
30 |
+
LOGGER.info(f'total num data: {len(self.data_list)}')
|
31 |
+
|
32 |
+
|
33 |
+
def parse_data_info(self, raw_data_info: dict, data_prefix: str) -> Union[dict, List[dict]]:
|
34 |
+
"""Parse raw annotation to target format.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
raw_data_info (dict): Raw data information load from ``ann_file``
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Union[dict, List[dict]]: Parsed annotation.
|
41 |
+
"""
|
42 |
+
img_info = raw_data_info['raw_img_info']
|
43 |
+
ann_info = raw_data_info['raw_ann_info']
|
44 |
+
|
45 |
+
data_info = {}
|
46 |
+
|
47 |
+
# TODO: need to change data_prefix['img'] to data_prefix['img_path']
|
48 |
+
img_path = osp.join(data_prefix, img_info['file_name'])
|
49 |
+
if self.data_prefix.get('seg', None):
|
50 |
+
seg_map_path = osp.join(
|
51 |
+
self.data_prefix['seg'],
|
52 |
+
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
|
53 |
+
else:
|
54 |
+
seg_map_path = None
|
55 |
+
data_info['img_path'] = img_path
|
56 |
+
data_info['img_id'] = img_info['img_id']
|
57 |
+
data_info['seg_map_path'] = seg_map_path
|
58 |
+
data_info['height'] = img_info['height']
|
59 |
+
data_info['width'] = img_info['width']
|
60 |
+
|
61 |
+
instances = []
|
62 |
+
for i, ann in enumerate(ann_info):
|
63 |
+
instance = {}
|
64 |
+
|
65 |
+
if ann.get('ignore', False):
|
66 |
+
continue
|
67 |
+
x1, y1, w, h = ann['bbox']
|
68 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
69 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
70 |
+
if inter_w * inter_h == 0:
|
71 |
+
continue
|
72 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
73 |
+
continue
|
74 |
+
if ann['category_id'] not in self.cat_ids:
|
75 |
+
continue
|
76 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
77 |
+
|
78 |
+
if ann.get('iscrowd', False):
|
79 |
+
instance['ignore_flag'] = 1
|
80 |
+
else:
|
81 |
+
instance['ignore_flag'] = 0
|
82 |
+
instance['bbox'] = bbox
|
83 |
+
instance['bbox_label'] = self.cat2label[ann['category_id']]
|
84 |
+
|
85 |
+
if ann.get('segmentation', None):
|
86 |
+
instance['mask'] = ann['segmentation']
|
87 |
+
|
88 |
+
instances.append(instance)
|
89 |
+
data_info['instances'] = instances
|
90 |
+
return data_info
|
91 |
+
|
92 |
+
|
93 |
+
def load_data_list(self) -> List[dict]:
|
94 |
+
data_lst = []
|
95 |
+
if self.manga109_root is not None:
|
96 |
+
data_lst += self._data_list(self.manga109_annfile, osp.join(self.manga109_root, 'images'))
|
97 |
+
# if len(data_lst) > 8000:
|
98 |
+
# data_lst = data_lst[:500]
|
99 |
+
LOGGER.info(f'num data from manga109: {len(data_lst)}')
|
100 |
+
if self.animeins_root is not None:
|
101 |
+
animeins_annfile = osp.join(self.animeins_root, self.animeins_annfile)
|
102 |
+
data_prefix = osp.join(self.animeins_root, self.data_prefix['img'])
|
103 |
+
anime_lst = self._data_list(animeins_annfile, data_prefix)
|
104 |
+
# if len(anime_lst) > 8000:
|
105 |
+
# anime_lst = anime_lst[:500]
|
106 |
+
data_lst += anime_lst
|
107 |
+
LOGGER.info(f'num data from animeins: {len(data_lst)}')
|
108 |
+
return data_lst
|
109 |
+
|
110 |
+
def _data_list(self, annfile: str, data_prefix: str) -> List[dict]:
|
111 |
+
"""Load annotations from an annotation file named as ``ann_file``
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
List[dict]: A list of annotation.
|
115 |
+
""" # noqa: E501
|
116 |
+
with self.file_client.get_local_path(annfile) as local_path:
|
117 |
+
self.coco = self.COCOAPI(local_path)
|
118 |
+
# The order of returned `cat_ids` will not
|
119 |
+
# change with the order of the `classes`
|
120 |
+
self.cat_ids = self.coco.get_cat_ids(
|
121 |
+
cat_names=self.metainfo['classes'])
|
122 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
123 |
+
cat_img_map = copy.deepcopy(self.coco.cat_img_map)
|
124 |
+
for key, val in cat_img_map.items():
|
125 |
+
if key in self.cat_img_map:
|
126 |
+
self.cat_img_map[key] += val
|
127 |
+
else:
|
128 |
+
self.cat_img_map[key] = val
|
129 |
+
|
130 |
+
img_ids = self.coco.get_img_ids()
|
131 |
+
data_list = []
|
132 |
+
total_ann_ids = []
|
133 |
+
for img_id in img_ids:
|
134 |
+
raw_img_info = self.coco.load_imgs([img_id])[0]
|
135 |
+
raw_img_info['img_id'] = img_id
|
136 |
+
|
137 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
138 |
+
raw_ann_info = self.coco.load_anns(ann_ids)
|
139 |
+
total_ann_ids.extend(ann_ids)
|
140 |
+
|
141 |
+
parsed_data_info = self.parse_data_info({
|
142 |
+
'raw_ann_info':
|
143 |
+
raw_ann_info,
|
144 |
+
'raw_img_info':
|
145 |
+
raw_img_info
|
146 |
+
}, data_prefix)
|
147 |
+
data_list.append(parsed_data_info)
|
148 |
+
if self.ANN_ID_UNIQUE:
|
149 |
+
assert len(set(total_ann_ids)) == len(
|
150 |
+
total_ann_ids
|
151 |
+
), f"Annotation ids in '{annfile}' are not unique!"
|
152 |
+
|
153 |
+
del self.coco
|
154 |
+
|
155 |
+
return data_list
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
@TRANSFORMS.register_module()
|
160 |
+
class LoadAnnotationsNoSegs(LoadAnnotations):
|
161 |
+
|
162 |
+
def _process_masks(self, results: dict) -> list:
|
163 |
+
"""Process gt_masks and filter invalid polygons.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
list: Processed gt_masks.
|
170 |
+
"""
|
171 |
+
gt_masks = []
|
172 |
+
gt_ignore_flags = []
|
173 |
+
gt_ignore_mask_flags = []
|
174 |
+
for instance in results.get('instances', []):
|
175 |
+
gt_mask = instance['mask']
|
176 |
+
ignore_mask = False
|
177 |
+
# If the annotation of segmentation mask is invalid,
|
178 |
+
# ignore the whole instance.
|
179 |
+
if isinstance(gt_mask, list):
|
180 |
+
gt_mask = [
|
181 |
+
np.array(polygon) for polygon in gt_mask
|
182 |
+
if len(polygon) % 2 == 0 and len(polygon) >= 6
|
183 |
+
]
|
184 |
+
if len(gt_mask) == 0:
|
185 |
+
# ignore this instance and set gt_mask to a fake mask
|
186 |
+
instance['ignore_flag'] = 1
|
187 |
+
gt_mask = [np.zeros(6)]
|
188 |
+
elif not self.poly2mask:
|
189 |
+
# `PolygonMasks` requires a ploygon of format List[np.array],
|
190 |
+
# other formats are invalid.
|
191 |
+
instance['ignore_flag'] = 1
|
192 |
+
gt_mask = [np.zeros(6)]
|
193 |
+
elif isinstance(gt_mask, dict) and \
|
194 |
+
not (gt_mask.get('counts') is not None and
|
195 |
+
gt_mask.get('size') is not None and
|
196 |
+
isinstance(gt_mask['counts'], (list, str))):
|
197 |
+
# if gt_mask is a dict, it should include `counts` and `size`,
|
198 |
+
# so that `BitmapMasks` can uncompressed RLE
|
199 |
+
# instance['ignore_flag'] = 1
|
200 |
+
ignore_mask = True
|
201 |
+
gt_mask = [np.zeros(6)]
|
202 |
+
gt_masks.append(gt_mask)
|
203 |
+
# re-process gt_ignore_flags
|
204 |
+
gt_ignore_flags.append(instance['ignore_flag'])
|
205 |
+
gt_ignore_mask_flags.append(ignore_mask)
|
206 |
+
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
|
207 |
+
results['gt_ignore_mask_flags'] = np.array(gt_ignore_mask_flags, dtype=bool)
|
208 |
+
return gt_masks
|
209 |
+
|
210 |
+
def _load_masks(self, results: dict) -> None:
|
211 |
+
"""Private function to load mask annotations.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
215 |
+
"""
|
216 |
+
h, w = results['ori_shape']
|
217 |
+
gt_masks = self._process_masks(results)
|
218 |
+
if self.poly2mask:
|
219 |
+
p2masks = []
|
220 |
+
if len(gt_masks) > 0:
|
221 |
+
for ins, mask, ignore_mask in zip(results['instances'], gt_masks, results['gt_ignore_mask_flags']):
|
222 |
+
bbox = [int(c) for c in ins['bbox']]
|
223 |
+
if ignore_mask:
|
224 |
+
m = np.zeros((h, w), dtype=np.uint8)
|
225 |
+
m[bbox[1]:bbox[3], bbox[0]: bbox[2]] = 255
|
226 |
+
# m[bbox[1]:bbox[3], bbox[0]: bbox[2]]
|
227 |
+
p2masks.append(m)
|
228 |
+
else:
|
229 |
+
p2masks.append(self._poly2mask(mask, h, w))
|
230 |
+
# import cv2
|
231 |
+
# # cv2.imwrite('tmp_mask.png', p2masks[-1] * 255)
|
232 |
+
# cv2.imwrite('tmp_img.png', results['img'])
|
233 |
+
# cv2.imwrite('tmp_bbox.png', m * 225)
|
234 |
+
# print(p2masks[-1].shape, p2masks[-1].dtype)
|
235 |
+
gt_masks = BitmapMasks(p2masks, h, w)
|
236 |
+
else:
|
237 |
+
# fake polygon masks will be ignored in `PackDetInputs`
|
238 |
+
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
|
239 |
+
results['gt_masks'] = gt_masks
|
240 |
+
|
241 |
+
def transform(self, results: dict) -> dict:
|
242 |
+
"""Function to load multiple types annotations.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
dict: The dict contains loaded bounding box, label and
|
249 |
+
semantic segmentation.
|
250 |
+
"""
|
251 |
+
|
252 |
+
if self.with_bbox:
|
253 |
+
self._load_bboxes(results)
|
254 |
+
if self.with_label:
|
255 |
+
self._load_labels(results)
|
256 |
+
if self.with_mask:
|
257 |
+
self._load_masks(results)
|
258 |
+
if self.with_seg:
|
259 |
+
self._load_seg_map(results)
|
260 |
+
|
261 |
+
return results
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
@TRANSFORMS.register_module()
|
266 |
+
class PackDetIputsNoSeg(PackDetInputs):
|
267 |
+
|
268 |
+
mapping_table = {
|
269 |
+
'gt_bboxes': 'bboxes',
|
270 |
+
'gt_bboxes_labels': 'labels',
|
271 |
+
'gt_ignore_mask_flags': 'ignore_mask',
|
272 |
+
'gt_masks': 'masks'
|
273 |
+
}
|
274 |
+
|
275 |
+
def transform(self, results: dict) -> dict:
|
276 |
+
"""Method to pack the input data.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
results (dict): Result dict from the data pipeline.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
dict:
|
283 |
+
|
284 |
+
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
285 |
+
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
|
286 |
+
sample.
|
287 |
+
"""
|
288 |
+
packed_results = dict()
|
289 |
+
if 'img' in results:
|
290 |
+
img = results['img']
|
291 |
+
if len(img.shape) < 3:
|
292 |
+
img = np.expand_dims(img, -1)
|
293 |
+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
294 |
+
packed_results['inputs'] = to_tensor(img)
|
295 |
+
|
296 |
+
if 'gt_ignore_flags' in results:
|
297 |
+
valid_idx = np.where(results['gt_ignore_flags'] == 0)[0]
|
298 |
+
ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0]
|
299 |
+
|
300 |
+
data_sample = DetDataSample()
|
301 |
+
instance_data = InstanceData()
|
302 |
+
ignore_instance_data = InstanceData()
|
303 |
+
|
304 |
+
for key in self.mapping_table.keys():
|
305 |
+
if key not in results:
|
306 |
+
continue
|
307 |
+
if key == 'gt_masks' or isinstance(results[key], BaseBoxes):
|
308 |
+
if 'gt_ignore_flags' in results:
|
309 |
+
instance_data[
|
310 |
+
self.mapping_table[key]] = results[key][valid_idx]
|
311 |
+
ignore_instance_data[
|
312 |
+
self.mapping_table[key]] = results[key][ignore_idx]
|
313 |
+
else:
|
314 |
+
instance_data[self.mapping_table[key]] = results[key]
|
315 |
+
else:
|
316 |
+
if 'gt_ignore_flags' in results:
|
317 |
+
instance_data[self.mapping_table[key]] = to_tensor(
|
318 |
+
results[key][valid_idx])
|
319 |
+
ignore_instance_data[self.mapping_table[key]] = to_tensor(
|
320 |
+
results[key][ignore_idx])
|
321 |
+
else:
|
322 |
+
instance_data[self.mapping_table[key]] = to_tensor(
|
323 |
+
results[key])
|
324 |
+
data_sample.gt_instances = instance_data
|
325 |
+
data_sample.ignored_instances = ignore_instance_data
|
326 |
+
|
327 |
+
if 'proposals' in results:
|
328 |
+
proposals = InstanceData(
|
329 |
+
bboxes=to_tensor(results['proposals']),
|
330 |
+
scores=to_tensor(results['proposals_scores']))
|
331 |
+
data_sample.proposals = proposals
|
332 |
+
|
333 |
+
if 'gt_seg_map' in results:
|
334 |
+
gt_sem_seg_data = dict(
|
335 |
+
sem_seg=to_tensor(results['gt_seg_map'][None, ...].copy()))
|
336 |
+
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
337 |
+
|
338 |
+
img_meta = {}
|
339 |
+
for key in self.meta_keys:
|
340 |
+
assert key in results, f'`{key}` is not found in `results`, ' \
|
341 |
+
f'the valid keys are {list(results)}.'
|
342 |
+
img_meta[key] = results[key]
|
343 |
+
|
344 |
+
data_sample.set_metainfo(img_meta)
|
345 |
+
packed_results['data_samples'] = data_sample
|
346 |
+
|
347 |
+
return packed_results
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
def translate_bitmapmask(bitmap_masks: BitmapMasks,
|
352 |
+
out_shape,
|
353 |
+
offset_x,
|
354 |
+
offset_y,):
|
355 |
+
|
356 |
+
if len(bitmap_masks.masks) == 0:
|
357 |
+
translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
|
358 |
+
else:
|
359 |
+
masks = bitmap_masks.masks
|
360 |
+
out_h, out_w = out_shape
|
361 |
+
mask_h, mask_w = masks.shape[1:]
|
362 |
+
|
363 |
+
translated_masks = np.zeros((masks.shape[0], *out_shape),
|
364 |
+
dtype=masks.dtype)
|
365 |
+
|
366 |
+
ix, iy = bbox_overlap_xy([0, 0, out_w, out_h], [offset_x, offset_y, mask_w, mask_h])
|
367 |
+
if ix > 2 and iy > 2:
|
368 |
+
if offset_x > 0:
|
369 |
+
mx1 = 0
|
370 |
+
tx1 = offset_x
|
371 |
+
else:
|
372 |
+
mx1 = -offset_x
|
373 |
+
tx1 = 0
|
374 |
+
mx2 = min(out_w - offset_x, mask_w)
|
375 |
+
tx2 = tx1 + mx2 - mx1
|
376 |
+
|
377 |
+
if offset_y > 0:
|
378 |
+
my1 = 0
|
379 |
+
ty1 = offset_y
|
380 |
+
else:
|
381 |
+
my1 = -offset_y
|
382 |
+
ty1 = 0
|
383 |
+
my2 = min(out_h - offset_y, mask_h)
|
384 |
+
ty2 = ty1 + my2 - my1
|
385 |
+
|
386 |
+
translated_masks[:, ty1: ty2, tx1: tx2] = \
|
387 |
+
masks[:, my1: my2, mx1: mx2]
|
388 |
+
|
389 |
+
return BitmapMasks(translated_masks, *out_shape)
|
390 |
+
|
391 |
+
|
392 |
+
@TRANSFORMS.register_module()
|
393 |
+
class CachedMosaicNoSeg(CachedMosaic):
|
394 |
+
|
395 |
+
@autocast_box_type()
|
396 |
+
def transform(self, results: dict) -> dict:
|
397 |
+
|
398 |
+
"""Mosaic transform function.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
results (dict): Result dict.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
dict: Updated result dict.
|
405 |
+
"""
|
406 |
+
# cache and pop images
|
407 |
+
self.results_cache.append(copy.deepcopy(results))
|
408 |
+
if len(self.results_cache) > self.max_cached_images:
|
409 |
+
if self.random_pop:
|
410 |
+
index = random.randint(0, len(self.results_cache) - 1)
|
411 |
+
else:
|
412 |
+
index = 0
|
413 |
+
self.results_cache.pop(index)
|
414 |
+
|
415 |
+
if len(self.results_cache) <= 4:
|
416 |
+
return results
|
417 |
+
|
418 |
+
if random.uniform(0, 1) > self.prob:
|
419 |
+
return results
|
420 |
+
indices = self.get_indexes(self.results_cache)
|
421 |
+
mix_results = [copy.deepcopy(self.results_cache[i]) for i in indices]
|
422 |
+
|
423 |
+
# TODO: refactor mosaic to reuse these code.
|
424 |
+
mosaic_bboxes = []
|
425 |
+
mosaic_bboxes_labels = []
|
426 |
+
mosaic_ignore_flags = []
|
427 |
+
mosaic_masks = []
|
428 |
+
mosaic_ignore_mask_flags = []
|
429 |
+
with_mask = True if 'gt_masks' in results else False
|
430 |
+
|
431 |
+
if len(results['img'].shape) == 3:
|
432 |
+
mosaic_img = np.full(
|
433 |
+
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3),
|
434 |
+
self.pad_val,
|
435 |
+
dtype=results['img'].dtype)
|
436 |
+
else:
|
437 |
+
mosaic_img = np.full(
|
438 |
+
(int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)),
|
439 |
+
self.pad_val,
|
440 |
+
dtype=results['img'].dtype)
|
441 |
+
|
442 |
+
# mosaic center x, y
|
443 |
+
center_x = int(
|
444 |
+
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
445 |
+
center_y = int(
|
446 |
+
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
447 |
+
center_position = (center_x, center_y)
|
448 |
+
|
449 |
+
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
450 |
+
|
451 |
+
n_manga = 0
|
452 |
+
for i, loc in enumerate(loc_strs):
|
453 |
+
if loc == 'top_left':
|
454 |
+
results_patch = copy.deepcopy(results)
|
455 |
+
else:
|
456 |
+
results_patch = copy.deepcopy(mix_results[i - 1])
|
457 |
+
|
458 |
+
is_manga = results_patch['img_id'] > 900000000
|
459 |
+
if is_manga:
|
460 |
+
n_manga += 1
|
461 |
+
if n_manga > 3:
|
462 |
+
continue
|
463 |
+
im_h, im_w = results_patch['img'].shape[:2]
|
464 |
+
if im_w > im_h and random.random() < 0.75:
|
465 |
+
results_patch = hcrop(results_patch, (im_h, im_w // 2), True)
|
466 |
+
|
467 |
+
img_i = results_patch['img']
|
468 |
+
h_i, w_i = img_i.shape[:2]
|
469 |
+
# keep_ratio resize
|
470 |
+
scale_ratio_i = min(self.img_scale[1] / h_i,
|
471 |
+
self.img_scale[0] / w_i)
|
472 |
+
img_i = mmcv.imresize(
|
473 |
+
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
|
474 |
+
|
475 |
+
# compute the combine parameters
|
476 |
+
paste_coord, crop_coord = self._mosaic_combine(
|
477 |
+
loc, center_position, img_i.shape[:2][::-1])
|
478 |
+
x1_p, y1_p, x2_p, y2_p = paste_coord
|
479 |
+
x1_c, y1_c, x2_c, y2_c = crop_coord
|
480 |
+
|
481 |
+
# crop and paste image
|
482 |
+
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
|
483 |
+
|
484 |
+
# adjust coordinate
|
485 |
+
gt_bboxes_i = results_patch['gt_bboxes']
|
486 |
+
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
|
487 |
+
gt_ignore_flags_i = results_patch['gt_ignore_flags']
|
488 |
+
gt_ignore_mask_i = results_patch['gt_ignore_mask_flags']
|
489 |
+
|
490 |
+
padw = x1_p - x1_c
|
491 |
+
padh = y1_p - y1_c
|
492 |
+
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
|
493 |
+
gt_bboxes_i.translate_([padw, padh])
|
494 |
+
mosaic_bboxes.append(gt_bboxes_i)
|
495 |
+
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
|
496 |
+
mosaic_ignore_flags.append(gt_ignore_flags_i)
|
497 |
+
mosaic_ignore_mask_flags.append(gt_ignore_mask_i)
|
498 |
+
if with_mask and results_patch.get('gt_masks', None) is not None:
|
499 |
+
|
500 |
+
gt_masks_i = results_patch['gt_masks']
|
501 |
+
gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
|
502 |
+
|
503 |
+
gt_masks_i = translate_bitmapmask(gt_masks_i,
|
504 |
+
out_shape=(int(self.img_scale[0] * 2),
|
505 |
+
int(self.img_scale[1] * 2)),
|
506 |
+
offset_x=padw, offset_y=padh)
|
507 |
+
|
508 |
+
# gt_masks_i = gt_masks_i.translate(
|
509 |
+
# out_shape=(int(self.img_scale[0] * 2),
|
510 |
+
# int(self.img_scale[1] * 2)),
|
511 |
+
# offset=padw,
|
512 |
+
# direction='horizontal')
|
513 |
+
# gt_masks_i = gt_masks_i.translate(
|
514 |
+
# out_shape=(int(self.img_scale[0] * 2),
|
515 |
+
# int(self.img_scale[1] * 2)),
|
516 |
+
# offset=padh,
|
517 |
+
# direction='vertical')
|
518 |
+
mosaic_masks.append(gt_masks_i)
|
519 |
+
|
520 |
+
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
|
521 |
+
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
|
522 |
+
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
|
523 |
+
mosaic_ignore_mask_flags = np.concatenate(mosaic_ignore_mask_flags, 0)
|
524 |
+
|
525 |
+
if self.bbox_clip_border:
|
526 |
+
mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]])
|
527 |
+
# remove outside bboxes
|
528 |
+
inside_inds = mosaic_bboxes.is_inside(
|
529 |
+
[2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy()
|
530 |
+
|
531 |
+
mosaic_bboxes = mosaic_bboxes[inside_inds]
|
532 |
+
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
|
533 |
+
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
|
534 |
+
mosaic_ignore_mask_flags = mosaic_ignore_mask_flags[inside_inds]
|
535 |
+
|
536 |
+
results['img'] = mosaic_img
|
537 |
+
results['img_shape'] = mosaic_img.shape
|
538 |
+
results['gt_bboxes'] = mosaic_bboxes
|
539 |
+
results['gt_bboxes_labels'] = mosaic_bboxes_labels
|
540 |
+
results['gt_ignore_flags'] = mosaic_ignore_flags
|
541 |
+
results['gt_ignore_mask_flags'] = mosaic_ignore_mask_flags
|
542 |
+
|
543 |
+
|
544 |
+
if with_mask:
|
545 |
+
total_instances = len(inside_inds)
|
546 |
+
assert total_instances == np.array([m.masks.shape[0] for m in mosaic_masks]).sum()
|
547 |
+
if total_instances > 10:
|
548 |
+
masks = np.empty((inside_inds.sum(), mosaic_masks[0].height, mosaic_masks[0].width), dtype=np.uint8)
|
549 |
+
msk_idx = 0
|
550 |
+
mmsk_idx = 0
|
551 |
+
for m in mosaic_masks:
|
552 |
+
for ii in range(m.masks.shape[0]):
|
553 |
+
if inside_inds[msk_idx]:
|
554 |
+
masks[mmsk_idx] = m.masks[ii]
|
555 |
+
mmsk_idx += 1
|
556 |
+
msk_idx += 1
|
557 |
+
results['gt_masks'] = BitmapMasks(masks, mosaic_masks[0].height, mosaic_masks[0].width)
|
558 |
+
else:
|
559 |
+
mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
|
560 |
+
results['gt_masks'] = mosaic_masks[inside_inds]
|
561 |
+
# assert np.all(results['gt_masks'].masks == masks) and results['gt_masks'].masks.shape == masks.shape
|
562 |
+
|
563 |
+
# assert inside_inds.sum() == results['gt_masks'].masks.shape[0]
|
564 |
+
return results
|
565 |
+
|
566 |
+
@TRANSFORMS.register_module()
|
567 |
+
class FilterAnnotationsNoSeg(FilterAnnotations):
|
568 |
+
|
569 |
+
def __init__(self,
|
570 |
+
min_gt_bbox_wh: Tuple[int, int] = (1, 1),
|
571 |
+
min_gt_mask_area: int = 1,
|
572 |
+
by_box: bool = True,
|
573 |
+
by_mask: bool = False,
|
574 |
+
keep_empty: bool = True) -> None:
|
575 |
+
# TODO: add more filter options
|
576 |
+
assert by_box or by_mask
|
577 |
+
self.min_gt_bbox_wh = min_gt_bbox_wh
|
578 |
+
self.min_gt_mask_area = min_gt_mask_area
|
579 |
+
self.by_box = by_box
|
580 |
+
self.by_mask = by_mask
|
581 |
+
self.keep_empty = keep_empty
|
582 |
+
|
583 |
+
@autocast_box_type()
|
584 |
+
def transform(self, results: dict) -> Union[dict, None]:
|
585 |
+
"""Transform function to filter annotations.
|
586 |
+
|
587 |
+
Args:
|
588 |
+
results (dict): Result dict.
|
589 |
+
|
590 |
+
Returns:
|
591 |
+
dict: Updated result dict.
|
592 |
+
"""
|
593 |
+
assert 'gt_bboxes' in results
|
594 |
+
gt_bboxes = results['gt_bboxes']
|
595 |
+
if gt_bboxes.shape[0] == 0:
|
596 |
+
return results
|
597 |
+
|
598 |
+
tests = []
|
599 |
+
if self.by_box:
|
600 |
+
tests.append(
|
601 |
+
((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
|
602 |
+
(gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
|
603 |
+
|
604 |
+
if self.by_mask:
|
605 |
+
assert 'gt_masks' in results
|
606 |
+
gt_masks = results['gt_masks']
|
607 |
+
tests.append(gt_masks.areas >= self.min_gt_mask_area)
|
608 |
+
|
609 |
+
keep = tests[0]
|
610 |
+
for t in tests[1:]:
|
611 |
+
keep = keep & t
|
612 |
+
|
613 |
+
# if not keep.any():
|
614 |
+
# if self.keep_empty:
|
615 |
+
# return None
|
616 |
+
|
617 |
+
assert len(results['gt_ignore_flags']) == len(results['gt_ignore_mask_flags'])
|
618 |
+
keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags', 'gt_ignore_mask_flags')
|
619 |
+
for key in keys:
|
620 |
+
if key in results:
|
621 |
+
try:
|
622 |
+
results[key] = results[key][keep]
|
623 |
+
except Exception as e:
|
624 |
+
raise e
|
625 |
+
|
626 |
+
return results
|
627 |
+
|
628 |
+
|
629 |
+
def hcrop(results: dict, crop_size: Tuple[int, int],
|
630 |
+
allow_negative_crop: bool) -> Union[dict, None]:
|
631 |
+
|
632 |
+
assert crop_size[0] > 0 and crop_size[1] > 0
|
633 |
+
img = results['img']
|
634 |
+
offset_h, offset_w = 0, random.choice([0, crop_size[1]])
|
635 |
+
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
|
636 |
+
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
|
637 |
+
|
638 |
+
# Record the homography matrix for the RandomCrop
|
639 |
+
homography_matrix = np.array(
|
640 |
+
[[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
|
641 |
+
dtype=np.float32)
|
642 |
+
if results.get('homography_matrix', None) is None:
|
643 |
+
results['homography_matrix'] = homography_matrix
|
644 |
+
else:
|
645 |
+
results['homography_matrix'] = homography_matrix @ results[
|
646 |
+
'homography_matrix']
|
647 |
+
|
648 |
+
# crop the image
|
649 |
+
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
650 |
+
img_shape = img.shape
|
651 |
+
results['img'] = img
|
652 |
+
results['img_shape'] = img_shape
|
653 |
+
|
654 |
+
# crop bboxes accordingly and clip to the image boundary
|
655 |
+
if results.get('gt_bboxes', None) is not None:
|
656 |
+
bboxes = results['gt_bboxes']
|
657 |
+
bboxes.translate_([-offset_w, -offset_h])
|
658 |
+
bboxes.clip_(img_shape[:2])
|
659 |
+
valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
|
660 |
+
# If the crop does not contain any gt-bbox area and
|
661 |
+
# allow_negative_crop is False, skip this image.
|
662 |
+
if (not valid_inds.any() and not allow_negative_crop):
|
663 |
+
return None
|
664 |
+
|
665 |
+
results['gt_bboxes'] = bboxes[valid_inds]
|
666 |
+
|
667 |
+
if results.get('gt_ignore_flags', None) is not None:
|
668 |
+
results['gt_ignore_flags'] = \
|
669 |
+
results['gt_ignore_flags'][valid_inds]
|
670 |
+
|
671 |
+
if results.get('gt_ignore_mask_flags', None) is not None:
|
672 |
+
results['gt_ignore_mask_flags'] = \
|
673 |
+
results['gt_ignore_mask_flags'][valid_inds]
|
674 |
+
|
675 |
+
if results.get('gt_bboxes_labels', None) is not None:
|
676 |
+
results['gt_bboxes_labels'] = \
|
677 |
+
results['gt_bboxes_labels'][valid_inds]
|
678 |
+
|
679 |
+
if results.get('gt_masks', None) is not None:
|
680 |
+
results['gt_masks'] = results['gt_masks'][
|
681 |
+
valid_inds.nonzero()[0]].crop(
|
682 |
+
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
|
683 |
+
results['gt_bboxes'] = results['gt_masks'].get_bboxes(
|
684 |
+
type(results['gt_bboxes']))
|
685 |
+
|
686 |
+
# crop semantic seg
|
687 |
+
if results.get('gt_seg_map', None) is not None:
|
688 |
+
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
|
689 |
+
crop_x1:crop_x2]
|
690 |
+
|
691 |
+
return results
|
692 |
+
|
693 |
+
|
694 |
+
@TRANSFORMS.register_module()
|
695 |
+
class RandomCropNoSeg(RandomCrop):
|
696 |
+
|
697 |
+
def _crop_data(self, results: dict, crop_size: Tuple[int, int],
|
698 |
+
allow_negative_crop: bool) -> Union[dict, None]:
|
699 |
+
|
700 |
+
assert crop_size[0] > 0 and crop_size[1] > 0
|
701 |
+
img = results['img']
|
702 |
+
margin_h = max(img.shape[0] - crop_size[0], 0)
|
703 |
+
margin_w = max(img.shape[1] - crop_size[1], 0)
|
704 |
+
offset_h, offset_w = self._rand_offset((margin_h, margin_w))
|
705 |
+
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
|
706 |
+
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
|
707 |
+
|
708 |
+
# Record the homography matrix for the RandomCrop
|
709 |
+
homography_matrix = np.array(
|
710 |
+
[[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
|
711 |
+
dtype=np.float32)
|
712 |
+
if results.get('homography_matrix', None) is None:
|
713 |
+
results['homography_matrix'] = homography_matrix
|
714 |
+
else:
|
715 |
+
results['homography_matrix'] = homography_matrix @ results[
|
716 |
+
'homography_matrix']
|
717 |
+
|
718 |
+
# crop the image
|
719 |
+
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
720 |
+
img_shape = img.shape
|
721 |
+
results['img'] = img
|
722 |
+
results['img_shape'] = img_shape
|
723 |
+
|
724 |
+
# crop bboxes accordingly and clip to the image boundary
|
725 |
+
if results.get('gt_bboxes', None) is not None:
|
726 |
+
bboxes = results['gt_bboxes']
|
727 |
+
bboxes.translate_([-offset_w, -offset_h])
|
728 |
+
if self.bbox_clip_border:
|
729 |
+
bboxes.clip_(img_shape[:2])
|
730 |
+
valid_inds = bboxes.is_inside(img_shape[:2]).numpy()
|
731 |
+
# If the crop does not contain any gt-bbox area and
|
732 |
+
# allow_negative_crop is False, skip this image.
|
733 |
+
if (not valid_inds.any() and not allow_negative_crop):
|
734 |
+
return None
|
735 |
+
|
736 |
+
results['gt_bboxes'] = bboxes[valid_inds]
|
737 |
+
|
738 |
+
if results.get('gt_ignore_flags', None) is not None:
|
739 |
+
results['gt_ignore_flags'] = \
|
740 |
+
results['gt_ignore_flags'][valid_inds]
|
741 |
+
|
742 |
+
if results.get('gt_ignore_mask_flags', None) is not None:
|
743 |
+
results['gt_ignore_mask_flags'] = \
|
744 |
+
results['gt_ignore_mask_flags'][valid_inds]
|
745 |
+
|
746 |
+
if results.get('gt_bboxes_labels', None) is not None:
|
747 |
+
results['gt_bboxes_labels'] = \
|
748 |
+
results['gt_bboxes_labels'][valid_inds]
|
749 |
+
|
750 |
+
if results.get('gt_masks', None) is not None:
|
751 |
+
results['gt_masks'] = results['gt_masks'][
|
752 |
+
valid_inds.nonzero()[0]].crop(
|
753 |
+
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
|
754 |
+
if self.recompute_bbox:
|
755 |
+
results['gt_bboxes'] = results['gt_masks'].get_bboxes(
|
756 |
+
type(results['gt_bboxes']))
|
757 |
+
|
758 |
+
# crop semantic seg
|
759 |
+
if results.get('gt_seg_map', None) is not None:
|
760 |
+
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
|
761 |
+
crop_x1:crop_x2]
|
762 |
+
|
763 |
+
return results
|
764 |
+
|
765 |
+
|
766 |
+
|
767 |
+
@TRANSFORMS.register_module()
|
768 |
+
class CachedMixUpNoSeg(CachedMixUp):
|
769 |
+
|
770 |
+
@autocast_box_type()
|
771 |
+
def transform(self, results: dict) -> dict:
|
772 |
+
"""MixUp transform function.
|
773 |
+
|
774 |
+
Args:
|
775 |
+
results (dict): Result dict.
|
776 |
+
|
777 |
+
Returns:
|
778 |
+
dict: Updated result dict.
|
779 |
+
"""
|
780 |
+
# cache and pop images
|
781 |
+
self.results_cache.append(copy.deepcopy(results))
|
782 |
+
if len(self.results_cache) > self.max_cached_images:
|
783 |
+
if self.random_pop:
|
784 |
+
index = random.randint(0, len(self.results_cache) - 1)
|
785 |
+
else:
|
786 |
+
index = 0
|
787 |
+
self.results_cache.pop(index)
|
788 |
+
|
789 |
+
if len(self.results_cache) <= 1:
|
790 |
+
return results
|
791 |
+
|
792 |
+
if random.uniform(0, 1) > self.prob:
|
793 |
+
return results
|
794 |
+
|
795 |
+
index = self.get_indexes(self.results_cache)
|
796 |
+
retrieve_results = copy.deepcopy(self.results_cache[index])
|
797 |
+
|
798 |
+
# TODO: refactor mixup to reuse these code.
|
799 |
+
if retrieve_results['gt_bboxes'].shape[0] == 0:
|
800 |
+
# empty bbox
|
801 |
+
return results
|
802 |
+
|
803 |
+
retrieve_img = retrieve_results['img']
|
804 |
+
with_mask = True if 'gt_masks' in results else False
|
805 |
+
|
806 |
+
jit_factor = random.uniform(*self.ratio_range)
|
807 |
+
is_filp = random.uniform(0, 1) > self.flip_ratio
|
808 |
+
|
809 |
+
if len(retrieve_img.shape) == 3:
|
810 |
+
out_img = np.ones(
|
811 |
+
(self.dynamic_scale[1], self.dynamic_scale[0], 3),
|
812 |
+
dtype=retrieve_img.dtype) * self.pad_val
|
813 |
+
else:
|
814 |
+
out_img = np.ones(
|
815 |
+
self.dynamic_scale[::-1],
|
816 |
+
dtype=retrieve_img.dtype) * self.pad_val
|
817 |
+
|
818 |
+
# 1. keep_ratio resize
|
819 |
+
scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0],
|
820 |
+
self.dynamic_scale[0] / retrieve_img.shape[1])
|
821 |
+
retrieve_img = mmcv.imresize(
|
822 |
+
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
|
823 |
+
int(retrieve_img.shape[0] * scale_ratio)))
|
824 |
+
|
825 |
+
# 2. paste
|
826 |
+
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
|
827 |
+
|
828 |
+
# 3. scale jit
|
829 |
+
scale_ratio *= jit_factor
|
830 |
+
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
|
831 |
+
int(out_img.shape[0] * jit_factor)))
|
832 |
+
|
833 |
+
# 4. flip
|
834 |
+
if is_filp:
|
835 |
+
out_img = out_img[:, ::-1, :]
|
836 |
+
|
837 |
+
# 5. random crop
|
838 |
+
ori_img = results['img']
|
839 |
+
origin_h, origin_w = out_img.shape[:2]
|
840 |
+
target_h, target_w = ori_img.shape[:2]
|
841 |
+
padded_img = np.ones((max(origin_h, target_h), max(
|
842 |
+
origin_w, target_w), 3)) * self.pad_val
|
843 |
+
padded_img = padded_img.astype(np.uint8)
|
844 |
+
padded_img[:origin_h, :origin_w] = out_img
|
845 |
+
|
846 |
+
x_offset, y_offset = 0, 0
|
847 |
+
if padded_img.shape[0] > target_h:
|
848 |
+
y_offset = random.randint(0, padded_img.shape[0] - target_h)
|
849 |
+
if padded_img.shape[1] > target_w:
|
850 |
+
x_offset = random.randint(0, padded_img.shape[1] - target_w)
|
851 |
+
padded_cropped_img = padded_img[y_offset:y_offset + target_h,
|
852 |
+
x_offset:x_offset + target_w]
|
853 |
+
|
854 |
+
# 6. adjust bbox
|
855 |
+
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
|
856 |
+
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
|
857 |
+
if with_mask:
|
858 |
+
retrieve_gt_masks = retrieve_results['gt_masks'].rescale(
|
859 |
+
scale_ratio)
|
860 |
+
|
861 |
+
if self.bbox_clip_border:
|
862 |
+
retrieve_gt_bboxes.clip_([origin_h, origin_w])
|
863 |
+
|
864 |
+
if is_filp:
|
865 |
+
retrieve_gt_bboxes.flip_([origin_h, origin_w],
|
866 |
+
direction='horizontal')
|
867 |
+
if with_mask:
|
868 |
+
retrieve_gt_masks = retrieve_gt_masks.flip()
|
869 |
+
|
870 |
+
# 7. filter
|
871 |
+
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
|
872 |
+
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
|
873 |
+
if with_mask:
|
874 |
+
|
875 |
+
retrieve_gt_masks = translate_bitmapmask(retrieve_gt_masks,
|
876 |
+
out_shape=(target_h, target_w),
|
877 |
+
offset_x=-x_offset, offset_y=-y_offset)
|
878 |
+
|
879 |
+
# retrieve_gt_masks = retrieve_gt_masks.translate(
|
880 |
+
# out_shape=(target_h, target_w),
|
881 |
+
# offset=-x_offset,
|
882 |
+
# direction='horizontal')
|
883 |
+
# retrieve_gt_masks = retrieve_gt_masks.translate(
|
884 |
+
# out_shape=(target_h, target_w),
|
885 |
+
# offset=-y_offset,
|
886 |
+
# direction='vertical')
|
887 |
+
|
888 |
+
if self.bbox_clip_border:
|
889 |
+
cp_retrieve_gt_bboxes.clip_([target_h, target_w])
|
890 |
+
|
891 |
+
# 8. mix up
|
892 |
+
ori_img = ori_img.astype(np.float32)
|
893 |
+
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
|
894 |
+
|
895 |
+
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
|
896 |
+
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
|
897 |
+
retrieve_gt_ignore_mask_flags = retrieve_results['gt_ignore_mask_flags']
|
898 |
+
|
899 |
+
mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
|
900 |
+
(results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
|
901 |
+
mixup_gt_bboxes_labels = np.concatenate(
|
902 |
+
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
|
903 |
+
mixup_gt_ignore_flags = np.concatenate(
|
904 |
+
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
|
905 |
+
mixup_gt_ignore_mask_flags = np.concatenate(
|
906 |
+
(results['gt_ignore_mask_flags'], retrieve_gt_ignore_mask_flags), axis=0)
|
907 |
+
|
908 |
+
if with_mask:
|
909 |
+
mixup_gt_masks = retrieve_gt_masks.cat(
|
910 |
+
[results['gt_masks'], retrieve_gt_masks])
|
911 |
+
|
912 |
+
# remove outside bbox
|
913 |
+
inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy()
|
914 |
+
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
|
915 |
+
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
|
916 |
+
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
|
917 |
+
mixup_gt_ignore_mask_flags = mixup_gt_ignore_mask_flags[inside_inds]
|
918 |
+
if with_mask:
|
919 |
+
mixup_gt_masks = mixup_gt_masks[inside_inds]
|
920 |
+
|
921 |
+
results['img'] = mixup_img.astype(np.uint8)
|
922 |
+
results['img_shape'] = mixup_img.shape
|
923 |
+
results['gt_bboxes'] = mixup_gt_bboxes
|
924 |
+
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
|
925 |
+
results['gt_ignore_flags'] = mixup_gt_ignore_flags
|
926 |
+
results['gt_ignore_mask_flags'] = mixup_gt_ignore_mask_flags
|
927 |
+
if with_mask:
|
928 |
+
results['gt_masks'] = mixup_gt_masks
|
929 |
+
return results
|
animeinsseg/data/maskrefine_dataset.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
import pycocotools.mask as maskUtils
|
5 |
+
from pycocotools.coco import COCO
|
6 |
+
import random
|
7 |
+
import os.path as osp
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from scipy.ndimage import distance_transform_bf, distance_transform_edt, distance_transform_cdt
|
11 |
+
|
12 |
+
|
13 |
+
def is_grey(img: np.ndarray):
|
14 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
15 |
+
return False
|
16 |
+
else:
|
17 |
+
return True
|
18 |
+
|
19 |
+
|
20 |
+
def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value = (0, 0, 0)):
|
21 |
+
h, w = img.shape[:2]
|
22 |
+
pad_h, pad_w = 0, 0
|
23 |
+
|
24 |
+
# make square image
|
25 |
+
if w < h:
|
26 |
+
pad_w = h - w
|
27 |
+
w += pad_w
|
28 |
+
elif h < w:
|
29 |
+
pad_h = w - h
|
30 |
+
h += pad_h
|
31 |
+
|
32 |
+
pad_size = tgt_size - h
|
33 |
+
if pad_size > 0:
|
34 |
+
pad_h += pad_size
|
35 |
+
pad_w += pad_size
|
36 |
+
|
37 |
+
if pad_h > 0 or pad_w > 0:
|
38 |
+
c = 1
|
39 |
+
if is_grey(img):
|
40 |
+
if isinstance(pad_value, tuple):
|
41 |
+
pad_value = pad_value[0]
|
42 |
+
else:
|
43 |
+
if isinstance(pad_value, int):
|
44 |
+
pad_value = (pad_value, pad_value, pad_value)
|
45 |
+
|
46 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
|
47 |
+
|
48 |
+
resize_ratio = tgt_size / img.shape[0]
|
49 |
+
if resize_ratio < 1:
|
50 |
+
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
|
51 |
+
elif resize_ratio > 1:
|
52 |
+
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_LINEAR)
|
53 |
+
|
54 |
+
return img, resize_ratio, pad_h, pad_w
|
55 |
+
|
56 |
+
|
57 |
+
class MaskRefineDataset(Dataset):
|
58 |
+
|
59 |
+
def __init__(self,
|
60 |
+
refine_ann_path: str,
|
61 |
+
data_root: str,
|
62 |
+
load_instance_mask: bool = True,
|
63 |
+
aug_ins_prob: float = 0.,
|
64 |
+
ins_rect_prob: float = 0.,
|
65 |
+
output_size: int = 720,
|
66 |
+
augmentation: bool = False,
|
67 |
+
with_distance: bool = False):
|
68 |
+
self.load_instance_mask = load_instance_mask
|
69 |
+
self.ann_util = COCO(refine_ann_path)
|
70 |
+
self.img_ids = self.ann_util.getImgIds()
|
71 |
+
self.set_load_method(load_instance_mask)
|
72 |
+
self.data_root = data_root
|
73 |
+
|
74 |
+
self.ins_rect_prob = ins_rect_prob
|
75 |
+
self.aug_ins_prob = aug_ins_prob
|
76 |
+
self.augmentation = augmentation
|
77 |
+
if augmentation:
|
78 |
+
transform = [
|
79 |
+
A.OpticalDistortion(),
|
80 |
+
A.HorizontalFlip(),
|
81 |
+
A.CLAHE(),
|
82 |
+
A.Posterize(),
|
83 |
+
A.CropAndPad(percent=0.1, p=0.3, pad_mode=cv2.BORDER_CONSTANT, pad_cval=0, pad_cval_mask=0, keep_size=True),
|
84 |
+
A.RandomContrast(),
|
85 |
+
A.Rotate(30, p=0.3, mask_value=0, border_mode=cv2.BORDER_CONSTANT)
|
86 |
+
]
|
87 |
+
self._aug_transform = A.Compose(transform)
|
88 |
+
else:
|
89 |
+
self._aug_transform = None
|
90 |
+
|
91 |
+
self.output_size = output_size
|
92 |
+
self.with_distance = with_distance
|
93 |
+
|
94 |
+
def set_output_size(self, size: int):
|
95 |
+
self.output_size = size
|
96 |
+
|
97 |
+
def set_load_method(self, load_instance_mask: bool):
|
98 |
+
if load_instance_mask:
|
99 |
+
self._load_mask = self._load_with_instance
|
100 |
+
else:
|
101 |
+
self._load_mask = self._load_without_instance
|
102 |
+
|
103 |
+
def __getitem__(self, idx: int):
|
104 |
+
img_id = self.img_ids[idx]
|
105 |
+
img_meta = self.ann_util.imgs[img_id]
|
106 |
+
img_path = osp.join(self.data_root, img_meta['file_name'])
|
107 |
+
img = cv2.imread(img_path)
|
108 |
+
|
109 |
+
annids = self.ann_util.getAnnIds([img_id])
|
110 |
+
if len(annids) > 0:
|
111 |
+
ann = random.choice(annids)
|
112 |
+
ann = self.ann_util.anns[ann]
|
113 |
+
assert ann['image_id'] == img_id
|
114 |
+
else:
|
115 |
+
ann = None
|
116 |
+
|
117 |
+
return self._load_mask(img, ann)
|
118 |
+
|
119 |
+
def transform(self, img: np.ndarray, mask: np.ndarray, ins_seg: np.ndarray = None) -> dict:
|
120 |
+
if ins_seg is not None:
|
121 |
+
use_seg = True
|
122 |
+
else:
|
123 |
+
use_seg = False
|
124 |
+
|
125 |
+
if self.augmentation:
|
126 |
+
masks = [mask]
|
127 |
+
if use_seg:
|
128 |
+
masks.append(ins_seg)
|
129 |
+
data = self._aug_transform(image=img, masks=masks)
|
130 |
+
img = data['image']
|
131 |
+
masks = data['masks']
|
132 |
+
mask = masks[0]
|
133 |
+
if use_seg:
|
134 |
+
ins_seg = masks[1]
|
135 |
+
|
136 |
+
img = square_pad_resize(img, self.output_size, random.randint(0, 255))[0]
|
137 |
+
mask = square_pad_resize(mask, self.output_size, 0)[0]
|
138 |
+
if ins_seg is not None:
|
139 |
+
ins_seg = square_pad_resize(ins_seg, self.output_size, 0)[0]
|
140 |
+
|
141 |
+
img = (img.astype(np.float32) / 255.).transpose((2, 0, 1))
|
142 |
+
mask = mask[None, ...]
|
143 |
+
|
144 |
+
|
145 |
+
if use_seg:
|
146 |
+
ins_seg = ins_seg[None, ...]
|
147 |
+
img = np.concatenate((img, ins_seg), axis=0)
|
148 |
+
|
149 |
+
data = {'img': img, 'mask': mask}
|
150 |
+
if self.with_distance:
|
151 |
+
dist = distance_transform_edt(mask[0])
|
152 |
+
dist_max = dist.max()
|
153 |
+
if dist_max != 0:
|
154 |
+
dist = 1 - dist / dist_max
|
155 |
+
# diff_mat = cv2.bitwise_xor(mask[0], ins_seg[0])
|
156 |
+
# dist = dist + diff_mat + 0.2
|
157 |
+
dist = dist + 0.2
|
158 |
+
dist = dist.size / (dist.sum() + 1) * dist
|
159 |
+
dist = np.clip(dist, 0, 20)
|
160 |
+
else:
|
161 |
+
dist = np.ones_like(dist)
|
162 |
+
# print(dist.max(), dist.min())
|
163 |
+
data['dist_weight'] = dist[None, ...]
|
164 |
+
return data
|
165 |
+
|
166 |
+
def _load_with_instance(self, img: np.ndarray, ann: dict):
|
167 |
+
if ann is None:
|
168 |
+
mask = np.zeros(img.shape[:2], dtype=np.float32)
|
169 |
+
ins_seg = mask
|
170 |
+
else:
|
171 |
+
mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
|
172 |
+
if self.augmentation and random.random() < self.ins_rect_prob:
|
173 |
+
ins_seg = np.zeros_like(mask)
|
174 |
+
bbox = [int(b) for b in ann['bbox']]
|
175 |
+
ins_seg[bbox[1]: bbox[1] + bbox[3], bbox[0]: bbox[0] + bbox[2]] = 1
|
176 |
+
elif len(ann['pred_segmentations']) > 0:
|
177 |
+
ins_seg = random.choice(ann['pred_segmentations'])
|
178 |
+
ins_seg = maskUtils.decode(ins_seg).astype(np.float32)
|
179 |
+
else:
|
180 |
+
ins_seg = mask
|
181 |
+
if self.augmentation and random.random() < self.aug_ins_prob:
|
182 |
+
ksize = random.choice([1, 3, 5, 7])
|
183 |
+
ksize = ksize * 2 + 1
|
184 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ksize=(ksize, ksize))
|
185 |
+
if random.random() < 0.5:
|
186 |
+
ins_seg = cv2.dilate(ins_seg, kernel)
|
187 |
+
else:
|
188 |
+
ins_seg = cv2.erode(ins_seg, kernel)
|
189 |
+
|
190 |
+
return self.transform(img, mask, ins_seg)
|
191 |
+
|
192 |
+
def _load_without_instance(self, img: np.ndarray, ann: dict):
|
193 |
+
if ann is None:
|
194 |
+
mask = np.zeros(img.shape[:2], dtype=np.float32)
|
195 |
+
else:
|
196 |
+
mask = maskUtils.decode(ann['segmentation']).astype(np.float32)
|
197 |
+
return self.transform(img, mask)
|
198 |
+
|
199 |
+
def __len__(self):
|
200 |
+
return len(self.img_ids)
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == '__main__':
|
204 |
+
ann_path = r'workspace/test_syndata/annotations/refine_train.json'
|
205 |
+
data_root = r'workspace/test_syndata/train'
|
206 |
+
|
207 |
+
ann_path = r'workspace/test_syndata/annotations/refine_train.json'
|
208 |
+
data_root = r'workspace/test_syndata/train'
|
209 |
+
aug_ins_prob = 0.5
|
210 |
+
load_instance_mask = True
|
211 |
+
ins_rect_prob = 0.25
|
212 |
+
output_size = 640
|
213 |
+
augmentation = True
|
214 |
+
|
215 |
+
random.seed(0)
|
216 |
+
|
217 |
+
md = MaskRefineDataset(ann_path, data_root, load_instance_mask, aug_ins_prob, ins_rect_prob, output_size, augmentation, with_distance=True)
|
218 |
+
|
219 |
+
dl = DataLoader(md, batch_size=1, shuffle=False, persistent_workers=True,
|
220 |
+
num_workers=1, pin_memory=True)
|
221 |
+
for data in dl:
|
222 |
+
img = data['img'].cpu().numpy()
|
223 |
+
img = (img[0, :3].transpose((1, 2, 0)) * 255).astype(np.uint8)
|
224 |
+
mask = (data['mask'].cpu().numpy()[0][0] * 255).astype(np.uint8)
|
225 |
+
if load_instance_mask:
|
226 |
+
ins = (data['img'].cpu().numpy()[0][3] * 255).astype(np.uint8)
|
227 |
+
cv2.imshow('ins', ins)
|
228 |
+
dist = data['dist_weight'].cpu().numpy()[0][0]
|
229 |
+
dist = (dist / dist.max() * 255).astype(np.uint8)
|
230 |
+
cv2.imshow('img', img)
|
231 |
+
cv2.imshow('mask', mask)
|
232 |
+
cv2.imshow('dist_weight', dist)
|
233 |
+
cv2.waitKey(0)
|
234 |
+
|
235 |
+
# cv2.imwrite('')
|
animeinsseg/data/metrics.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import datetime
|
3 |
+
import itertools
|
4 |
+
import os.path as osp
|
5 |
+
import tempfile
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import Dict, List, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from mmengine.evaluator import BaseMetric
|
12 |
+
from mmengine.fileio import FileClient, dump, load
|
13 |
+
from mmengine.logging import MMLogger
|
14 |
+
from terminaltables import AsciiTable
|
15 |
+
|
16 |
+
from mmdet.datasets.api_wrappers import COCO, COCOeval
|
17 |
+
from mmdet.registry import METRICS
|
18 |
+
from mmdet.structures.mask import encode_mask_results
|
19 |
+
# from ..functional import eval_recalls
|
20 |
+
from mmdet.evaluation.metrics import CocoMetric
|
21 |
+
|
22 |
+
|
23 |
+
@METRICS.register_module()
|
24 |
+
class AnimeMangaMetric(CocoMetric):
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
manga109_annfile=None,
|
28 |
+
animeins_annfile=None,
|
29 |
+
ann_file: Optional[str] = None,
|
30 |
+
metric: Union[str, List[str]] = 'bbox',
|
31 |
+
classwise: bool = False,
|
32 |
+
proposal_nums: Sequence[int] = (100, 300, 1000),
|
33 |
+
iou_thrs: Optional[Union[float, Sequence[float]]] = None,
|
34 |
+
metric_items: Optional[Sequence[str]] = None,
|
35 |
+
format_only: bool = False,
|
36 |
+
outfile_prefix: Optional[str] = None,
|
37 |
+
file_client_args: dict = dict(backend='disk'),
|
38 |
+
collect_device: str = 'cpu',
|
39 |
+
prefix: Optional[str] = None,
|
40 |
+
sort_categories: bool = False) -> None:
|
41 |
+
|
42 |
+
super().__init__(ann_file, metric, classwise, proposal_nums, iou_thrs, metric_items, format_only, outfile_prefix, file_client_args, collect_device, prefix, sort_categories)
|
43 |
+
|
44 |
+
self.manga109_img_ids = set()
|
45 |
+
if manga109_annfile is not None:
|
46 |
+
with self.file_client.get_local_path(manga109_annfile) as local_path:
|
47 |
+
self._manga109_coco_api = COCO(local_path)
|
48 |
+
if sort_categories:
|
49 |
+
# 'categories' list in objects365_train.json and
|
50 |
+
# objects365_val.json is inconsistent, need sort
|
51 |
+
# list(or dict) before get cat_ids.
|
52 |
+
cats = self._manga109_coco_api.cats
|
53 |
+
sorted_cats = {i: cats[i] for i in sorted(cats)}
|
54 |
+
self._manga109_coco_api.cats = sorted_cats
|
55 |
+
categories = self._manga109_coco_api.dataset['categories']
|
56 |
+
sorted_categories = sorted(
|
57 |
+
categories, key=lambda i: i['id'])
|
58 |
+
self._manga109_coco_api.dataset['categories'] = sorted_categories
|
59 |
+
self.manga109_img_ids = set(self._manga109_coco_api.get_img_ids())
|
60 |
+
else:
|
61 |
+
self._manga109_coco_api = None
|
62 |
+
|
63 |
+
self.animeins_img_ids = set()
|
64 |
+
if animeins_annfile is not None:
|
65 |
+
with self.file_client.get_local_path(animeins_annfile) as local_path:
|
66 |
+
self._animeins_coco_api = COCO(local_path)
|
67 |
+
if sort_categories:
|
68 |
+
# 'categories' list in objects365_train.json and
|
69 |
+
# objects365_val.json is inconsistent, need sort
|
70 |
+
# list(or dict) before get cat_ids.
|
71 |
+
cats = self._animeins_coco_api.cats
|
72 |
+
sorted_cats = {i: cats[i] for i in sorted(cats)}
|
73 |
+
self._animeins_coco_api.cats = sorted_cats
|
74 |
+
categories = self._animeins_coco_api.dataset['categories']
|
75 |
+
sorted_categories = sorted(
|
76 |
+
categories, key=lambda i: i['id'])
|
77 |
+
self._animeins_coco_api.dataset['categories'] = sorted_categories
|
78 |
+
self.animeins_img_ids = set(self._animeins_coco_api.get_img_ids())
|
79 |
+
else:
|
80 |
+
self._animeins_coco_api = None
|
81 |
+
|
82 |
+
if self._animeins_coco_api is not None:
|
83 |
+
self._coco_api = self._animeins_coco_api
|
84 |
+
else:
|
85 |
+
self._coco_api = self._manga109_coco_api
|
86 |
+
|
87 |
+
|
88 |
+
def compute_metrics(self, results: list) -> Dict[str, float]:
|
89 |
+
|
90 |
+
# split gt and prediction list
|
91 |
+
gts, preds = zip(*results)
|
92 |
+
|
93 |
+
manga109_gts, animeins_gts = [], []
|
94 |
+
manga109_preds, animeins_preds = [], []
|
95 |
+
for gt, pred in zip(gts, preds):
|
96 |
+
if gt['img_id'] in self.manga109_img_ids:
|
97 |
+
manga109_gts.append(gt)
|
98 |
+
manga109_preds.append(pred)
|
99 |
+
else:
|
100 |
+
animeins_gts.append(gt)
|
101 |
+
animeins_preds.append(pred)
|
102 |
+
|
103 |
+
tmp_dir = None
|
104 |
+
if self.outfile_prefix is None:
|
105 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
106 |
+
outfile_prefix = osp.join(tmp_dir.name, 'results')
|
107 |
+
else:
|
108 |
+
outfile_prefix = self.outfile_prefix
|
109 |
+
|
110 |
+
eval_results = OrderedDict()
|
111 |
+
|
112 |
+
if len(manga109_gts) > 0:
|
113 |
+
metrics = []
|
114 |
+
for m in self.metrics:
|
115 |
+
if m != 'segm':
|
116 |
+
metrics.append(m)
|
117 |
+
|
118 |
+
self.cat_ids = self._manga109_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
|
119 |
+
self.img_ids = self._manga109_coco_api.get_img_ids()
|
120 |
+
rst = self._compute_metrics(metrics, self._manga109_coco_api, manga109_preds, outfile_prefix, tmp_dir)
|
121 |
+
for key, item in rst.items():
|
122 |
+
eval_results['manga109_'+key] = item
|
123 |
+
|
124 |
+
if len(animeins_gts) > 0:
|
125 |
+
self.cat_ids = self._animeins_coco_api.get_cat_ids(cat_names=self.dataset_meta['classes'])
|
126 |
+
self.img_ids = self._animeins_coco_api.get_img_ids()
|
127 |
+
rst = self._compute_metrics(self.metrics, self._animeins_coco_api, animeins_preds, outfile_prefix, tmp_dir)
|
128 |
+
for key, item in rst.items():
|
129 |
+
eval_results['animeins_'+key] = item
|
130 |
+
|
131 |
+
return eval_results
|
132 |
+
|
133 |
+
def results2json(self, results: Sequence[dict],
|
134 |
+
outfile_prefix: str) -> dict:
|
135 |
+
"""Dump the detection results to a COCO style json file.
|
136 |
+
|
137 |
+
There are 3 types of results: proposals, bbox predictions, mask
|
138 |
+
predictions, and they have different data types. This method will
|
139 |
+
automatically recognize the type, and dump them to json files.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
results (Sequence[dict]): Testing results of the
|
143 |
+
dataset.
|
144 |
+
outfile_prefix (str): The filename prefix of the json files. If the
|
145 |
+
prefix is "somepath/xxx", the json files will be named
|
146 |
+
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
|
147 |
+
"somepath/xxx.proposal.json".
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
dict: Possible keys are "bbox", "segm", "proposal", and
|
151 |
+
values are corresponding filenames.
|
152 |
+
"""
|
153 |
+
bbox_json_results = []
|
154 |
+
segm_json_results = [] if 'masks' in results[0] else None
|
155 |
+
for idx, result in enumerate(results):
|
156 |
+
image_id = result.get('img_id', idx)
|
157 |
+
labels = result['labels']
|
158 |
+
bboxes = result['bboxes']
|
159 |
+
scores = result['scores']
|
160 |
+
# bbox results
|
161 |
+
for i, label in enumerate(labels):
|
162 |
+
data = dict()
|
163 |
+
data['image_id'] = image_id
|
164 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
165 |
+
data['score'] = float(scores[i])
|
166 |
+
data['category_id'] = self.cat_ids[label]
|
167 |
+
bbox_json_results.append(data)
|
168 |
+
|
169 |
+
if segm_json_results is None:
|
170 |
+
continue
|
171 |
+
|
172 |
+
# segm results
|
173 |
+
masks = result['masks']
|
174 |
+
mask_scores = result.get('mask_scores', scores)
|
175 |
+
for i, label in enumerate(labels):
|
176 |
+
data = dict()
|
177 |
+
data['image_id'] = image_id
|
178 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
179 |
+
data['score'] = float(mask_scores[i])
|
180 |
+
data['category_id'] = self.cat_ids[label]
|
181 |
+
if isinstance(masks[i]['counts'], bytes):
|
182 |
+
masks[i]['counts'] = masks[i]['counts'].decode()
|
183 |
+
data['segmentation'] = masks[i]
|
184 |
+
segm_json_results.append(data)
|
185 |
+
|
186 |
+
logger: MMLogger = MMLogger.get_current_instance()
|
187 |
+
logger.info('dumping predictions ... ')
|
188 |
+
result_files = dict()
|
189 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
190 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
191 |
+
dump(bbox_json_results, result_files['bbox'])
|
192 |
+
|
193 |
+
if segm_json_results is not None:
|
194 |
+
result_files['segm'] = f'{outfile_prefix}.segm.json'
|
195 |
+
dump(segm_json_results, result_files['segm'])
|
196 |
+
|
197 |
+
return result_files
|
198 |
+
|
199 |
+
def _compute_metrics(self, metrics, tgt_api, preds, outfile_prefix, tmp_dir):
|
200 |
+
logger: MMLogger = MMLogger.get_current_instance()
|
201 |
+
|
202 |
+
result_files = self.results2json(preds, outfile_prefix)
|
203 |
+
|
204 |
+
eval_results = OrderedDict()
|
205 |
+
if self.format_only:
|
206 |
+
logger.info('results are saved in '
|
207 |
+
f'{osp.dirname(outfile_prefix)}')
|
208 |
+
return eval_results
|
209 |
+
|
210 |
+
for metric in metrics:
|
211 |
+
logger.info(f'Evaluating {metric}...')
|
212 |
+
|
213 |
+
# TODO: May refactor fast_eval_recall to an independent metric?
|
214 |
+
# fast eval recall
|
215 |
+
if metric == 'proposal_fast':
|
216 |
+
ar = self.fast_eval_recall(
|
217 |
+
preds, self.proposal_nums, self.iou_thrs, logger=logger)
|
218 |
+
log_msg = []
|
219 |
+
for i, num in enumerate(self.proposal_nums):
|
220 |
+
eval_results[f'AR@{num}'] = ar[i]
|
221 |
+
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
|
222 |
+
log_msg = ''.join(log_msg)
|
223 |
+
logger.info(log_msg)
|
224 |
+
continue
|
225 |
+
|
226 |
+
# evaluate proposal, bbox and segm
|
227 |
+
iou_type = 'bbox' if metric == 'proposal' else metric
|
228 |
+
if metric not in result_files:
|
229 |
+
raise KeyError(f'{metric} is not in results')
|
230 |
+
try:
|
231 |
+
predictions = load(result_files[metric])
|
232 |
+
if iou_type == 'segm':
|
233 |
+
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
|
234 |
+
# When evaluating mask AP, if the results contain bbox,
|
235 |
+
# cocoapi will use the box area instead of the mask area
|
236 |
+
# for calculating the instance area. Though the overall AP
|
237 |
+
# is not affected, this leads to different
|
238 |
+
# small/medium/large mask AP results.
|
239 |
+
for x in predictions:
|
240 |
+
x.pop('bbox')
|
241 |
+
coco_dt = tgt_api.loadRes(predictions)
|
242 |
+
|
243 |
+
except IndexError:
|
244 |
+
logger.error(
|
245 |
+
'The testing results of the whole dataset is empty.')
|
246 |
+
break
|
247 |
+
|
248 |
+
coco_eval = COCOeval(tgt_api, coco_dt, iou_type)
|
249 |
+
|
250 |
+
coco_eval.params.catIds = self.cat_ids
|
251 |
+
coco_eval.params.imgIds = self.img_ids
|
252 |
+
coco_eval.params.maxDets = list(self.proposal_nums)
|
253 |
+
coco_eval.params.iouThrs = self.iou_thrs
|
254 |
+
|
255 |
+
# mapping of cocoEval.stats
|
256 |
+
coco_metric_names = {
|
257 |
+
'mAP': 0,
|
258 |
+
'mAP_50': 1,
|
259 |
+
'mAP_75': 2,
|
260 |
+
'mAP_s': 3,
|
261 |
+
'mAP_m': 4,
|
262 |
+
'mAP_l': 5,
|
263 |
+
'AR@100': 6,
|
264 |
+
'AR@300': 7,
|
265 |
+
'AR@1000': 8,
|
266 |
+
'AR_s@1000': 9,
|
267 |
+
'AR_m@1000': 10,
|
268 |
+
'AR_l@1000': 11
|
269 |
+
}
|
270 |
+
metric_items = self.metric_items
|
271 |
+
if metric_items is not None:
|
272 |
+
for metric_item in metric_items:
|
273 |
+
if metric_item not in coco_metric_names:
|
274 |
+
raise KeyError(
|
275 |
+
f'metric item "{metric_item}" is not supported')
|
276 |
+
|
277 |
+
if metric == 'proposal':
|
278 |
+
coco_eval.params.useCats = 0
|
279 |
+
coco_eval.evaluate()
|
280 |
+
coco_eval.accumulate()
|
281 |
+
coco_eval.summarize()
|
282 |
+
if metric_items is None:
|
283 |
+
metric_items = [
|
284 |
+
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
|
285 |
+
'AR_m@1000', 'AR_l@1000'
|
286 |
+
]
|
287 |
+
|
288 |
+
for item in metric_items:
|
289 |
+
val = float(
|
290 |
+
f'{coco_eval.stats[coco_metric_names[item]]:.3f}')
|
291 |
+
eval_results[item] = val
|
292 |
+
else:
|
293 |
+
coco_eval.evaluate()
|
294 |
+
coco_eval.accumulate()
|
295 |
+
coco_eval.summarize()
|
296 |
+
if self.classwise: # Compute per-category AP
|
297 |
+
# Compute per-category AP
|
298 |
+
# from https://github.com/facebookresearch/detectron2/
|
299 |
+
precisions = coco_eval.eval['precision']
|
300 |
+
# precision: (iou, recall, cls, area range, max dets)
|
301 |
+
assert len(self.cat_ids) == precisions.shape[2]
|
302 |
+
|
303 |
+
results_per_category = []
|
304 |
+
for idx, cat_id in enumerate(self.cat_ids):
|
305 |
+
# area range index 0: all area ranges
|
306 |
+
# max dets index -1: typically 100 per image
|
307 |
+
nm = tgt_api.loadCats(cat_id)[0]
|
308 |
+
precision = precisions[:, :, idx, 0, -1]
|
309 |
+
precision = precision[precision > -1]
|
310 |
+
if precision.size:
|
311 |
+
ap = np.mean(precision)
|
312 |
+
else:
|
313 |
+
ap = float('nan')
|
314 |
+
results_per_category.append(
|
315 |
+
(f'{nm["name"]}', f'{round(ap, 3)}'))
|
316 |
+
eval_results[f'{nm["name"]}_precision'] = round(ap, 3)
|
317 |
+
|
318 |
+
num_columns = min(6, len(results_per_category) * 2)
|
319 |
+
results_flatten = list(
|
320 |
+
itertools.chain(*results_per_category))
|
321 |
+
headers = ['category', 'AP'] * (num_columns // 2)
|
322 |
+
results_2d = itertools.zip_longest(*[
|
323 |
+
results_flatten[i::num_columns]
|
324 |
+
for i in range(num_columns)
|
325 |
+
])
|
326 |
+
table_data = [headers]
|
327 |
+
table_data += [result for result in results_2d]
|
328 |
+
table = AsciiTable(table_data)
|
329 |
+
logger.info('\n' + table.table)
|
330 |
+
|
331 |
+
if metric_items is None:
|
332 |
+
metric_items = [
|
333 |
+
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
|
334 |
+
]
|
335 |
+
|
336 |
+
for metric_item in metric_items:
|
337 |
+
key = f'{metric}_{metric_item}'
|
338 |
+
val = coco_eval.stats[coco_metric_names[metric_item]]
|
339 |
+
eval_results[key] = float(f'{round(val, 3)}')
|
340 |
+
|
341 |
+
ap = coco_eval.stats[:6]
|
342 |
+
logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} '
|
343 |
+
f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
|
344 |
+
f'{ap[4]:.3f} {ap[5]:.3f}')
|
345 |
+
|
346 |
+
if tmp_dir is not None:
|
347 |
+
tmp_dir.cleanup()
|
348 |
+
return eval_results
|
animeinsseg/data/paste_methods.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Union, Tuple, Dict
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import os.path as osp
|
7 |
+
from tqdm import tqdm
|
8 |
+
from panopticapi.utils import rgb2id, id2rgb
|
9 |
+
from time import time
|
10 |
+
import traceback
|
11 |
+
|
12 |
+
from utils.io_utils import bbox_overlap_area
|
13 |
+
from utils.logger import LOGGER
|
14 |
+
from utils.constants import COLOR_PALETTE
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class PartitionTree:
|
19 |
+
|
20 |
+
def __init__(self, bleft: int, btop: int, bright: int, bbottom: int, parent = None) -> None:
|
21 |
+
self.left: PartitionTree = None
|
22 |
+
self.right: PartitionTree = None
|
23 |
+
self.top: PartitionTree = None
|
24 |
+
self.bottom: PartitionTree = None
|
25 |
+
|
26 |
+
if bright < bleft:
|
27 |
+
bright = bleft
|
28 |
+
if bbottom < btop:
|
29 |
+
bbottom = btop
|
30 |
+
|
31 |
+
self.bleft = bleft
|
32 |
+
self.bright = bright
|
33 |
+
self.btop = btop
|
34 |
+
self.bbottom = bbottom
|
35 |
+
self.parent: PartitionTree = parent
|
36 |
+
|
37 |
+
def is_leaf(self):
|
38 |
+
return self.left is None
|
39 |
+
|
40 |
+
def new_partition(self, new_rect: List):
|
41 |
+
self.left = PartitionTree(self.bleft, self.btop, new_rect[0], self.bbottom, self)
|
42 |
+
self.top = PartitionTree(self.bleft, self.btop, self.bright, new_rect[1], self)
|
43 |
+
self.right = PartitionTree(new_rect[2], self.btop, self.bright, self.bbottom, self)
|
44 |
+
self.bottom = PartitionTree(self.bleft, new_rect[3], self.bright, self.bbottom, self)
|
45 |
+
if self.parent is not None:
|
46 |
+
self.root_update_rect(new_rect)
|
47 |
+
|
48 |
+
def root_update_rect(self, rect):
|
49 |
+
root = self.get_root()
|
50 |
+
root.update_child_rect(rect)
|
51 |
+
|
52 |
+
def update_child_rect(self, rect: List):
|
53 |
+
if self.is_leaf():
|
54 |
+
self.update_from_rect(rect)
|
55 |
+
else:
|
56 |
+
self.left.update_child_rect(rect)
|
57 |
+
self.right.update_child_rect(rect)
|
58 |
+
self.top.update_child_rect(rect)
|
59 |
+
self.bottom.update_child_rect(rect)
|
60 |
+
|
61 |
+
def get_root(self):
|
62 |
+
if self.parent is not None:
|
63 |
+
return self.parent.get_root()
|
64 |
+
else:
|
65 |
+
return self
|
66 |
+
|
67 |
+
|
68 |
+
def update_from_rect(self, rect: List):
|
69 |
+
if not self.is_leaf():
|
70 |
+
return
|
71 |
+
ix = min(self.bright, rect[2]) - max(self.bleft, rect[0])
|
72 |
+
iy = min(self.bbottom, rect[3]) - max(self.btop, rect[1])
|
73 |
+
if not (ix > 0 and iy > 0):
|
74 |
+
return
|
75 |
+
|
76 |
+
new_ltrb0 = np.array([self.bleft, self.btop, self.bright, self.bbottom])
|
77 |
+
new_ltrb1 = new_ltrb0.copy()
|
78 |
+
|
79 |
+
if rect[0] > self.bleft and rect[0] < self.bright:
|
80 |
+
new_ltrb0[2] = rect[0]
|
81 |
+
else:
|
82 |
+
new_ltrb0[0] = rect[2]
|
83 |
+
|
84 |
+
if rect[1] > self.btop and rect[1] < self.bbottom:
|
85 |
+
new_ltrb1[3]= rect[1]
|
86 |
+
else:
|
87 |
+
new_ltrb1[1] = rect[3]
|
88 |
+
|
89 |
+
if (new_ltrb0[2:] - new_ltrb0[:2]).prod() > (new_ltrb1[2:] - new_ltrb1[:2]).prod():
|
90 |
+
self.bleft, self.btop, self.bright, self.bbottom = new_ltrb0
|
91 |
+
else:
|
92 |
+
self.bleft, self.btop, self.bright, self.bbottom = new_ltrb1
|
93 |
+
|
94 |
+
@property
|
95 |
+
def width(self) -> int:
|
96 |
+
return self.bright - self.bleft
|
97 |
+
|
98 |
+
@property
|
99 |
+
def height(self) -> int:
|
100 |
+
return self.bbottom - self.btop
|
101 |
+
|
102 |
+
def prefer_partition(self, tgt_h: int, tgt_w: int):
|
103 |
+
if self.is_leaf():
|
104 |
+
return self, min(self.width / tgt_w, 1.2) * min(self.height / tgt_h, 1.2)
|
105 |
+
else:
|
106 |
+
lp, ls = self.left.prefer_partition(tgt_h, tgt_w)
|
107 |
+
rp, rs = self.right.prefer_partition(tgt_h, tgt_w)
|
108 |
+
tp, ts = self.top.prefer_partition(tgt_h, tgt_w)
|
109 |
+
bp, bs = self.bottom.prefer_partition(tgt_h, tgt_w)
|
110 |
+
preferp = [(p, s) for s, p in sorted(zip([ls, rs, ts, bs],[lp, rp, tp, bp]), key=lambda pair: pair[0], reverse=True)][0]
|
111 |
+
return preferp
|
112 |
+
|
113 |
+
def new_random_pos(self, fg_h: int, fg_w: int, im_h: int, im_w: int, random_sample: bool = False):
|
114 |
+
extx, exty = int(fg_w / 3), int(fg_h / 3)
|
115 |
+
extxb, extyb = int(fg_w / 10), int(fg_h / 10)
|
116 |
+
region_w, region_h = self.width + extx, self.height + exty
|
117 |
+
downscale_ratio = max(min(region_w / fg_w, region_h / fg_h), 0.8)
|
118 |
+
if downscale_ratio < 1:
|
119 |
+
fg_h = int(downscale_ratio * fg_h)
|
120 |
+
fg_w = int(downscale_ratio * fg_w)
|
121 |
+
|
122 |
+
max_x, max_y = self.bright + extx - fg_w, self.bbottom + exty - fg_h
|
123 |
+
max_x = min(im_w+extxb-fg_w, max_x)
|
124 |
+
max_y = min(im_h+extyb-fg_h, max_y)
|
125 |
+
min_x = max(min(self.bright + extx - fg_w, self.bleft - extx), -extx)
|
126 |
+
min_x = max(-extxb, min_x)
|
127 |
+
min_y = max(min(self.bbottom + exty - fg_h, self.btop - exty), -exty)
|
128 |
+
min_y = max(-extyb, min_y)
|
129 |
+
px, py = min_x, min_y
|
130 |
+
if min_x < max_x:
|
131 |
+
if random_sample:
|
132 |
+
px = random.randint(min_x, max_x)
|
133 |
+
else:
|
134 |
+
px = int((min_x + max_x) / 2)
|
135 |
+
if min_y < max_y:
|
136 |
+
if random_sample:
|
137 |
+
py = random.randint(min_y, max_y)
|
138 |
+
else:
|
139 |
+
py = int((min_y + max_y) / 2)
|
140 |
+
return px, py, downscale_ratio
|
141 |
+
|
142 |
+
def drawpartition(self, image: np.ndarray, color = None):
|
143 |
+
if color is None:
|
144 |
+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
145 |
+
if not self.is_leaf():
|
146 |
+
cv2.rectangle(image, (self.bleft, self.btop), (self.bright, self.bbottom), color, 2)
|
147 |
+
if not self.is_leaf():
|
148 |
+
c = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
149 |
+
self.left.drawpartition(image, c)
|
150 |
+
self.right.drawpartition(image, c)
|
151 |
+
self.top.drawpartition(image, c)
|
152 |
+
self.bottom.drawpartition(image, c)
|
153 |
+
|
154 |
+
|
155 |
+
def paste_one_fg(fg_pil: Image, bg: Image, segments: np.ndarray, px: int, py: int, seg_color: Tuple, cal_area=True):
|
156 |
+
|
157 |
+
fg_h, fg_w = fg_pil.height, fg_pil.width
|
158 |
+
im_h, im_w = bg.height, bg.width
|
159 |
+
|
160 |
+
bg.paste(fg_pil, (px, py), mask=fg_pil)
|
161 |
+
|
162 |
+
|
163 |
+
bgx1, bgx2, bgy1, bgy2 = px, px+fg_w, py, py+fg_h
|
164 |
+
fgx1, fgx2, fgy1, fgy2 = 0, fg_w, 0, fg_h
|
165 |
+
if bgx1 < 0:
|
166 |
+
fgx1 = -bgx1
|
167 |
+
bgx1 = 0
|
168 |
+
if bgy1 < 0:
|
169 |
+
fgy1 = -bgy1
|
170 |
+
bgy1 = 0
|
171 |
+
if bgx2 > im_w:
|
172 |
+
fgx2 = im_w - bgx2
|
173 |
+
bgx2 = im_w
|
174 |
+
if bgy2 > im_h:
|
175 |
+
fgy2 = im_h - bgy2
|
176 |
+
bgy2 = im_h
|
177 |
+
|
178 |
+
fg_mask = np.array(fg_pil)[fgy1: fgy2, fgx1: fgx2, 3] > 30
|
179 |
+
segments[bgy1: bgy2, bgx1: bgx2][np.where(fg_mask)] = seg_color
|
180 |
+
|
181 |
+
if cal_area:
|
182 |
+
area = fg_mask.sum()
|
183 |
+
else:
|
184 |
+
area = 1
|
185 |
+
bbox = [bgx1, bgy1, bgx2-bgx1, bgy2-bgy1]
|
186 |
+
return area, bbox, [bgx1, bgy1, bgx2, bgy2]
|
187 |
+
|
188 |
+
|
189 |
+
def partition_paste(fg_list, bg: Image):
|
190 |
+
segments_info = []
|
191 |
+
|
192 |
+
fg_list.sort(key = lambda x: x['image'].shape[0] * x['image'].shape[1], reverse=True)
|
193 |
+
pnode: PartitionTree = None
|
194 |
+
im_h, im_w = bg.height, bg.width
|
195 |
+
|
196 |
+
ptree = PartitionTree(0, 0, bg.width, bg.height)
|
197 |
+
|
198 |
+
segments = np.zeros((im_h, im_w, 3), np.uint8)
|
199 |
+
for ii, fg_dict in enumerate(fg_list):
|
200 |
+
fg = fg_dict['image']
|
201 |
+
fg_h, fg_w = fg.shape[:2]
|
202 |
+
pnode, _ = ptree.prefer_partition(fg_h, fg_w)
|
203 |
+
px, py, downscale_ratio = pnode.new_random_pos(fg_h, fg_w, im_h, im_w, True)
|
204 |
+
|
205 |
+
fg_pil = Image.fromarray(fg)
|
206 |
+
if downscale_ratio < 1:
|
207 |
+
fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
|
208 |
+
# fg_h, fg_w = fg_pil.height, fg_pil.width
|
209 |
+
|
210 |
+
seg_color = COLOR_PALETTE[ii]
|
211 |
+
area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=False)
|
212 |
+
pnode.new_partition(xyxy)
|
213 |
+
|
214 |
+
segments_info.append({
|
215 |
+
'id': rgb2id(seg_color),
|
216 |
+
'bbox': bbox,
|
217 |
+
'area': area
|
218 |
+
})
|
219 |
+
|
220 |
+
return segments_info, segments
|
221 |
+
# if downscale_ratio < 1:
|
222 |
+
# fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS)
|
223 |
+
# fg_h, fg_w = fg_pil.height, fg_pil.width
|
224 |
+
|
225 |
+
|
226 |
+
def gen_fg_regbboxes(fg_list: List[Dict], tgt_size: int, min_overlap=0.15, max_overlap=0.8):
|
227 |
+
|
228 |
+
def _sample_y(h):
|
229 |
+
y = (tgt_size - h) // 2
|
230 |
+
if y > 0:
|
231 |
+
yrange = min(y, h // 4)
|
232 |
+
y += random.randint(-yrange, yrange)
|
233 |
+
return y
|
234 |
+
else:
|
235 |
+
return 0
|
236 |
+
|
237 |
+
shape_list = []
|
238 |
+
depth_list = []
|
239 |
+
|
240 |
+
|
241 |
+
for fg_dict in fg_list:
|
242 |
+
shape_list.append(fg_dict['image'].shape[:2])
|
243 |
+
|
244 |
+
shape_list = np.array(shape_list)
|
245 |
+
depth_list = np.random.random(len(fg_list))
|
246 |
+
depth_list[shape_list[..., 1] > 0.6 * tgt_size] += 1
|
247 |
+
|
248 |
+
# num_fg = len(fg_list)
|
249 |
+
# grid_sample = random.random() < 0.4 or num_fg > 6
|
250 |
+
# grid_sample = grid_sample and num_fg < 9 and num_fg > 3
|
251 |
+
# grid_sample = False
|
252 |
+
# if grid_sample:
|
253 |
+
# grid_pos = np.arange(9)
|
254 |
+
# np.random.shuffle(grid_pos)
|
255 |
+
# grid_pos = grid_pos[: num_fg]
|
256 |
+
# grid_x = grid_pos % 3
|
257 |
+
# grid_y = grid_pos // 3
|
258 |
+
|
259 |
+
# else:
|
260 |
+
pos_list = [[0, _sample_y(shape_list[0][0])]]
|
261 |
+
pre_overlap = 0
|
262 |
+
for ii, ((h, w), d) in enumerate(zip(shape_list[1:], depth_list[1:])):
|
263 |
+
(preh, prew), predepth, (prex, prey) = shape_list[ii], depth_list[ii], pos_list[ii]
|
264 |
+
|
265 |
+
isfg = d < predepth
|
266 |
+
y = _sample_y(h)
|
267 |
+
x = prex+prew
|
268 |
+
if isfg:
|
269 |
+
min_x = max_x = x
|
270 |
+
if pre_overlap < max_overlap:
|
271 |
+
min_x -= (max_overlap - pre_overlap) * prew
|
272 |
+
min_x = int(min_x)
|
273 |
+
if pre_overlap < min_overlap:
|
274 |
+
max_x -= (min_overlap - pre_overlap) * prew
|
275 |
+
max_x = int(max_x)
|
276 |
+
x = random.randint(min_x, max_x)
|
277 |
+
pre_overlap = 0
|
278 |
+
else:
|
279 |
+
overlap = random.uniform(min_overlap, max_overlap)
|
280 |
+
x -= int(overlap * w)
|
281 |
+
area = h * w
|
282 |
+
overlap_area = bbox_overlap_area([x, y, w, h], [prex, prey, prew, preh])
|
283 |
+
pre_overlap = overlap_area / area
|
284 |
+
|
285 |
+
pos_list.append([x, y])
|
286 |
+
|
287 |
+
pos_list = np.array(pos_list)
|
288 |
+
last_x2 = pos_list[-1][0] + shape_list[-1][1]
|
289 |
+
valid_shiftx = tgt_size - last_x2
|
290 |
+
if valid_shiftx > 0:
|
291 |
+
shiftx = random.randint(0, valid_shiftx)
|
292 |
+
pos_list[:, 0] += shiftx
|
293 |
+
else:
|
294 |
+
pos_list[:, 0] += valid_shiftx // 2
|
295 |
+
|
296 |
+
for pos, fg_dict, depth in zip(pos_list, fg_list, depth_list):
|
297 |
+
fg_dict['pos'] = pos
|
298 |
+
fg_dict['depth'] = depth
|
299 |
+
fg_list.sort(key=lambda x: x['depth'], reverse=True)
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
def regular_paste(fg_list, bg: Image, regen_bboxes=False):
|
304 |
+
segments_info = []
|
305 |
+
im_h, im_w = bg.height, bg.width
|
306 |
+
|
307 |
+
if regen_bboxes:
|
308 |
+
random.shuffle(fg_list)
|
309 |
+
gen_fg_regbboxes(fg_list, im_h)
|
310 |
+
|
311 |
+
segments = np.zeros((im_h, im_w, 3), np.uint8)
|
312 |
+
for ii, fg_dict in enumerate(fg_list):
|
313 |
+
fg = fg_dict['image']
|
314 |
+
|
315 |
+
px, py = fg_dict.pop('pos')
|
316 |
+
fg_pil = Image.fromarray(fg)
|
317 |
+
|
318 |
+
seg_color = COLOR_PALETTE[ii]
|
319 |
+
area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=True)
|
320 |
+
|
321 |
+
segments_info.append({
|
322 |
+
'id': rgb2id(seg_color),
|
323 |
+
'bbox': bbox,
|
324 |
+
'area': area
|
325 |
+
})
|
326 |
+
|
327 |
+
return segments_info, segments
|
animeinsseg/data/sampler.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from random import choice as rchoice
|
3 |
+
from random import randint
|
4 |
+
import random
|
5 |
+
import cv2, traceback, imageio
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
from typing import Optional, List, Union, Tuple, Dict
|
9 |
+
from utils.io_utils import imread_nogrey_rgb, json2dict
|
10 |
+
from .transforms import rotate_image
|
11 |
+
from utils.logger import LOGGER
|
12 |
+
|
13 |
+
|
14 |
+
class NameSampler:
|
15 |
+
|
16 |
+
def __init__(self, name_prob_dict, sample_num=2048) -> None:
|
17 |
+
self.name_prob_dict = name_prob_dict
|
18 |
+
self._id2name = list(name_prob_dict.keys())
|
19 |
+
self.sample_ids = []
|
20 |
+
|
21 |
+
total_prob = 0.
|
22 |
+
for ii, (_, prob) in enumerate(name_prob_dict.items()):
|
23 |
+
tgt_num = int(prob * sample_num)
|
24 |
+
total_prob += prob
|
25 |
+
if tgt_num > 0:
|
26 |
+
self.sample_ids += [ii] * tgt_num
|
27 |
+
|
28 |
+
nsamples = len(self.sample_ids)
|
29 |
+
assert prob <= 1
|
30 |
+
if prob < 1 and nsamples < sample_num:
|
31 |
+
self.sample_ids += [len(self._id2name)] * (sample_num - nsamples)
|
32 |
+
self._id2name.append('_')
|
33 |
+
|
34 |
+
def sample(self) -> str:
|
35 |
+
return self._id2name[rchoice(self.sample_ids)]
|
36 |
+
|
37 |
+
|
38 |
+
class PossionSampler:
|
39 |
+
def __init__(self, lam=3, min_val=1, max_val=8) -> None:
|
40 |
+
self._distr = np.random.poisson(lam, 1024)
|
41 |
+
invalid = np.where(np.logical_or(self._distr<min_val, self._distr > max_val))
|
42 |
+
self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0]))
|
43 |
+
|
44 |
+
def sample(self) -> int:
|
45 |
+
return rchoice(self._distr)
|
46 |
+
|
47 |
+
|
48 |
+
class NormalSampler:
|
49 |
+
def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True):
|
50 |
+
s = np.random.normal(loc, std, 4096)
|
51 |
+
valid = np.where(np.logical_and(s>min_scale, s<max_scale))
|
52 |
+
self._distr = s[valid] * scalar
|
53 |
+
if to_int:
|
54 |
+
self._distr = self._distr.astype(np.int32)
|
55 |
+
|
56 |
+
def sample(self) -> int:
|
57 |
+
return rchoice(self._distr)
|
58 |
+
|
59 |
+
|
60 |
+
class PersonBBoxSampler:
|
61 |
+
|
62 |
+
def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None:
|
63 |
+
if isinstance(sample_path, str):
|
64 |
+
sample_path = [sample_path]
|
65 |
+
self.bbox_list = []
|
66 |
+
for sp in sample_path:
|
67 |
+
bboxlist = json2dict(sp)
|
68 |
+
for bboxes in bboxlist:
|
69 |
+
if isinstance(bboxes, dict):
|
70 |
+
bboxes = bboxes['bboxes']
|
71 |
+
bboxes = np.array(bboxes)
|
72 |
+
bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0)
|
73 |
+
self.bbox_list.append(bboxes)
|
74 |
+
|
75 |
+
self.fg_info_list = fg_info_list
|
76 |
+
self.fg_transform = fg_transform
|
77 |
+
self.is_train = is_train
|
78 |
+
|
79 |
+
def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]:
|
80 |
+
bboxes_normalized = rchoice(self.bbox_list)
|
81 |
+
if scale_range[0] != 1 or scale_range[1] != 1:
|
82 |
+
bbox_scale = random.uniform(scale_range[0], scale_range[1])
|
83 |
+
else:
|
84 |
+
bbox_scale = 1
|
85 |
+
bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32)
|
86 |
+
|
87 |
+
xyxy_array = np.copy(bboxes)
|
88 |
+
xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]]
|
89 |
+
x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max()
|
90 |
+
|
91 |
+
x_shift = tgt_size - x_max
|
92 |
+
x_shift = randint(0, x_shift) if x_shift > 0 else 0
|
93 |
+
y_shift = tgt_size - y_max
|
94 |
+
y_shift = randint(0, y_shift) if y_shift > 0 else 0
|
95 |
+
|
96 |
+
bboxes[:, [0, 1]] += [x_shift, y_shift]
|
97 |
+
valid_bboxes = []
|
98 |
+
max_size = size_thres[1] * tgt_size
|
99 |
+
min_size = size_thres[0] * tgt_size
|
100 |
+
for bbox in bboxes:
|
101 |
+
w = min(bbox[2], tgt_size - bbox[0])
|
102 |
+
h = min(bbox[3], tgt_size - bbox[1])
|
103 |
+
if max(h, w) < max_size and min(h, w) > min_size:
|
104 |
+
valid_bboxes.append(bbox)
|
105 |
+
return valid_bboxes
|
106 |
+
|
107 |
+
def sample_matchfg(self, tgt_size: int):
|
108 |
+
while True:
|
109 |
+
bboxes = self.sample(tgt_size, (1.1, 1.8))
|
110 |
+
if len(bboxes) > 0:
|
111 |
+
break
|
112 |
+
MIN_FG_SIZE = 20
|
113 |
+
num_fg = len(bboxes)
|
114 |
+
rotate = 20 if self.is_train else 15
|
115 |
+
fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate)
|
116 |
+
assert len(fgs) == num_fg
|
117 |
+
|
118 |
+
bboxes.sort(key=lambda x: x[2] / x[3])
|
119 |
+
fgs.sort(key=lambda x: x['asp_ratio'])
|
120 |
+
|
121 |
+
for fg, bbox in zip(fgs, bboxes):
|
122 |
+
x, y, w, h = bbox
|
123 |
+
img = fg['image']
|
124 |
+
im_h, im_w = img.shape[:2]
|
125 |
+
if im_h < h and im_w < w:
|
126 |
+
scale = min(h / im_h, w / im_w)
|
127 |
+
new_h, new_w = int(scale * im_h), int(scale * im_w)
|
128 |
+
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
129 |
+
else:
|
130 |
+
scale_h, scale_w = min(1, h / im_h), min(1, w / im_w)
|
131 |
+
scale = (scale_h + scale_w) / 2
|
132 |
+
if scale < 1:
|
133 |
+
new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE)
|
134 |
+
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
135 |
+
|
136 |
+
if self.fg_transform is not None:
|
137 |
+
img = self.fg_transform(image=img)['image']
|
138 |
+
|
139 |
+
im_h, im_w = img.shape[:2]
|
140 |
+
fg['image'] = img
|
141 |
+
px = int(x + w / 2 - im_w / 2)
|
142 |
+
py = int(y + h / 2 - im_h / 2)
|
143 |
+
fg['pos'] = (px, py)
|
144 |
+
|
145 |
+
random.shuffle(fgs)
|
146 |
+
|
147 |
+
slist, llist = [], []
|
148 |
+
large_size = int(tgt_size * 0.55)
|
149 |
+
for fg in fgs:
|
150 |
+
if max(fg['image'].shape[:2]) > large_size:
|
151 |
+
llist.append(fg)
|
152 |
+
else:
|
153 |
+
slist.append(fg)
|
154 |
+
return llist + slist
|
155 |
+
|
156 |
+
|
157 |
+
def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.):
|
158 |
+
fgs = []
|
159 |
+
while len(fgs) < num_fg:
|
160 |
+
fg, fginfo = random_load_valid_fg(fg_info_list)
|
161 |
+
if random.random() < random_rotate_prob:
|
162 |
+
rotate_deg = randint(-random_rotate, random_rotate)
|
163 |
+
fg = rotate_image(fg, rotate_deg, alpha_crop=True)
|
164 |
+
|
165 |
+
asp_ratio = fg.shape[1] / fg.shape[0]
|
166 |
+
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
|
167 |
+
while len(fgs) < num_fg and random.random() < 0.12:
|
168 |
+
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo})
|
169 |
+
|
170 |
+
return fgs
|
171 |
+
|
172 |
+
|
173 |
+
def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]:
|
174 |
+
while True:
|
175 |
+
item = fginfo = rchoice(fg_info_list)
|
176 |
+
|
177 |
+
file_path = fginfo['file_path']
|
178 |
+
if 'root_dir' in fginfo and fginfo['root_dir']:
|
179 |
+
file_path = osp.join(fginfo['root_dir'], file_path)
|
180 |
+
|
181 |
+
try:
|
182 |
+
fg = imageio.imread(file_path)
|
183 |
+
except:
|
184 |
+
LOGGER.error(traceback.format_exc())
|
185 |
+
LOGGER.error(f'invalid fg: {file_path}')
|
186 |
+
fg_info_list.remove(item)
|
187 |
+
continue
|
188 |
+
|
189 |
+
c = 1
|
190 |
+
if len(fg.shape) == 3:
|
191 |
+
c = fg.shape[-1]
|
192 |
+
if c != 4:
|
193 |
+
LOGGER.warning(f'fg {file_path} doesnt have alpha channel')
|
194 |
+
fg_info_list.remove(item)
|
195 |
+
else:
|
196 |
+
if 'xyxy' in fginfo:
|
197 |
+
x1, y1, x2, y2 = fginfo['xyxy']
|
198 |
+
else:
|
199 |
+
oh, ow = fg.shape[:2]
|
200 |
+
ksize = 5
|
201 |
+
mask = cv2.blur(fg[..., 3], (ksize,ksize))
|
202 |
+
_, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY)
|
203 |
+
|
204 |
+
x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask))
|
205 |
+
x2, y2 = x1 + w, y1 + h
|
206 |
+
if oh - h > 15 or ow - w > 15:
|
207 |
+
crop = True
|
208 |
+
else:
|
209 |
+
x1 = y1 = 0
|
210 |
+
x2, y2 = ow, oh
|
211 |
+
|
212 |
+
fginfo['xyxy'] = [x1, y1, x2, y2]
|
213 |
+
fg = fg[y1: y2, x1: x2]
|
214 |
+
return fg, fginfo
|
215 |
+
|
216 |
+
|
217 |
+
def random_load_valid_bg(bg_list: List[str]) -> np.ndarray:
|
218 |
+
while True:
|
219 |
+
try:
|
220 |
+
bgp = rchoice(bg_list)
|
221 |
+
return imread_nogrey_rgb(bgp)
|
222 |
+
except:
|
223 |
+
LOGGER.error(traceback.format_exc())
|
224 |
+
LOGGER.error(f'invalid bg: {bgp}')
|
225 |
+
bg_list.remove(bgp)
|
226 |
+
continue
|
animeinsseg/data/syndataset.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Union, Tuple, Dict
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import imageio, os
|
7 |
+
import os.path as osp
|
8 |
+
from tqdm import tqdm
|
9 |
+
from panopticapi.utils import rgb2id
|
10 |
+
import traceback
|
11 |
+
|
12 |
+
from utils.io_utils import mask2rle, dict2json, fgbg_hist_matching
|
13 |
+
from utils.logger import LOGGER
|
14 |
+
from utils.constants import CATEGORIES, IMAGE_ID_ZFILL
|
15 |
+
from .transforms import get_fg_transforms, get_bg_transforms, quantize_image, resize2height, rotate_image
|
16 |
+
from .sampler import random_load_valid_bg, random_load_valid_fg, NameSampler, NormalSampler, PossionSampler, PersonBBoxSampler
|
17 |
+
from .paste_methods import regular_paste, partition_paste
|
18 |
+
|
19 |
+
|
20 |
+
def syn_animecoco_dataset(
|
21 |
+
bg_list: List, fg_info_list: List[Dict], dataset_save_dir: str, policy: str='train',
|
22 |
+
tgt_size=640, syn_num_multiplier=2.5, regular_paste_prob=0.4, person_paste_prob=0.4,
|
23 |
+
max_syn_num=-1, image_id_start=0, obj_id_start=0, hist_match_prob=0.2, quantize_prob=0.25):
|
24 |
+
|
25 |
+
LOGGER.info(f'syn data policy: {policy}')
|
26 |
+
LOGGER.info(f'background: {len(bg_list)} foreground: {len(fg_info_list)}')
|
27 |
+
|
28 |
+
numfg_sampler = PossionSampler(min_val=1, max_val=9, lam=2.5)
|
29 |
+
numfg_regpaste_sampler = PossionSampler(min_val=2, max_val=9, lam=3.5)
|
30 |
+
regpaste_size_sampler = NormalSampler(scalar=tgt_size, to_int=True, max_scale=0.75)
|
31 |
+
color_correction_sampler = NameSampler({'hist_match': hist_match_prob, 'quantize': quantize_prob}, )
|
32 |
+
paste_method_sampler = NameSampler({'regular': regular_paste_prob, 'personbbox': person_paste_prob,
|
33 |
+
'partition': 1-regular_paste_prob-person_paste_prob})
|
34 |
+
|
35 |
+
fg_transform = get_fg_transforms(tgt_size, transform_variant=policy)
|
36 |
+
fg_distort_transform = get_fg_transforms(tgt_size, transform_variant='distort_only')
|
37 |
+
bg_transform = get_bg_transforms('train', tgt_size)
|
38 |
+
|
39 |
+
image_id = image_id_start + 1
|
40 |
+
obj_id = obj_id_start + 1
|
41 |
+
|
42 |
+
det_annotations, image_meta = [], []
|
43 |
+
|
44 |
+
syn_num = int(syn_num_multiplier * len(fg_info_list))
|
45 |
+
if max_syn_num > 0:
|
46 |
+
syn_num = max_syn_num
|
47 |
+
|
48 |
+
ann_save_dir = osp.join(dataset_save_dir, 'annotations')
|
49 |
+
image_save_dir = osp.join(dataset_save_dir, policy)
|
50 |
+
|
51 |
+
if not osp.exists(image_save_dir):
|
52 |
+
os.makedirs(image_save_dir)
|
53 |
+
if not osp.exists(ann_save_dir):
|
54 |
+
os.makedirs(ann_save_dir)
|
55 |
+
|
56 |
+
is_train = policy == 'train'
|
57 |
+
if is_train:
|
58 |
+
jpg_save_quality = [75, 85, 95]
|
59 |
+
else:
|
60 |
+
jpg_save_quality = [95]
|
61 |
+
|
62 |
+
if isinstance(fg_info_list[0], str):
|
63 |
+
for ii, fgp in enumerate(fg_info_list):
|
64 |
+
if isinstance(fgp, str):
|
65 |
+
fg_info_list[ii] = {'file_path': fgp, 'tag_string': [], 'danbooru': False, 'category_id': 0}
|
66 |
+
|
67 |
+
if person_paste_prob > 0:
|
68 |
+
personbbox_sampler = PersonBBoxSampler(
|
69 |
+
'data/cocoperson_bbox_samples.json', fg_info_list,
|
70 |
+
fg_transform=fg_distort_transform if is_train else None, is_train=is_train)
|
71 |
+
|
72 |
+
total = tqdm(range(syn_num))
|
73 |
+
for fin in total:
|
74 |
+
try:
|
75 |
+
paste_method = paste_method_sampler.sample()
|
76 |
+
|
77 |
+
fgs = []
|
78 |
+
if paste_method == 'regular':
|
79 |
+
num_fg = numfg_regpaste_sampler.sample()
|
80 |
+
size = regpaste_size_sampler.sample()
|
81 |
+
while len(fgs) < num_fg:
|
82 |
+
tgt_height = int(random.uniform(0.7, 1.2) * size)
|
83 |
+
fg, fginfo = random_load_valid_fg(fg_info_list)
|
84 |
+
fg = resize2height(fg, tgt_height)
|
85 |
+
if is_train:
|
86 |
+
fg = fg_distort_transform(image=fg)['image']
|
87 |
+
rotate_deg = random.randint(-40, 40)
|
88 |
+
else:
|
89 |
+
rotate_deg = random.randint(-30, 30)
|
90 |
+
if random.random() < 0.3:
|
91 |
+
fg = rotate_image(fg, rotate_deg, alpha_crop=True)
|
92 |
+
fgs.append({'image': fg, 'fginfo': fginfo})
|
93 |
+
while len(fgs) < num_fg and random.random() < 0.15:
|
94 |
+
fgs.append({'image': fg, 'fginfo': fginfo})
|
95 |
+
elif paste_method == 'personbbox':
|
96 |
+
fgs = personbbox_sampler.sample_matchfg(tgt_size)
|
97 |
+
else:
|
98 |
+
num_fg = numfg_sampler.sample()
|
99 |
+
fgs = []
|
100 |
+
for ii in range(num_fg):
|
101 |
+
fg, fginfo = random_load_valid_fg(fg_info_list)
|
102 |
+
fg = fg_transform(image=fg)['image']
|
103 |
+
h, w = fg.shape[:2]
|
104 |
+
if num_fg > 6:
|
105 |
+
downscale = min(tgt_size / 2.5 / w, tgt_size / 2.5 / h)
|
106 |
+
if downscale < 1:
|
107 |
+
fg = cv2.resize(fg, (int(w * downscale), int(h * downscale)), interpolation=cv2.INTER_AREA)
|
108 |
+
fgs.append({'image': fg, 'fginfo': fginfo})
|
109 |
+
|
110 |
+
bg = random_load_valid_bg(bg_list)
|
111 |
+
bg = bg_transform(image=bg)['image']
|
112 |
+
|
113 |
+
color_correct = color_correction_sampler.sample()
|
114 |
+
|
115 |
+
if color_correct == 'hist_match':
|
116 |
+
fgbg_hist_matching(fgs, bg)
|
117 |
+
|
118 |
+
bg: Image = Image.fromarray(bg)
|
119 |
+
|
120 |
+
if paste_method == 'regular':
|
121 |
+
segments_info, segments = regular_paste(fgs, bg, regen_bboxes=True)
|
122 |
+
elif paste_method == 'personbbox':
|
123 |
+
segments_info, segments = regular_paste(fgs, bg, regen_bboxes=False)
|
124 |
+
elif paste_method == 'partition':
|
125 |
+
segments_info, segments = partition_paste(fgs, bg, )
|
126 |
+
else:
|
127 |
+
print(f'invalid paste method: {paste_method}')
|
128 |
+
raise NotImplementedError
|
129 |
+
|
130 |
+
image = np.array(bg)
|
131 |
+
if color_correct == 'quantize':
|
132 |
+
mask = cv2.inRange(segments, np.array([0,0,0]), np.array([0,0,0]))
|
133 |
+
# cv2.imshow("mask", mask)
|
134 |
+
image = quantize_image(image, random.choice([12, 16, 32]), 'kmeans', mask=mask)[0]
|
135 |
+
|
136 |
+
# postprocess & check if instance is valid
|
137 |
+
for ii, segi in enumerate(segments_info):
|
138 |
+
if segi['area'] == 0:
|
139 |
+
continue
|
140 |
+
x, y, w, h = segi['bbox']
|
141 |
+
x2, y2 = x+w, y+h
|
142 |
+
c = segments[y: y2, x: x2]
|
143 |
+
pan_png = rgb2id(c)
|
144 |
+
cmask = (pan_png == segi['id'])
|
145 |
+
area = cmask.sum()
|
146 |
+
|
147 |
+
if paste_method != 'partition' and \
|
148 |
+
area / (fgs[ii]['image'][..., 3] > 30).sum() < 0.25:
|
149 |
+
# cv2.imshow('im', fgs[ii]['image'])
|
150 |
+
# cv2.imshow('mask', fgs[ii]['image'][..., 3])
|
151 |
+
# cv2.imshow('seg', segments)
|
152 |
+
# cv2.waitKey(0)
|
153 |
+
cmask_ids = np.where(cmask)
|
154 |
+
segments[y: y2, x: x2][cmask_ids] = 0
|
155 |
+
image[y: y2, x: x2][cmask_ids] = (127, 127, 127)
|
156 |
+
continue
|
157 |
+
|
158 |
+
cmask = cmask.astype(np.uint8) * 255
|
159 |
+
dx, dy, w, h = cv2.boundingRect(cv2.findNonZero(cmask))
|
160 |
+
_bbox = [dx + x, dy + y, w, h]
|
161 |
+
|
162 |
+
seg = cv2.copyMakeBorder(cmask, y, tgt_size-y2, x, tgt_size-x2, cv2.BORDER_CONSTANT) > 0
|
163 |
+
assert seg.shape[0] == tgt_size and seg.shape[1] == tgt_size
|
164 |
+
segmentation = mask2rle(seg)
|
165 |
+
|
166 |
+
det_annotations.append({
|
167 |
+
'id': obj_id,
|
168 |
+
'category_id': fgs[ii]['fginfo']['category_id'],
|
169 |
+
'iscrowd': 0,
|
170 |
+
'segmentation': segmentation,
|
171 |
+
'image_id': image_id,
|
172 |
+
'area': area,
|
173 |
+
'tag_string': fgs[ii]['fginfo']['tag_string'],
|
174 |
+
'tag_string_character': fgs[ii]['fginfo']['tag_string_character'],
|
175 |
+
'bbox': [float(c) for c in _bbox]
|
176 |
+
})
|
177 |
+
|
178 |
+
obj_id += 1
|
179 |
+
# cv2.imshow('c', cv2.cvtColor(c, cv2.COLOR_RGB2BGR))
|
180 |
+
# cv2.imshow('cmask', cmask)
|
181 |
+
# cv2.waitKey(0)
|
182 |
+
|
183 |
+
image_id_str = str(image_id).zfill(IMAGE_ID_ZFILL)
|
184 |
+
image_file_name = image_id_str + '.jpg'
|
185 |
+
image_meta.append({
|
186 |
+
"id": image_id,"height": tgt_size,"width": tgt_size, "file_name": image_file_name, "id": image_id
|
187 |
+
})
|
188 |
+
|
189 |
+
# LOGGER.info(f'paste method: {paste_method} color correct: {color_correct}')
|
190 |
+
# cv2.imshow('image', cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
191 |
+
# cv2.imshow('segments', cv2.cvtColor(segments, cv2.COLOR_RGB2BGR))
|
192 |
+
# cv2.waitKey(0)
|
193 |
+
|
194 |
+
imageio.imwrite(osp.join(image_save_dir, image_file_name), image, quality=random.choice(jpg_save_quality))
|
195 |
+
image_id += 1
|
196 |
+
|
197 |
+
except:
|
198 |
+
LOGGER.error(traceback.format_exc())
|
199 |
+
continue
|
200 |
+
|
201 |
+
det_meta = {
|
202 |
+
"info": {},
|
203 |
+
"licenses": [],
|
204 |
+
"images": image_meta,
|
205 |
+
"annotations": det_annotations,
|
206 |
+
"categories": CATEGORIES
|
207 |
+
}
|
208 |
+
|
209 |
+
detp = osp.join(ann_save_dir, f'det_{policy}.json')
|
210 |
+
dict2json(det_meta, detp)
|
211 |
+
LOGGER.info(f'annotations saved to {detp}')
|
212 |
+
|
213 |
+
return image_id, obj_id
|
animeinsseg/data/transforms.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
from albumentations import DualIAATransform, to_tuple
|
3 |
+
import imgaug.augmenters as iaa
|
4 |
+
import cv2
|
5 |
+
from tqdm import tqdm
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
from sklearn.metrics import pairwise_distances_argmin
|
8 |
+
from sklearn.utils import shuffle
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
class IAAAffine2(DualIAATransform):
|
12 |
+
"""Place a regular grid of points on the input and randomly move the neighbourhood of these point around
|
13 |
+
via affine transformations.
|
14 |
+
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
|
15 |
+
Args:
|
16 |
+
p (float): probability of applying the transform. Default: 0.5.
|
17 |
+
Targets:
|
18 |
+
image, mask
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
scale=(0.7, 1.3),
|
24 |
+
translate_percent=None,
|
25 |
+
translate_px=None,
|
26 |
+
rotate=0.0,
|
27 |
+
shear=(-0.1, 0.1),
|
28 |
+
order=1,
|
29 |
+
cval=0,
|
30 |
+
mode="reflect",
|
31 |
+
always_apply=False,
|
32 |
+
p=0.5,
|
33 |
+
):
|
34 |
+
super(IAAAffine2, self).__init__(always_apply, p)
|
35 |
+
self.scale = dict(x=scale, y=scale)
|
36 |
+
self.translate_percent = to_tuple(translate_percent, 0)
|
37 |
+
self.translate_px = to_tuple(translate_px, 0)
|
38 |
+
self.rotate = to_tuple(rotate)
|
39 |
+
self.shear = dict(x=shear, y=shear)
|
40 |
+
self.order = order
|
41 |
+
self.cval = cval
|
42 |
+
self.mode = mode
|
43 |
+
|
44 |
+
@property
|
45 |
+
def processor(self):
|
46 |
+
return iaa.Affine(
|
47 |
+
self.scale,
|
48 |
+
self.translate_percent,
|
49 |
+
self.translate_px,
|
50 |
+
self.rotate,
|
51 |
+
self.shear,
|
52 |
+
self.order,
|
53 |
+
self.cval,
|
54 |
+
self.mode,
|
55 |
+
)
|
56 |
+
|
57 |
+
def get_transform_init_args_names(self):
|
58 |
+
return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
|
59 |
+
|
60 |
+
|
61 |
+
class IAAPerspective2(DualIAATransform):
|
62 |
+
"""Perform a random four point perspective transform of the input.
|
63 |
+
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
|
64 |
+
Args:
|
65 |
+
scale ((float, float): standard deviation of the normal distributions. These are used to sample
|
66 |
+
the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
|
67 |
+
p (float): probability of applying the transform. Default: 0.5.
|
68 |
+
Targets:
|
69 |
+
image, mask
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5,
|
73 |
+
order=1, cval=0, mode="replicate"):
|
74 |
+
super(IAAPerspective2, self).__init__(always_apply, p)
|
75 |
+
self.scale = to_tuple(scale, 1.0)
|
76 |
+
self.keep_size = keep_size
|
77 |
+
self.cval = cval
|
78 |
+
self.mode = mode
|
79 |
+
|
80 |
+
@property
|
81 |
+
def processor(self):
|
82 |
+
return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)
|
83 |
+
|
84 |
+
def get_transform_init_args_names(self):
|
85 |
+
return ("scale", "keep_size")
|
86 |
+
|
87 |
+
|
88 |
+
def get_bg_transforms(transform_variant, out_size):
|
89 |
+
max_size = int(out_size * 1.2)
|
90 |
+
if transform_variant == 'train':
|
91 |
+
transform = [
|
92 |
+
A.SmallestMaxSize(max_size, always_apply=True, interpolation=cv2.INTER_AREA),
|
93 |
+
A.RandomResizedCrop(out_size, out_size, scale=(0.9, 1.5), p=1, ratio=(0.9, 1.1)),
|
94 |
+
]
|
95 |
+
else:
|
96 |
+
transform = [
|
97 |
+
A.SmallestMaxSize(out_size, always_apply=True),
|
98 |
+
A.RandomCrop(out_size, out_size, True),
|
99 |
+
]
|
100 |
+
return A.Compose(transform)
|
101 |
+
|
102 |
+
|
103 |
+
def get_fg_transforms(out_size, scale_limit=(-0.85, -0.3), transform_variant='train'):
|
104 |
+
if transform_variant == 'train':
|
105 |
+
transform = [
|
106 |
+
A.LongestMaxSize(out_size),
|
107 |
+
A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_AREA),
|
108 |
+
IAAAffine2(scale=(1, 1),
|
109 |
+
rotate=(-15, 15),
|
110 |
+
shear=(-0.1, 0.1), p=0.3, mode='constant'),
|
111 |
+
IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
|
112 |
+
A.HorizontalFlip(),
|
113 |
+
A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
|
114 |
+
A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
|
115 |
+
]
|
116 |
+
elif transform_variant == 'distort_only':
|
117 |
+
transform = [
|
118 |
+
IAAAffine2(scale=(1, 1),
|
119 |
+
shear=(-0.1, 0.1), p=0.3, mode='constant'),
|
120 |
+
IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
|
121 |
+
A.HorizontalFlip(),
|
122 |
+
A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT, p=0.3),
|
123 |
+
A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, p=0.3)
|
124 |
+
]
|
125 |
+
else:
|
126 |
+
transform = [
|
127 |
+
A.LongestMaxSize(out_size),
|
128 |
+
A.RandomScale(scale_limit=scale_limit, always_apply=True, interpolation=cv2.INTER_LINEAR)
|
129 |
+
]
|
130 |
+
return A.Compose(transform)
|
131 |
+
|
132 |
+
|
133 |
+
def get_transforms(transform_variant, out_size, to_float=True):
|
134 |
+
if transform_variant == 'distortions':
|
135 |
+
transform = [
|
136 |
+
IAAAffine2(scale=(1, 1.3),
|
137 |
+
rotate=(-20, 20),
|
138 |
+
shear=(-0.1, 0.1), p=1, mode='constant'),
|
139 |
+
IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
|
140 |
+
A.OpticalDistortion(),
|
141 |
+
A.HorizontalFlip(),
|
142 |
+
A.Sharpen(p=0.3),
|
143 |
+
A.CLAHE(),
|
144 |
+
A.GaussNoise(p=0.3),
|
145 |
+
A.Posterize(),
|
146 |
+
A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
|
147 |
+
]
|
148 |
+
elif transform_variant == 'default':
|
149 |
+
transform = [
|
150 |
+
A.HorizontalFlip(),
|
151 |
+
A.Rotate(20, p=0.3)
|
152 |
+
]
|
153 |
+
elif transform_variant == 'identity':
|
154 |
+
transform = []
|
155 |
+
else:
|
156 |
+
raise ValueError(f'Unexpected transform_variant {transform_variant}')
|
157 |
+
if to_float:
|
158 |
+
transform.append(A.ToFloat())
|
159 |
+
return A.Compose(transform)
|
160 |
+
|
161 |
+
|
162 |
+
def get_template_transforms(transform_variant, out_size, to_float=True):
|
163 |
+
if transform_variant == 'distortions':
|
164 |
+
transform = [
|
165 |
+
A.Cutout(p=0.3, max_w_size=30, max_h_size=30, num_holes=1),
|
166 |
+
IAAAffine2(scale=(1, 1.3),
|
167 |
+
rotate=(-20, 20),
|
168 |
+
shear=(-0.1, 0.1), p=1, mode='constant'),
|
169 |
+
IAAPerspective2(scale=(0.0, 0.06), p=0.3, mode='constant'),
|
170 |
+
A.OpticalDistortion(),
|
171 |
+
A.HorizontalFlip(),
|
172 |
+
A.Sharpen(p=0.3),
|
173 |
+
A.CLAHE(),
|
174 |
+
A.GaussNoise(p=0.3),
|
175 |
+
A.Posterize(),
|
176 |
+
A.ElasticTransform(alpha=0.3, sigma=15, alpha_affine=15, border_mode=cv2.BORDER_CONSTANT),
|
177 |
+
]
|
178 |
+
elif transform_variant == 'identity':
|
179 |
+
transform = []
|
180 |
+
else:
|
181 |
+
raise ValueError(f'Unexpected transform_variant {transform_variant}')
|
182 |
+
if to_float:
|
183 |
+
transform.append(A.ToFloat())
|
184 |
+
return A.Compose(transform)
|
185 |
+
|
186 |
+
|
187 |
+
def rotate_image(mat: np.ndarray, angle: float, alpha_crop: bool = False) -> np.ndarray:
|
188 |
+
"""
|
189 |
+
Rotates an image (angle in degrees) and expands image to avoid cropping
|
190 |
+
# https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
|
191 |
+
"""
|
192 |
+
|
193 |
+
height, width = mat.shape[:2] # image shape has 3 dimensions
|
194 |
+
image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
|
195 |
+
|
196 |
+
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
|
197 |
+
|
198 |
+
# rotation calculates the cos and sin, taking absolutes of those.
|
199 |
+
abs_cos = abs(rotation_mat[0,0])
|
200 |
+
abs_sin = abs(rotation_mat[0,1])
|
201 |
+
|
202 |
+
# find the new width and height bounds
|
203 |
+
bound_w = int(height * abs_sin + width * abs_cos)
|
204 |
+
bound_h = int(height * abs_cos + width * abs_sin)
|
205 |
+
|
206 |
+
# subtract old image center (bringing image back to origo) and adding the new image center coordinates
|
207 |
+
rotation_mat[0, 2] += bound_w/2 - image_center[0]
|
208 |
+
rotation_mat[1, 2] += bound_h/2 - image_center[1]
|
209 |
+
|
210 |
+
# rotate image with the new bounds and translated rotation matrix
|
211 |
+
rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
|
212 |
+
|
213 |
+
if alpha_crop and len(rotated_mat.shape) == 3 and rotated_mat.shape[-1] == 4:
|
214 |
+
x, y, w, h = cv2.boundingRect(rotated_mat[..., -1])
|
215 |
+
rotated_mat = rotated_mat[y: y+h, x: x+w]
|
216 |
+
|
217 |
+
return rotated_mat
|
218 |
+
|
219 |
+
|
220 |
+
def recreate_image(codebook, labels, w, h):
|
221 |
+
"""Recreate the (compressed) image from the code book & labels"""
|
222 |
+
return (codebook[labels].reshape(w, h, -1) * 255).astype(np.uint8)
|
223 |
+
|
224 |
+
def quantize_image(image: np.ndarray, n_colors: int, method='kmeans', mask=None):
|
225 |
+
# https://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html
|
226 |
+
image = np.array(image, dtype=np.float64) / 255
|
227 |
+
|
228 |
+
if len(image.shape) == 3:
|
229 |
+
w, h, d = tuple(image.shape)
|
230 |
+
else:
|
231 |
+
w, h = image.shape
|
232 |
+
d = 1
|
233 |
+
|
234 |
+
# assert d == 3
|
235 |
+
image_array = image.reshape(-1, d)
|
236 |
+
|
237 |
+
if method == 'kmeans':
|
238 |
+
|
239 |
+
image_array_sample = None
|
240 |
+
if mask is not None:
|
241 |
+
ids = np.where(mask)
|
242 |
+
if len(ids[0]) > 10:
|
243 |
+
bg = image[ids][::2]
|
244 |
+
fg = image[np.where(mask == 0)]
|
245 |
+
max_bg_num = int(fg.shape[0] * 1.5)
|
246 |
+
if bg.shape[0] > max_bg_num:
|
247 |
+
bg = shuffle(bg, random_state=0, n_samples=max_bg_num)
|
248 |
+
image_array_sample = np.concatenate((fg, bg), axis=0)
|
249 |
+
if image_array_sample.shape[0] > 2048:
|
250 |
+
image_array_sample = shuffle(image_array_sample, random_state=0, n_samples=2048)
|
251 |
+
else:
|
252 |
+
image_array_sample = None
|
253 |
+
|
254 |
+
if image_array_sample is None:
|
255 |
+
image_array_sample = shuffle(image_array, random_state=0, n_samples=2048)
|
256 |
+
|
257 |
+
kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(
|
258 |
+
image_array_sample
|
259 |
+
)
|
260 |
+
|
261 |
+
labels = kmeans.predict(image_array)
|
262 |
+
quantized = recreate_image(kmeans.cluster_centers_, labels, w, h)
|
263 |
+
return quantized, kmeans.cluster_centers_, labels
|
264 |
+
|
265 |
+
else:
|
266 |
+
|
267 |
+
codebook_random = shuffle(image_array, random_state=0, n_samples=n_colors)
|
268 |
+
labels_random = pairwise_distances_argmin(codebook_random, image_array, axis=0)
|
269 |
+
|
270 |
+
return [recreate_image(codebook_random, labels_random, w, h)]
|
271 |
+
|
272 |
+
|
273 |
+
def resize2height(img: np.ndarray, height: int):
|
274 |
+
im_h, im_w = img.shape[:2]
|
275 |
+
if im_h > height:
|
276 |
+
interpolation = cv2.INTER_AREA
|
277 |
+
else:
|
278 |
+
interpolation = cv2.INTER_LINEAR
|
279 |
+
if im_h != height:
|
280 |
+
img = cv2.resize(img, (int(height / im_h * im_w), height), interpolation=interpolation)
|
281 |
+
return img
|
282 |
+
|
283 |
+
if __name__ == '__main__':
|
284 |
+
import os.path as osp
|
285 |
+
|
286 |
+
img_path = r'tmp\megumin.png'
|
287 |
+
save_dir = r'tmp'
|
288 |
+
sample_num = 24
|
289 |
+
|
290 |
+
tv = 'distortions'
|
291 |
+
out_size = 224
|
292 |
+
transforms = get_transforms(tv, out_size ,to_float=False)
|
293 |
+
img = cv2.imread(img_path)
|
294 |
+
for idx in tqdm(range(sample_num)):
|
295 |
+
transformed = transforms(image=img)['image']
|
296 |
+
print(transformed.shape)
|
297 |
+
cv2.imwrite(osp.join(save_dir, str(idx)+'-transform.jpg'), transformed)
|
298 |
+
# cv2.waitKey(0)
|
299 |
+
pass
|
animeinsseg/inpainting/__init__.py
ADDED
File without changes
|
animeinsseg/inpainting/ldm_inpaint.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import safetensors
|
6 |
+
import os
|
7 |
+
import einops
|
8 |
+
import cv2
|
9 |
+
from PIL import Image, ImageFilter, ImageOps
|
10 |
+
from utils.io_utils import resize_pad2divisior
|
11 |
+
import os
|
12 |
+
from utils.io_utils import submit_request, img2b64
|
13 |
+
import json
|
14 |
+
# Debug by Francis
|
15 |
+
# from ldm.util import instantiate_from_config
|
16 |
+
# from ldm.models.diffusion.ddpm import LatentDiffusion
|
17 |
+
# from ldm.models.diffusion.ddim import DDIMSampler
|
18 |
+
# from ldm.modules.diffusionmodules.util import noise_like
|
19 |
+
import io
|
20 |
+
import base64
|
21 |
+
from requests.auth import HTTPBasicAuth
|
22 |
+
|
23 |
+
# Debug by Francis
|
24 |
+
# def create_model(config_path):
|
25 |
+
# config = OmegaConf.load(config_path)
|
26 |
+
# model = instantiate_from_config(config.model).cpu()
|
27 |
+
# return model
|
28 |
+
#
|
29 |
+
# def get_state_dict(d):
|
30 |
+
# return d.get('state_dict', d)
|
31 |
+
#
|
32 |
+
# def load_state_dict(ckpt_path, location='cpu'):
|
33 |
+
# _, extension = os.path.splitext(ckpt_path)
|
34 |
+
# if extension.lower() == ".safetensors":
|
35 |
+
# import safetensors.torch
|
36 |
+
# state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
37 |
+
# else:
|
38 |
+
# state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
39 |
+
# state_dict = get_state_dict(state_dict)
|
40 |
+
# return state_dict
|
41 |
+
#
|
42 |
+
#
|
43 |
+
# def load_ldm_sd(model, path) :
|
44 |
+
# if path.endswith('.safetensor') :
|
45 |
+
# sd = safetensors.torch.load_file(path)
|
46 |
+
# else :
|
47 |
+
# sd = load_state_dict(path)
|
48 |
+
# model.load_state_dict(sd, strict = False)
|
49 |
+
#
|
50 |
+
# def fill_mask_input(image, mask):
|
51 |
+
# """fills masked regions with colors from image using blur. Not extremely effective."""
|
52 |
+
#
|
53 |
+
# image_mod = Image.new('RGBA', (image.width, image.height))
|
54 |
+
#
|
55 |
+
# image_masked = Image.new('RGBa', (image.width, image.height))
|
56 |
+
# image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
57 |
+
#
|
58 |
+
# image_masked = image_masked.convert('RGBa')
|
59 |
+
#
|
60 |
+
# for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
61 |
+
# blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
62 |
+
# for _ in range(repeats):
|
63 |
+
# image_mod.alpha_composite(blurred)
|
64 |
+
#
|
65 |
+
# return image_mod.convert("RGB")
|
66 |
+
#
|
67 |
+
#
|
68 |
+
# def get_inpainting_image_condition(model, image, mask) :
|
69 |
+
# conditioning_mask = np.array(mask.convert("L"))
|
70 |
+
# conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
71 |
+
# conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
72 |
+
# conditioning_mask = torch.round(conditioning_mask)
|
73 |
+
# conditioning_mask = conditioning_mask.to(device=image.device, dtype=image.dtype)
|
74 |
+
# conditioning_image = torch.lerp(
|
75 |
+
# image,
|
76 |
+
# image * (1.0 - conditioning_mask),
|
77 |
+
# 1
|
78 |
+
# )
|
79 |
+
# conditioning_image = model.get_first_stage_encoding(model.encode_first_stage(conditioning_image))
|
80 |
+
# conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:])
|
81 |
+
# conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
82 |
+
# image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
83 |
+
# return image_conditioning
|
84 |
+
#
|
85 |
+
#
|
86 |
+
# class GuidedLDM(LatentDiffusion):
|
87 |
+
# def __init__(self, *args, **kwargs):
|
88 |
+
# super().__init__(*args, **kwargs)
|
89 |
+
#
|
90 |
+
# @torch.no_grad()
|
91 |
+
# def img2img_inpaint(
|
92 |
+
# self,
|
93 |
+
# image: Image.Image,
|
94 |
+
# c_text: str,
|
95 |
+
# uc_text: str,
|
96 |
+
# mask: Image.Image,
|
97 |
+
# ddim_steps = 50,
|
98 |
+
# mask_blur: int = 0,
|
99 |
+
# use_cuda: bool = True,
|
100 |
+
# **kwargs) -> Image.Image :
|
101 |
+
# ddim_sampler = GuidedDDIMSample(self)
|
102 |
+
# if use_cuda :
|
103 |
+
# self.cond_stage_model.cuda()
|
104 |
+
# self.first_stage_model.cuda()
|
105 |
+
# c_text = self.get_learned_conditioning([c_text])
|
106 |
+
# uc_text = self.get_learned_conditioning([uc_text])
|
107 |
+
# cond = {"c_crossattn": [c_text]}
|
108 |
+
# uc_cond = {"c_crossattn": [uc_text]}
|
109 |
+
#
|
110 |
+
# if use_cuda :
|
111 |
+
# device = torch.device('cuda:0')
|
112 |
+
# else :
|
113 |
+
# device = torch.device('cpu')
|
114 |
+
#
|
115 |
+
# image_mask = mask
|
116 |
+
# image_mask = image_mask.convert('L')
|
117 |
+
# image_mask = image_mask.filter(ImageFilter.GaussianBlur(mask_blur))
|
118 |
+
# latent_mask = image_mask
|
119 |
+
# # image = fill_mask_input(image, latent_mask)
|
120 |
+
# # image.save('image_fill.png')
|
121 |
+
# image = np.array(image).astype(np.float32) / 127.5 - 1.0
|
122 |
+
# image = np.moveaxis(image, 2, 0)
|
123 |
+
# image = torch.from_numpy(image).to(device)[None]
|
124 |
+
# init_latent = self.get_first_stage_encoding(self.encode_first_stage(image))
|
125 |
+
# init_mask = latent_mask
|
126 |
+
# latmask = init_mask.convert('RGB').resize((init_latent.shape[3], init_latent.shape[2]))
|
127 |
+
# latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
128 |
+
# latmask = latmask[0]
|
129 |
+
# latmask = np.around(latmask)
|
130 |
+
# latmask = np.tile(latmask[None], (4, 1, 1))
|
131 |
+
# nmask = torch.asarray(latmask).to(init_latent.device).float()
|
132 |
+
# init_latent = (1 - nmask) * init_latent + nmask * torch.randn_like(init_latent)
|
133 |
+
#
|
134 |
+
# denoising_strength = 1
|
135 |
+
# if self.model.conditioning_key == 'hybrid' :
|
136 |
+
# image_cdt = get_inpainting_image_condition(self, image, image_mask)
|
137 |
+
# cond["c_concat"] = [image_cdt]
|
138 |
+
# uc_cond["c_concat"] = [image_cdt]
|
139 |
+
#
|
140 |
+
# steps = ddim_steps
|
141 |
+
# t_enc = int(min(denoising_strength, 0.999) * steps)
|
142 |
+
# eta = 0
|
143 |
+
#
|
144 |
+
# noise = torch.randn_like(init_latent)
|
145 |
+
# ddim_sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, ddim_discretize="uniform", verbose=False)
|
146 |
+
# x1 = ddim_sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * int(init_latent.shape[0])).to(device), noise=noise)
|
147 |
+
#
|
148 |
+
# if use_cuda :
|
149 |
+
# self.cond_stage_model.cpu()
|
150 |
+
# self.first_stage_model.cpu()
|
151 |
+
#
|
152 |
+
# if use_cuda :
|
153 |
+
# self.model.cuda()
|
154 |
+
# decoded = ddim_sampler.decode(x1, cond,t_enc,init_latent=init_latent,nmask=nmask,unconditional_guidance_scale=7,unconditional_conditioning=uc_cond)
|
155 |
+
# if use_cuda :
|
156 |
+
# self.model.cpu()
|
157 |
+
#
|
158 |
+
# if mask is not None :
|
159 |
+
# decoded = init_latent * (1 - nmask) + decoded * nmask
|
160 |
+
#
|
161 |
+
# if use_cuda :
|
162 |
+
# self.first_stage_model.cuda()
|
163 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
164 |
+
# x_samples = self.decode_first_stage(decoded.to(torch.float32))
|
165 |
+
# if use_cuda :
|
166 |
+
# self.first_stage_model.cpu()
|
167 |
+
# return torch.clip(x_samples, -1, 1)
|
168 |
+
#
|
169 |
+
#
|
170 |
+
#
|
171 |
+
# class GuidedDDIMSample(DDIMSampler) :
|
172 |
+
# def __init__(self, *args, **kwargs):
|
173 |
+
# super().__init__(*args, **kwargs)
|
174 |
+
#
|
175 |
+
# @torch.no_grad()
|
176 |
+
# def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
177 |
+
# temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
178 |
+
# unconditional_guidance_scale=1., unconditional_conditioning=None,
|
179 |
+
# dynamic_threshold=None):
|
180 |
+
# b, *_, device = *x.shape, x.device
|
181 |
+
#
|
182 |
+
# if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
183 |
+
# model_output = self.model.apply_model(x, t, c)
|
184 |
+
# else:
|
185 |
+
# x_in = torch.cat([x] * 2)
|
186 |
+
# t_in = torch.cat([t] * 2)
|
187 |
+
# if isinstance(c, dict):
|
188 |
+
# assert isinstance(unconditional_conditioning, dict)
|
189 |
+
# c_in = dict()
|
190 |
+
# for k in c:
|
191 |
+
# if isinstance(c[k], list):
|
192 |
+
# c_in[k] = [torch.cat([
|
193 |
+
# unconditional_conditioning[k][i],
|
194 |
+
# c[k][i]]) for i in range(len(c[k]))]
|
195 |
+
# else:
|
196 |
+
# c_in[k] = torch.cat([
|
197 |
+
# unconditional_conditioning[k],
|
198 |
+
# c[k]])
|
199 |
+
# elif isinstance(c, list):
|
200 |
+
# c_in = list()
|
201 |
+
# assert isinstance(unconditional_conditioning, list)
|
202 |
+
# for i in range(len(c)):
|
203 |
+
# c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
204 |
+
# else:
|
205 |
+
# c_in = torch.cat([unconditional_conditioning, c])
|
206 |
+
# model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
207 |
+
# model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
208 |
+
#
|
209 |
+
# e_t = model_output
|
210 |
+
#
|
211 |
+
# alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
212 |
+
# alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
213 |
+
# sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
214 |
+
# sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
215 |
+
# # select parameters corresponding to the currently considered timestep
|
216 |
+
# a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
217 |
+
# a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
218 |
+
# sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
219 |
+
# sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
220 |
+
#
|
221 |
+
# # current prediction for x_0
|
222 |
+
# pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
223 |
+
#
|
224 |
+
# # direction pointing to x_t
|
225 |
+
# dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
226 |
+
# noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
227 |
+
# if noise_dropout > 0.:
|
228 |
+
# noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
229 |
+
# x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
230 |
+
# return x_prev, pred_x0
|
231 |
+
#
|
232 |
+
# @torch.no_grad()
|
233 |
+
# def decode(self, x_latent, cond, t_start, init_latent=None, nmask=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
234 |
+
# use_original_steps=False, callback=None):
|
235 |
+
#
|
236 |
+
# timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
237 |
+
# total_steps = len(timesteps)
|
238 |
+
# timesteps = timesteps[:t_start]
|
239 |
+
#
|
240 |
+
# time_range = np.flip(timesteps)
|
241 |
+
# total_steps = timesteps.shape[0]
|
242 |
+
# print(f"Running Guided DDIM Sampling with {len(timesteps)} timesteps, t_start={t_start}")
|
243 |
+
# iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
244 |
+
# x_dec = x_latent
|
245 |
+
# for i, step in enumerate(iterator):
|
246 |
+
# p = (i + (total_steps - t_start) + 1) / (total_steps)
|
247 |
+
# index = total_steps - i - 1
|
248 |
+
# ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
249 |
+
# if nmask is not None :
|
250 |
+
# noised_input = self.model.q_sample(init_latent.to(x_latent.device), ts.to(x_latent.device))
|
251 |
+
# x_dec = (1 - nmask) * noised_input + nmask * x_dec
|
252 |
+
# x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
253 |
+
# unconditional_guidance_scale=unconditional_guidance_scale,
|
254 |
+
# unconditional_conditioning=unconditional_conditioning)
|
255 |
+
# if callback: callback(i)
|
256 |
+
# return x_dec
|
257 |
+
#
|
258 |
+
#
|
259 |
+
# def ldm_inpaint(model, img, mask, inpaint_size=720, pos_prompt='', neg_prompt = '', use_cuda=True):
|
260 |
+
# img_original = np.copy(img)
|
261 |
+
# im_h, im_w = img.shape[:2]
|
262 |
+
# img_resized, (pad_h, pad_w) = resize_pad2divisior(img, inpaint_size)
|
263 |
+
#
|
264 |
+
# mask_original = np.copy(mask)
|
265 |
+
# mask_original[mask_original < 127] = 0
|
266 |
+
# mask_original[mask_original >= 127] = 1
|
267 |
+
# mask_original = mask_original[:, :, None]
|
268 |
+
# mask, _ = resize_pad2divisior(mask, inpaint_size)
|
269 |
+
#
|
270 |
+
# # cv2.imwrite('img_resized.png', img_resized)
|
271 |
+
# # cv2.imwrite('mask_resized.png', mask)
|
272 |
+
#
|
273 |
+
#
|
274 |
+
# if use_cuda :
|
275 |
+
# with torch.autocast(enabled = True, device_type = 'cuda') :
|
276 |
+
# img = model.img2img_inpaint(
|
277 |
+
# image = Image.fromarray(img_resized),
|
278 |
+
# c_text = pos_prompt,
|
279 |
+
# uc_text = neg_prompt,
|
280 |
+
# mask = Image.fromarray(mask),
|
281 |
+
# use_cuda = True
|
282 |
+
# )
|
283 |
+
# else :
|
284 |
+
# img = model.img2img_inpaint(
|
285 |
+
# image = Image.fromarray(img_resized),
|
286 |
+
# c_text = pos_prompt,
|
287 |
+
# uc_text = neg_prompt,
|
288 |
+
# mask = Image.fromarray(mask),
|
289 |
+
# use_cuda = False
|
290 |
+
# )
|
291 |
+
#
|
292 |
+
# img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
|
293 |
+
# if pad_h != 0:
|
294 |
+
# img_inpainted = img_inpainted[:-pad_h]
|
295 |
+
# if pad_w != 0:
|
296 |
+
# img_inpainted = img_inpainted[:, :-pad_w]
|
297 |
+
#
|
298 |
+
#
|
299 |
+
# if img_inpainted.shape[0] != im_h or img_inpainted.shape[1] != im_w:
|
300 |
+
# img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR)
|
301 |
+
# ans = img_inpainted * mask_original + img_original * (1 - mask_original)
|
302 |
+
# ans = img_inpainted
|
303 |
+
# return ans
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
import requests
|
309 |
+
from PIL import Image
|
310 |
+
def ldm_inpaint_webui(
|
311 |
+
img, mask, resolution: int, url: str, prompt: str = '', neg_prompt: str = '',
|
312 |
+
**inpaint_ldm_options):
|
313 |
+
if isinstance(img, np.ndarray):
|
314 |
+
img = Image.fromarray(img)
|
315 |
+
|
316 |
+
im_h, im_w = img.height, img.width
|
317 |
+
|
318 |
+
if img.height > img.width:
|
319 |
+
W = resolution
|
320 |
+
H = (img.height / img.width * resolution) // 32 * 32
|
321 |
+
H = int(H)
|
322 |
+
else:
|
323 |
+
H = resolution
|
324 |
+
W = (img.width / img.height * resolution) // 32 * 32
|
325 |
+
W = int(W)
|
326 |
+
|
327 |
+
auth = None
|
328 |
+
if 'username' in inpaint_ldm_options:
|
329 |
+
username = inpaint_ldm_options.pop('username')
|
330 |
+
password = inpaint_ldm_options.pop('password')
|
331 |
+
auth = HTTPBasicAuth(username, password)
|
332 |
+
|
333 |
+
img_b64 = img2b64(img)
|
334 |
+
mask_b64 = img2b64(mask)
|
335 |
+
data = {
|
336 |
+
"init_images": [img_b64],
|
337 |
+
"mask": mask_b64,
|
338 |
+
"prompt": prompt,
|
339 |
+
"negative_prompt": neg_prompt,
|
340 |
+
"width": W,
|
341 |
+
"height": H,
|
342 |
+
**inpaint_ldm_options,
|
343 |
+
}
|
344 |
+
data = json.dumps(data)
|
345 |
+
|
346 |
+
response = submit_request(url, data, auth=auth)
|
347 |
+
|
348 |
+
inpainted_b64 = response.json()['images'][0]
|
349 |
+
inpainted = Image.open(io.BytesIO(base64.b64decode(inpainted_b64)))
|
350 |
+
if inpainted.height != im_h or inpainted.width != im_w:
|
351 |
+
inpainted = inpainted.resize((im_w, im_h), resample=Image.Resampling.LANCZOS)
|
352 |
+
inpainted = np.array(inpainted)
|
353 |
+
return inpainted
|
animeinsseg/inpainting/patch_match.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : patch_match.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : maojiayuan@gmail.com
|
6 |
+
# Date : 01/09/2020
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
import ctypes, os
|
11 |
+
import os.path as osp
|
12 |
+
from typing import Optional, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
# try:
|
18 |
+
# # If the Jacinle library (https://github.com/vacancy/Jacinle) is present, use its auto_travis feature.
|
19 |
+
# from jacinle.jit.cext import auto_travis
|
20 |
+
# auto_travis(__file__, required_files=['*.so'])
|
21 |
+
# except ImportError as e:
|
22 |
+
# # Otherwise, fall back to the subprocess.
|
23 |
+
# import subprocess
|
24 |
+
# print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
|
25 |
+
# subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
|
26 |
+
|
27 |
+
|
28 |
+
__all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
|
29 |
+
|
30 |
+
|
31 |
+
class CShapeT(ctypes.Structure):
|
32 |
+
_fields_ = [
|
33 |
+
('width', ctypes.c_int),
|
34 |
+
('height', ctypes.c_int),
|
35 |
+
('channels', ctypes.c_int),
|
36 |
+
]
|
37 |
+
|
38 |
+
class CMatT(ctypes.Structure):
|
39 |
+
_fields_ = [
|
40 |
+
('data_ptr', ctypes.c_void_p),
|
41 |
+
('shape', CShapeT),
|
42 |
+
('dtype', ctypes.c_int)
|
43 |
+
]
|
44 |
+
|
45 |
+
import sys
|
46 |
+
if sys.platform == 'linux':
|
47 |
+
PMLIB = ctypes.CDLL('data/libs/libpatchmatch_inpaint.so')
|
48 |
+
else:
|
49 |
+
PMLIB = ctypes.CDLL('data/libs/libpatchmatch.dll')
|
50 |
+
|
51 |
+
PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
|
52 |
+
PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
|
53 |
+
PMLIB.PM_free_pymat.argtypes = [CMatT]
|
54 |
+
PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
|
55 |
+
PMLIB.PM_inpaint.restype = CMatT
|
56 |
+
PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
57 |
+
PMLIB.PM_inpaint_regularity.restype = CMatT
|
58 |
+
PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
|
59 |
+
PMLIB.PM_inpaint2.restype = CMatT
|
60 |
+
PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
61 |
+
PMLIB.PM_inpaint2_regularity.restype = CMatT
|
62 |
+
|
63 |
+
|
64 |
+
def set_random_seed(seed: int):
|
65 |
+
PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
|
66 |
+
|
67 |
+
|
68 |
+
def set_verbose(verbose: bool):
|
69 |
+
PMLIB.PM_set_verbose(ctypes.c_int(verbose))
|
70 |
+
|
71 |
+
|
72 |
+
def inpaint(
|
73 |
+
image: Union[np.ndarray, Image.Image],
|
74 |
+
mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
75 |
+
*,
|
76 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
77 |
+
patch_size: int = 15
|
78 |
+
) -> np.ndarray:
|
79 |
+
"""
|
80 |
+
PatchMatch based inpainting proposed in:
|
81 |
+
|
82 |
+
PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
|
83 |
+
C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
|
84 |
+
SIGGRAPH 2009
|
85 |
+
|
86 |
+
Args:
|
87 |
+
image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
|
88 |
+
mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
|
89 |
+
If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
|
90 |
+
global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
|
91 |
+
patch_size (int): the patch size for the inpainting algorithm.
|
92 |
+
|
93 |
+
Return:
|
94 |
+
result (np.ndarray): the repaired image, of the same size as the input image.
|
95 |
+
"""
|
96 |
+
|
97 |
+
if isinstance(image, Image.Image):
|
98 |
+
image = np.array(image)
|
99 |
+
image = np.ascontiguousarray(image)
|
100 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
101 |
+
|
102 |
+
if mask is None:
|
103 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
104 |
+
mask = np.ascontiguousarray(mask)
|
105 |
+
else:
|
106 |
+
mask = _canonize_mask_array(mask)
|
107 |
+
|
108 |
+
if global_mask is None:
|
109 |
+
ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
|
110 |
+
else:
|
111 |
+
global_mask = _canonize_mask_array(global_mask)
|
112 |
+
ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
|
113 |
+
|
114 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
115 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
116 |
+
|
117 |
+
return ret_npmat
|
118 |
+
|
119 |
+
|
120 |
+
def inpaint_regularity(
|
121 |
+
image: Union[np.ndarray, Image.Image],
|
122 |
+
mask: Optional[Union[np.ndarray, Image.Image]],
|
123 |
+
ijmap: np.ndarray,
|
124 |
+
*,
|
125 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
126 |
+
patch_size: int = 15, guide_weight: float = 0.25
|
127 |
+
) -> np.ndarray:
|
128 |
+
if isinstance(image, Image.Image):
|
129 |
+
image = np.array(image)
|
130 |
+
image = np.ascontiguousarray(image)
|
131 |
+
|
132 |
+
assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
|
133 |
+
ijmap = np.ascontiguousarray(ijmap)
|
134 |
+
|
135 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
136 |
+
if mask is None:
|
137 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
138 |
+
mask = np.ascontiguousarray(mask)
|
139 |
+
else:
|
140 |
+
mask = _canonize_mask_array(mask)
|
141 |
+
|
142 |
+
|
143 |
+
if global_mask is None:
|
144 |
+
ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
145 |
+
else:
|
146 |
+
global_mask = _canonize_mask_array(global_mask)
|
147 |
+
ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
148 |
+
|
149 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
150 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
151 |
+
|
152 |
+
return ret_npmat
|
153 |
+
|
154 |
+
|
155 |
+
def _canonize_mask_array(mask):
|
156 |
+
if isinstance(mask, Image.Image):
|
157 |
+
mask = np.array(mask)
|
158 |
+
if mask.ndim == 2 and mask.dtype == 'uint8':
|
159 |
+
mask = mask[..., np.newaxis]
|
160 |
+
assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
|
161 |
+
return np.ascontiguousarray(mask)
|
162 |
+
|
163 |
+
|
164 |
+
dtype_pymat_to_ctypes = [
|
165 |
+
ctypes.c_uint8,
|
166 |
+
ctypes.c_int8,
|
167 |
+
ctypes.c_uint16,
|
168 |
+
ctypes.c_int16,
|
169 |
+
ctypes.c_int32,
|
170 |
+
ctypes.c_float,
|
171 |
+
ctypes.c_double,
|
172 |
+
]
|
173 |
+
|
174 |
+
|
175 |
+
dtype_np_to_pymat = {
|
176 |
+
'uint8': 0,
|
177 |
+
'int8': 1,
|
178 |
+
'uint16': 2,
|
179 |
+
'int16': 3,
|
180 |
+
'int32': 4,
|
181 |
+
'float32': 5,
|
182 |
+
'float64': 6,
|
183 |
+
}
|
184 |
+
|
185 |
+
|
186 |
+
def np_to_pymat(npmat):
|
187 |
+
assert npmat.ndim == 3
|
188 |
+
return CMatT(
|
189 |
+
ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
|
190 |
+
CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
|
191 |
+
dtype_np_to_pymat[str(npmat.dtype)]
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
def pymat_to_np(pymat):
|
196 |
+
npmat = np.ctypeslib.as_array(
|
197 |
+
ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
|
198 |
+
(pymat.shape.height, pymat.shape.width, pymat.shape.channels)
|
199 |
+
)
|
200 |
+
ret = np.empty(npmat.shape, npmat.dtype)
|
201 |
+
ret[:] = npmat
|
202 |
+
return ret
|
203 |
+
|
animeinsseg/models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
|
7 |
+
|
animeinsseg/models/animeseg_refine/__init__.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/SkyTNT/anime-segmentation/blob/main/train.py
|
2 |
+
import os
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from pytorch_lightning import Trainer
|
8 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
import torch.optim as optim
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from torch.cuda import amp
|
14 |
+
|
15 |
+
from utils.constants import DEFAULT_DEVICE
|
16 |
+
# from data_loader import create_training_datasets
|
17 |
+
|
18 |
+
|
19 |
+
import pytorch_lightning as pl
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
from .isnet import ISNetDIS, ISNetGTEncoder
|
23 |
+
from .u2net import U2NET, U2NET_full, U2NET_full2, U2NET_lite2
|
24 |
+
from .modnet import MODNet
|
25 |
+
|
26 |
+
# warnings.filterwarnings("ignore")
|
27 |
+
|
28 |
+
def get_net(net_name):
|
29 |
+
if net_name == "isnet":
|
30 |
+
return ISNetDIS()
|
31 |
+
elif net_name == "isnet_is":
|
32 |
+
return ISNetDIS()
|
33 |
+
elif net_name == "isnet_gt":
|
34 |
+
return ISNetGTEncoder()
|
35 |
+
elif net_name == "u2net":
|
36 |
+
return U2NET_full2()
|
37 |
+
elif net_name == "u2netl":
|
38 |
+
return U2NET_lite2()
|
39 |
+
elif net_name == "modnet":
|
40 |
+
return MODNet()
|
41 |
+
raise NotImplemented
|
42 |
+
|
43 |
+
|
44 |
+
def f1_torch(pred, gt):
|
45 |
+
# micro F1-score
|
46 |
+
pred = pred.float().view(pred.shape[0], -1)
|
47 |
+
gt = gt.float().view(gt.shape[0], -1)
|
48 |
+
tp1 = torch.sum(pred * gt, dim=1)
|
49 |
+
tp_fp1 = torch.sum(pred, dim=1)
|
50 |
+
tp_fn1 = torch.sum(gt, dim=1)
|
51 |
+
pred = 1 - pred
|
52 |
+
gt = 1 - gt
|
53 |
+
tp2 = torch.sum(pred * gt, dim=1)
|
54 |
+
tp_fp2 = torch.sum(pred, dim=1)
|
55 |
+
tp_fn2 = torch.sum(gt, dim=1)
|
56 |
+
precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001)
|
57 |
+
recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001)
|
58 |
+
f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001)
|
59 |
+
return precision, recall, f1
|
60 |
+
|
61 |
+
|
62 |
+
class AnimeSegmentation(pl.LightningModule):
|
63 |
+
|
64 |
+
def __init__(self, net_name):
|
65 |
+
super().__init__()
|
66 |
+
assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"]
|
67 |
+
self.net = get_net(net_name)
|
68 |
+
if net_name == "isnet_is":
|
69 |
+
self.gt_encoder = get_net("isnet_gt")
|
70 |
+
self.gt_encoder.requires_grad_(False)
|
71 |
+
else:
|
72 |
+
self.gt_encoder = None
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def try_load(cls, net_name, ckpt_path, map_location=None):
|
76 |
+
state_dict = torch.load(ckpt_path, map_location=map_location)
|
77 |
+
if "epoch" in state_dict:
|
78 |
+
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location)
|
79 |
+
else:
|
80 |
+
model = cls(net_name)
|
81 |
+
if any([k.startswith("net.") for k, v in state_dict.items()]):
|
82 |
+
model.load_state_dict(state_dict)
|
83 |
+
else:
|
84 |
+
model.net.load_state_dict(state_dict)
|
85 |
+
return model
|
86 |
+
|
87 |
+
def configure_optimizers(self):
|
88 |
+
optimizer = optim.Adam(self.net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
89 |
+
return optimizer
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if isinstance(self.net, ISNetDIS):
|
93 |
+
return self.net(x)[0][0].sigmoid()
|
94 |
+
if isinstance(self.net, ISNetGTEncoder):
|
95 |
+
return self.net(x)[0][0].sigmoid()
|
96 |
+
elif isinstance(self.net, U2NET):
|
97 |
+
return self.net(x)[0].sigmoid()
|
98 |
+
elif isinstance(self.net, MODNet):
|
99 |
+
return self.net(x, True)[2]
|
100 |
+
raise NotImplemented
|
101 |
+
|
102 |
+
def training_step(self, batch, batch_idx):
|
103 |
+
images, labels = batch["image"], batch["label"]
|
104 |
+
if isinstance(self.net, ISNetDIS):
|
105 |
+
ds, dfs = self.net(images)
|
106 |
+
loss_args = [ds, dfs, labels]
|
107 |
+
elif isinstance(self.net, ISNetGTEncoder):
|
108 |
+
ds = self.net(labels)[0]
|
109 |
+
loss_args = [ds, labels]
|
110 |
+
elif isinstance(self.net, U2NET):
|
111 |
+
ds = self.net(images)
|
112 |
+
loss_args = [ds, labels]
|
113 |
+
elif isinstance(self.net, MODNet):
|
114 |
+
trimaps = batch["trimap"]
|
115 |
+
pred_semantic, pred_detail, pred_matte = self.net(images, False)
|
116 |
+
loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels]
|
117 |
+
else:
|
118 |
+
raise NotImplemented
|
119 |
+
if self.gt_encoder is not None:
|
120 |
+
fs = self.gt_encoder(labels)[1]
|
121 |
+
loss_args.append(fs)
|
122 |
+
|
123 |
+
loss0, loss = self.net.compute_loss(loss_args)
|
124 |
+
self.log_dict({"train/loss": loss, "train/loss_tar": loss0})
|
125 |
+
return loss
|
126 |
+
|
127 |
+
def validation_step(self, batch, batch_idx):
|
128 |
+
images, labels = batch["image"], batch["label"]
|
129 |
+
if isinstance(self.net, ISNetGTEncoder):
|
130 |
+
preds = self.forward(labels)
|
131 |
+
else:
|
132 |
+
preds = self.forward(images)
|
133 |
+
pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels)
|
134 |
+
mae_m = F.l1_loss(preds, labels, reduction="mean")
|
135 |
+
pre_m = pre.mean()
|
136 |
+
rec_m = rec.mean()
|
137 |
+
f1_m = f1.mean()
|
138 |
+
self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True)
|
139 |
+
|
140 |
+
|
141 |
+
def get_gt_encoder(train_dataloader, val_dataloader, opt):
|
142 |
+
print("---start train ground truth encoder---")
|
143 |
+
gt_encoder = AnimeSegmentation("isnet_gt")
|
144 |
+
trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
|
145 |
+
devices=opt.devices, max_epochs=opt.gt_epoch,
|
146 |
+
benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
|
147 |
+
check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
|
148 |
+
strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
|
149 |
+
)
|
150 |
+
trainer.fit(gt_encoder, train_dataloader, val_dataloader)
|
151 |
+
return gt_encoder.net
|
152 |
+
|
153 |
+
|
154 |
+
def load_refinenet(refine_method = 'animeseg', device: str = None) -> AnimeSegmentation:
|
155 |
+
if device is None:
|
156 |
+
device = DEFAULT_DEVICE
|
157 |
+
if refine_method == 'animeseg':
|
158 |
+
model = AnimeSegmentation.try_load('isnet_is', 'models/anime-seg/isnetis.ckpt', device)
|
159 |
+
elif refine_method == 'refinenet_isnet':
|
160 |
+
model = ISNetDIS(in_ch=4)
|
161 |
+
sd = torch.load('models/AnimeInstanceSegmentation/refine_last.ckpt', map_location='cpu')
|
162 |
+
# sd = torch.load('models/AnimeInstanceSegmentation/refine_noweight_dist.ckpt', map_location='cpu')
|
163 |
+
# sd = torch.load('models/AnimeInstanceSegmentation/refine_f3loss.ckpt', map_location='cpu')
|
164 |
+
model.load_state_dict(sd)
|
165 |
+
else:
|
166 |
+
raise NotImplementedError
|
167 |
+
return model.eval().to(device)
|
168 |
+
|
169 |
+
def get_mask(model, input_img, use_amp=True, s=640):
|
170 |
+
h0, w0 = h, w = input_img.shape[0], input_img.shape[1]
|
171 |
+
if h > w:
|
172 |
+
h, w = s, int(s * w / h)
|
173 |
+
else:
|
174 |
+
h, w = int(s * h / w), s
|
175 |
+
ph, pw = s - h, s - w
|
176 |
+
tmpImg = np.zeros([s, s, 3], dtype=np.float32)
|
177 |
+
tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
|
178 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
179 |
+
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
|
180 |
+
with torch.no_grad():
|
181 |
+
if use_amp:
|
182 |
+
with amp.autocast():
|
183 |
+
pred = model(tmpImg)
|
184 |
+
pred = pred.to(dtype=torch.float32)
|
185 |
+
else:
|
186 |
+
pred = model(tmpImg)
|
187 |
+
pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
188 |
+
pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
|
189 |
+
return pred
|
animeinsseg/models/animeseg_refine/encoders.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.checkpoint import checkpoint
|
4 |
+
|
5 |
+
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
6 |
+
|
7 |
+
|
8 |
+
class AbstractEncoder(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
def encode(self, *args, **kwargs):
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
|
16 |
+
class IdentityEncoder(AbstractEncoder):
|
17 |
+
|
18 |
+
def encode(self, x):
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class ClassEmbedder(nn.Module):
|
23 |
+
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
24 |
+
super().__init__()
|
25 |
+
self.key = key
|
26 |
+
self.embedding = nn.Embedding(n_classes, embed_dim)
|
27 |
+
self.n_classes = n_classes
|
28 |
+
self.ucg_rate = ucg_rate
|
29 |
+
|
30 |
+
def forward(self, batch, key=None, disable_dropout=False):
|
31 |
+
if key is None:
|
32 |
+
key = self.key
|
33 |
+
# this is for use in crossattn
|
34 |
+
c = batch[key][:, None]
|
35 |
+
if self.ucg_rate > 0. and not disable_dropout:
|
36 |
+
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
37 |
+
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
38 |
+
c = c.long()
|
39 |
+
c = self.embedding(c)
|
40 |
+
return c
|
41 |
+
|
42 |
+
def get_unconditional_conditioning(self, bs, device="cuda"):
|
43 |
+
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
44 |
+
uc = torch.ones((bs,), device=device) * uc_class
|
45 |
+
uc = {self.key: uc}
|
46 |
+
return uc
|
47 |
+
|
48 |
+
|
49 |
+
class DanbooruEmbedder(AbstractEncoder):
|
50 |
+
def __init__(self):
|
51 |
+
super().__init__()
|
animeinsseg/models/animeseg_refine/isnet.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Codes are borrowed from
|
2 |
+
# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision import models
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
_bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
|
10 |
+
_bce_loss_none = nn.BCEWithLogitsLoss(reduction='none')
|
11 |
+
|
12 |
+
def bce_loss(p, t, weights=None):
|
13 |
+
if weights is None:
|
14 |
+
return _bce_loss(p, t)
|
15 |
+
else:
|
16 |
+
loss = _bce_loss_none(p, t)
|
17 |
+
loss = loss * weights
|
18 |
+
return loss.mean()
|
19 |
+
|
20 |
+
|
21 |
+
_fea_loss = nn.MSELoss(reduction="mean")
|
22 |
+
_fea_loss_none = nn.MSELoss(reduction="none")
|
23 |
+
|
24 |
+
def fea_loss(p, t, weights=None):
|
25 |
+
return _fea_loss(p, t)
|
26 |
+
|
27 |
+
kl_loss = nn.KLDivLoss(reduction="mean")
|
28 |
+
l1_loss = nn.L1Loss(reduction="mean")
|
29 |
+
smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
|
30 |
+
|
31 |
+
|
32 |
+
def structure_loss(pred, mask):
|
33 |
+
weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7)-mask)
|
34 |
+
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
|
35 |
+
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
|
36 |
+
|
37 |
+
pred = torch.sigmoid(pred)
|
38 |
+
inter = ((pred*mask)*weit).sum(dim=(2,3))
|
39 |
+
union = ((pred+mask)*weit).sum(dim=(2,3))
|
40 |
+
wiou = 1-(inter+1)/(union-inter+1)
|
41 |
+
return (wbce+wiou).mean()
|
42 |
+
|
43 |
+
|
44 |
+
def muti_loss_fusion(preds, target, dist_weight=None, loss0_weight=1.0):
|
45 |
+
loss0 = 0.0
|
46 |
+
loss = 0.0
|
47 |
+
|
48 |
+
for i in range(0, len(preds)):
|
49 |
+
weight = dist_weight if i == 0 else None
|
50 |
+
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
51 |
+
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
52 |
+
loss = loss + structure_loss(preds[i], tmp_target)
|
53 |
+
else:
|
54 |
+
# loss = loss + bce_loss(preds[i], target, weight)
|
55 |
+
loss = loss + structure_loss(preds[i], target)
|
56 |
+
if i == 0:
|
57 |
+
loss *= loss0_weight
|
58 |
+
loss0 = loss
|
59 |
+
return loss0, loss
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE', dist_weight=None, loss0_weight=1.0):
|
64 |
+
loss0 = 0.0
|
65 |
+
loss = 0.0
|
66 |
+
|
67 |
+
for i in range(0, len(preds)):
|
68 |
+
weight = dist_weight if i == 0 else None
|
69 |
+
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
70 |
+
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
71 |
+
# loss = loss + bce_loss(preds[i], tmp_target, weight)
|
72 |
+
loss = loss + structure_loss(preds[i], tmp_target)
|
73 |
+
else:
|
74 |
+
# loss = loss + bce_loss(preds[i], target, weight)
|
75 |
+
loss = loss + structure_loss(preds[i], target)
|
76 |
+
if i == 0:
|
77 |
+
loss *= loss0_weight
|
78 |
+
loss0 = loss
|
79 |
+
|
80 |
+
for i in range(0, len(dfs)):
|
81 |
+
df = dfs[i]
|
82 |
+
fs_i = fs[i]
|
83 |
+
if mode == 'MSE':
|
84 |
+
loss = loss + fea_loss(df, fs_i, dist_weight) ### add the mse loss of features as additional constraints
|
85 |
+
elif mode == 'KL':
|
86 |
+
loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
|
87 |
+
elif mode == 'MAE':
|
88 |
+
loss = loss + l1_loss(df, fs_i)
|
89 |
+
elif mode == 'SmoothL1':
|
90 |
+
loss = loss + smooth_l1_loss(df, fs_i)
|
91 |
+
|
92 |
+
return loss0, loss
|
93 |
+
|
94 |
+
|
95 |
+
class REBNCONV(nn.Module):
|
96 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
97 |
+
super(REBNCONV, self).__init__()
|
98 |
+
|
99 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
|
100 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
101 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
hx = x
|
105 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
106 |
+
|
107 |
+
return xout
|
108 |
+
|
109 |
+
|
110 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
111 |
+
def _upsample_like(src, tar):
|
112 |
+
src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
|
113 |
+
|
114 |
+
return src
|
115 |
+
|
116 |
+
|
117 |
+
### RSU-7 ###
|
118 |
+
class RSU7(nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
121 |
+
super(RSU7, self).__init__()
|
122 |
+
|
123 |
+
self.in_ch = in_ch
|
124 |
+
self.mid_ch = mid_ch
|
125 |
+
self.out_ch = out_ch
|
126 |
+
|
127 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
128 |
+
|
129 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
130 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
134 |
+
|
135 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
136 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
137 |
+
|
138 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
139 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
140 |
+
|
141 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
142 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
143 |
+
|
144 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
145 |
+
|
146 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
147 |
+
|
148 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
149 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
150 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
151 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
152 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
153 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
b, c, h, w = x.shape
|
157 |
+
|
158 |
+
hx = x
|
159 |
+
hxin = self.rebnconvin(hx)
|
160 |
+
|
161 |
+
hx1 = self.rebnconv1(hxin)
|
162 |
+
hx = self.pool1(hx1)
|
163 |
+
|
164 |
+
hx2 = self.rebnconv2(hx)
|
165 |
+
hx = self.pool2(hx2)
|
166 |
+
|
167 |
+
hx3 = self.rebnconv3(hx)
|
168 |
+
hx = self.pool3(hx3)
|
169 |
+
|
170 |
+
hx4 = self.rebnconv4(hx)
|
171 |
+
hx = self.pool4(hx4)
|
172 |
+
|
173 |
+
hx5 = self.rebnconv5(hx)
|
174 |
+
hx = self.pool5(hx5)
|
175 |
+
|
176 |
+
hx6 = self.rebnconv6(hx)
|
177 |
+
|
178 |
+
hx7 = self.rebnconv7(hx6)
|
179 |
+
|
180 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
181 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
182 |
+
|
183 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
184 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
185 |
+
|
186 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
187 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
188 |
+
|
189 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
190 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
191 |
+
|
192 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
193 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
194 |
+
|
195 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
196 |
+
|
197 |
+
return hx1d + hxin
|
198 |
+
|
199 |
+
|
200 |
+
### RSU-6 ###
|
201 |
+
class RSU6(nn.Module):
|
202 |
+
|
203 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
204 |
+
super(RSU6, self).__init__()
|
205 |
+
|
206 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
207 |
+
|
208 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
209 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
210 |
+
|
211 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
212 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
213 |
+
|
214 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
215 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
216 |
+
|
217 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
218 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
219 |
+
|
220 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
221 |
+
|
222 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
223 |
+
|
224 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
225 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
226 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
227 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
228 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
hx = x
|
232 |
+
|
233 |
+
hxin = self.rebnconvin(hx)
|
234 |
+
|
235 |
+
hx1 = self.rebnconv1(hxin)
|
236 |
+
hx = self.pool1(hx1)
|
237 |
+
|
238 |
+
hx2 = self.rebnconv2(hx)
|
239 |
+
hx = self.pool2(hx2)
|
240 |
+
|
241 |
+
hx3 = self.rebnconv3(hx)
|
242 |
+
hx = self.pool3(hx3)
|
243 |
+
|
244 |
+
hx4 = self.rebnconv4(hx)
|
245 |
+
hx = self.pool4(hx4)
|
246 |
+
|
247 |
+
hx5 = self.rebnconv5(hx)
|
248 |
+
|
249 |
+
hx6 = self.rebnconv6(hx5)
|
250 |
+
|
251 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
252 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
253 |
+
|
254 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
255 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
256 |
+
|
257 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
258 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
259 |
+
|
260 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
261 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
262 |
+
|
263 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
264 |
+
|
265 |
+
return hx1d + hxin
|
266 |
+
|
267 |
+
|
268 |
+
### RSU-5 ###
|
269 |
+
class RSU5(nn.Module):
|
270 |
+
|
271 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
272 |
+
super(RSU5, self).__init__()
|
273 |
+
|
274 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
275 |
+
|
276 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
277 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
278 |
+
|
279 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
280 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
281 |
+
|
282 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
283 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
284 |
+
|
285 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
286 |
+
|
287 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
288 |
+
|
289 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
290 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
291 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
292 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
hx = x
|
296 |
+
|
297 |
+
hxin = self.rebnconvin(hx)
|
298 |
+
|
299 |
+
hx1 = self.rebnconv1(hxin)
|
300 |
+
hx = self.pool1(hx1)
|
301 |
+
|
302 |
+
hx2 = self.rebnconv2(hx)
|
303 |
+
hx = self.pool2(hx2)
|
304 |
+
|
305 |
+
hx3 = self.rebnconv3(hx)
|
306 |
+
hx = self.pool3(hx3)
|
307 |
+
|
308 |
+
hx4 = self.rebnconv4(hx)
|
309 |
+
|
310 |
+
hx5 = self.rebnconv5(hx4)
|
311 |
+
|
312 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
313 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
314 |
+
|
315 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
316 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
317 |
+
|
318 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
319 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
320 |
+
|
321 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
322 |
+
|
323 |
+
return hx1d + hxin
|
324 |
+
|
325 |
+
|
326 |
+
### RSU-4 ###
|
327 |
+
class RSU4(nn.Module):
|
328 |
+
|
329 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
330 |
+
super(RSU4, self).__init__()
|
331 |
+
|
332 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
333 |
+
|
334 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
335 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
336 |
+
|
337 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
338 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
339 |
+
|
340 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
341 |
+
|
342 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
343 |
+
|
344 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
345 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
346 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
347 |
+
|
348 |
+
def forward(self, x):
|
349 |
+
hx = x
|
350 |
+
|
351 |
+
hxin = self.rebnconvin(hx)
|
352 |
+
|
353 |
+
hx1 = self.rebnconv1(hxin)
|
354 |
+
hx = self.pool1(hx1)
|
355 |
+
|
356 |
+
hx2 = self.rebnconv2(hx)
|
357 |
+
hx = self.pool2(hx2)
|
358 |
+
|
359 |
+
hx3 = self.rebnconv3(hx)
|
360 |
+
|
361 |
+
hx4 = self.rebnconv4(hx3)
|
362 |
+
|
363 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
364 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
365 |
+
|
366 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
367 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
368 |
+
|
369 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
370 |
+
|
371 |
+
return hx1d + hxin
|
372 |
+
|
373 |
+
|
374 |
+
### RSU-4F ###
|
375 |
+
class RSU4F(nn.Module):
|
376 |
+
|
377 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
378 |
+
super(RSU4F, self).__init__()
|
379 |
+
|
380 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
381 |
+
|
382 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
383 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
384 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
385 |
+
|
386 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
387 |
+
|
388 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
389 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
390 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
391 |
+
|
392 |
+
def forward(self, x):
|
393 |
+
hx = x
|
394 |
+
|
395 |
+
hxin = self.rebnconvin(hx)
|
396 |
+
|
397 |
+
hx1 = self.rebnconv1(hxin)
|
398 |
+
hx2 = self.rebnconv2(hx1)
|
399 |
+
hx3 = self.rebnconv3(hx2)
|
400 |
+
|
401 |
+
hx4 = self.rebnconv4(hx3)
|
402 |
+
|
403 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
404 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
405 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
406 |
+
|
407 |
+
return hx1d + hxin
|
408 |
+
|
409 |
+
|
410 |
+
class myrebnconv(nn.Module):
|
411 |
+
def __init__(self, in_ch=3,
|
412 |
+
out_ch=1,
|
413 |
+
kernel_size=3,
|
414 |
+
stride=1,
|
415 |
+
padding=1,
|
416 |
+
dilation=1,
|
417 |
+
groups=1):
|
418 |
+
super(myrebnconv, self).__init__()
|
419 |
+
|
420 |
+
self.conv = nn.Conv2d(in_ch,
|
421 |
+
out_ch,
|
422 |
+
kernel_size=kernel_size,
|
423 |
+
stride=stride,
|
424 |
+
padding=padding,
|
425 |
+
dilation=dilation,
|
426 |
+
groups=groups)
|
427 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
428 |
+
self.rl = nn.ReLU(inplace=True)
|
429 |
+
|
430 |
+
def forward(self, x):
|
431 |
+
return self.rl(self.bn(self.conv(x)))
|
432 |
+
|
433 |
+
|
434 |
+
class ISNetGTEncoder(nn.Module):
|
435 |
+
|
436 |
+
def __init__(self, in_ch=1, out_ch=1):
|
437 |
+
super(ISNetGTEncoder, self).__init__()
|
438 |
+
|
439 |
+
self.conv_in = myrebnconv(in_ch, 16, 3, stride=2, padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
440 |
+
|
441 |
+
self.stage1 = RSU7(16, 16, 64)
|
442 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
443 |
+
|
444 |
+
self.stage2 = RSU6(64, 16, 64)
|
445 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
446 |
+
|
447 |
+
self.stage3 = RSU5(64, 32, 128)
|
448 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
449 |
+
|
450 |
+
self.stage4 = RSU4(128, 32, 256)
|
451 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
452 |
+
|
453 |
+
self.stage5 = RSU4F(256, 64, 512)
|
454 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
455 |
+
|
456 |
+
self.stage6 = RSU4F(512, 64, 512)
|
457 |
+
|
458 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
459 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
460 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
461 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
462 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
463 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
464 |
+
|
465 |
+
@staticmethod
|
466 |
+
def compute_loss(args, dist_weight=None):
|
467 |
+
preds, targets = args
|
468 |
+
return muti_loss_fusion(preds, targets, dist_weight)
|
469 |
+
|
470 |
+
def forward(self, x):
|
471 |
+
hx = x
|
472 |
+
|
473 |
+
hxin = self.conv_in(hx)
|
474 |
+
# hx = self.pool_in(hxin)
|
475 |
+
|
476 |
+
# stage 1
|
477 |
+
hx1 = self.stage1(hxin)
|
478 |
+
hx = self.pool12(hx1)
|
479 |
+
|
480 |
+
# stage 2
|
481 |
+
hx2 = self.stage2(hx)
|
482 |
+
hx = self.pool23(hx2)
|
483 |
+
|
484 |
+
# stage 3
|
485 |
+
hx3 = self.stage3(hx)
|
486 |
+
hx = self.pool34(hx3)
|
487 |
+
|
488 |
+
# stage 4
|
489 |
+
hx4 = self.stage4(hx)
|
490 |
+
hx = self.pool45(hx4)
|
491 |
+
|
492 |
+
# stage 5
|
493 |
+
hx5 = self.stage5(hx)
|
494 |
+
hx = self.pool56(hx5)
|
495 |
+
|
496 |
+
# stage 6
|
497 |
+
hx6 = self.stage6(hx)
|
498 |
+
|
499 |
+
# side output
|
500 |
+
d1 = self.side1(hx1)
|
501 |
+
d1 = _upsample_like(d1, x)
|
502 |
+
|
503 |
+
d2 = self.side2(hx2)
|
504 |
+
d2 = _upsample_like(d2, x)
|
505 |
+
|
506 |
+
d3 = self.side3(hx3)
|
507 |
+
d3 = _upsample_like(d3, x)
|
508 |
+
|
509 |
+
d4 = self.side4(hx4)
|
510 |
+
d4 = _upsample_like(d4, x)
|
511 |
+
|
512 |
+
d5 = self.side5(hx5)
|
513 |
+
d5 = _upsample_like(d5, x)
|
514 |
+
|
515 |
+
d6 = self.side6(hx6)
|
516 |
+
d6 = _upsample_like(d6, x)
|
517 |
+
|
518 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
519 |
+
|
520 |
+
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
|
521 |
+
return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
|
522 |
+
|
523 |
+
|
524 |
+
class ISNetDIS(nn.Module):
|
525 |
+
|
526 |
+
def __init__(self, in_ch=3, out_ch=1):
|
527 |
+
super(ISNetDIS, self).__init__()
|
528 |
+
|
529 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
530 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
531 |
+
|
532 |
+
self.stage1 = RSU7(64, 32, 64)
|
533 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
534 |
+
|
535 |
+
self.stage2 = RSU6(64, 32, 128)
|
536 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
537 |
+
|
538 |
+
self.stage3 = RSU5(128, 64, 256)
|
539 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
540 |
+
|
541 |
+
self.stage4 = RSU4(256, 128, 512)
|
542 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
543 |
+
|
544 |
+
self.stage5 = RSU4F(512, 256, 512)
|
545 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
546 |
+
|
547 |
+
self.stage6 = RSU4F(512, 256, 512)
|
548 |
+
|
549 |
+
# decoder
|
550 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
551 |
+
self.stage4d = RSU4(1024, 128, 256)
|
552 |
+
self.stage3d = RSU5(512, 64, 128)
|
553 |
+
self.stage2d = RSU6(256, 32, 64)
|
554 |
+
self.stage1d = RSU7(128, 16, 64)
|
555 |
+
|
556 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
557 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
558 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
559 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
560 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
561 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
562 |
+
|
563 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
564 |
+
|
565 |
+
@staticmethod
|
566 |
+
def compute_loss_kl(preds, targets, dfs, fs, mode='MSE'):
|
567 |
+
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode, loss0_weight=5.0)
|
568 |
+
|
569 |
+
@staticmethod
|
570 |
+
def compute_loss(args, dist_weight=None):
|
571 |
+
if len(args) == 3:
|
572 |
+
ds, dfs, labels = args
|
573 |
+
return muti_loss_fusion(ds, labels, dist_weight, loss0_weight=5.0)
|
574 |
+
else:
|
575 |
+
ds, dfs, labels, fs = args
|
576 |
+
return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE", dist_weight=dist_weight, loss0_weight=5.0)
|
577 |
+
|
578 |
+
def forward(self, x):
|
579 |
+
hx = x
|
580 |
+
|
581 |
+
hxin = self.conv_in(hx)
|
582 |
+
hx = self.pool_in(hxin)
|
583 |
+
|
584 |
+
# stage 1
|
585 |
+
hx1 = self.stage1(hxin)
|
586 |
+
hx = self.pool12(hx1)
|
587 |
+
|
588 |
+
# stage 2
|
589 |
+
hx2 = self.stage2(hx)
|
590 |
+
hx = self.pool23(hx2)
|
591 |
+
|
592 |
+
# stage 3
|
593 |
+
hx3 = self.stage3(hx)
|
594 |
+
hx = self.pool34(hx3)
|
595 |
+
|
596 |
+
# stage 4
|
597 |
+
hx4 = self.stage4(hx)
|
598 |
+
hx = self.pool45(hx4)
|
599 |
+
|
600 |
+
# stage 5
|
601 |
+
hx5 = self.stage5(hx)
|
602 |
+
hx = self.pool56(hx5)
|
603 |
+
|
604 |
+
# stage 6
|
605 |
+
hx6 = self.stage6(hx)
|
606 |
+
hx6up = _upsample_like(hx6, hx5)
|
607 |
+
|
608 |
+
# -------------------- decoder --------------------
|
609 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
610 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
611 |
+
|
612 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
613 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
614 |
+
|
615 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
616 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
617 |
+
|
618 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
619 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
620 |
+
|
621 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
622 |
+
|
623 |
+
# side output
|
624 |
+
d1 = self.side1(hx1d)
|
625 |
+
d1 = _upsample_like(d1, x)
|
626 |
+
|
627 |
+
d2 = self.side2(hx2d)
|
628 |
+
d2 = _upsample_like(d2, x)
|
629 |
+
|
630 |
+
d3 = self.side3(hx3d)
|
631 |
+
d3 = _upsample_like(d3, x)
|
632 |
+
|
633 |
+
d4 = self.side4(hx4d)
|
634 |
+
d4 = _upsample_like(d4, x)
|
635 |
+
|
636 |
+
d5 = self.side5(hx5d)
|
637 |
+
d5 = _upsample_like(d5, x)
|
638 |
+
|
639 |
+
d6 = self.side6(hx6)
|
640 |
+
d6 = _upsample_like(d6, x)
|
641 |
+
|
642 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
643 |
+
|
644 |
+
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
645 |
+
return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
animeinsseg/models/animeseg_refine/models.py
ADDED
File without changes
|
animeinsseg/models/animeseg_refine/modnet.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Codes are borrowed from
|
2 |
+
# https://github.com/ZHKKKe/MODNet/blob/master/src/trainer.py
|
3 |
+
# https://github.com/ZHKKKe/MODNet/blob/master/src/models/backbones/mobilenetv2.py
|
4 |
+
# https://github.com/ZHKKKe/MODNet/blob/master/src/models/modnet.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import scipy
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
from scipy.ndimage import gaussian_filter
|
15 |
+
|
16 |
+
|
17 |
+
# ----------------------------------------------------------------------------------
|
18 |
+
# Loss Functions
|
19 |
+
# ----------------------------------------------------------------------------------
|
20 |
+
|
21 |
+
|
22 |
+
class GaussianBlurLayer(nn.Module):
|
23 |
+
""" Add Gaussian Blur to a 4D tensors
|
24 |
+
This layer takes a 4D tensor of {N, C, H, W} as input.
|
25 |
+
The Gaussian blur will be performed in given channel number (C) splitly.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, channels, kernel_size):
|
29 |
+
"""
|
30 |
+
Arguments:
|
31 |
+
channels (int): Channel for input tensor
|
32 |
+
kernel_size (int): Size of the kernel used in blurring
|
33 |
+
"""
|
34 |
+
|
35 |
+
super(GaussianBlurLayer, self).__init__()
|
36 |
+
self.channels = channels
|
37 |
+
self.kernel_size = kernel_size
|
38 |
+
assert self.kernel_size % 2 != 0
|
39 |
+
|
40 |
+
self.op = nn.Sequential(
|
41 |
+
nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
|
42 |
+
nn.Conv2d(channels, channels, self.kernel_size,
|
43 |
+
stride=1, padding=0, bias=None, groups=channels)
|
44 |
+
)
|
45 |
+
|
46 |
+
self._init_kernel()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Arguments:
|
51 |
+
x (torch.Tensor): input 4D tensor
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: Blurred version of the input
|
54 |
+
"""
|
55 |
+
|
56 |
+
if not len(list(x.shape)) == 4:
|
57 |
+
print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
|
58 |
+
exit()
|
59 |
+
elif not x.shape[1] == self.channels:
|
60 |
+
print('In \'GaussianBlurLayer\', the required channel ({0}) is'
|
61 |
+
'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
|
62 |
+
exit()
|
63 |
+
|
64 |
+
return self.op(x)
|
65 |
+
|
66 |
+
def _init_kernel(self):
|
67 |
+
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
|
68 |
+
|
69 |
+
n = np.zeros((self.kernel_size, self.kernel_size))
|
70 |
+
i = math.floor(self.kernel_size / 2)
|
71 |
+
n[i, i] = 1
|
72 |
+
kernel = gaussian_filter(n, sigma)
|
73 |
+
|
74 |
+
for name, param in self.named_parameters():
|
75 |
+
param.data.copy_(torch.from_numpy(kernel))
|
76 |
+
param.requires_grad = False
|
77 |
+
|
78 |
+
|
79 |
+
blurer = GaussianBlurLayer(1, 3)
|
80 |
+
|
81 |
+
|
82 |
+
def loss_func(pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte,
|
83 |
+
semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
|
84 |
+
""" loss of MODNet
|
85 |
+
Arguments:
|
86 |
+
blurer: GaussianBlurLayer
|
87 |
+
pred_semantic: model output
|
88 |
+
pred_detail: model output
|
89 |
+
pred_matte: model output
|
90 |
+
image : input RGB image ts pixel values should be normalized
|
91 |
+
trimap : trimap used to calculate the losses
|
92 |
+
its pixel values can be 0, 0.5, or 1
|
93 |
+
(foreground=1, background=0, unknown=0.5)
|
94 |
+
gt_matte: ground truth alpha matte its pixel values are between [0, 1]
|
95 |
+
semantic_scale (float): scale of the semantic loss
|
96 |
+
NOTE: please adjust according to your dataset
|
97 |
+
detail_scale (float): scale of the detail loss
|
98 |
+
NOTE: please adjust according to your dataset
|
99 |
+
matte_scale (float): scale of the matte loss
|
100 |
+
NOTE: please adjust according to your dataset
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
|
104 |
+
detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
|
105 |
+
matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
|
106 |
+
"""
|
107 |
+
|
108 |
+
trimap = trimap.float()
|
109 |
+
# calculate the boundary mask from the trimap
|
110 |
+
boundaries = (trimap < 0.5) + (trimap > 0.5)
|
111 |
+
|
112 |
+
# calculate the semantic loss
|
113 |
+
gt_semantic = F.interpolate(gt_matte, scale_factor=1 / 16, mode='bilinear')
|
114 |
+
gt_semantic = blurer(gt_semantic)
|
115 |
+
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
|
116 |
+
semantic_loss = semantic_scale * semantic_loss
|
117 |
+
|
118 |
+
# calculate the detail loss
|
119 |
+
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail.float())
|
120 |
+
gt_detail = torch.where(boundaries, trimap, gt_matte.float())
|
121 |
+
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail.float()))
|
122 |
+
detail_loss = detail_scale * detail_loss
|
123 |
+
|
124 |
+
# calculate the matte loss
|
125 |
+
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte.float())
|
126 |
+
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
|
127 |
+
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
|
128 |
+
+ 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
|
129 |
+
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
|
130 |
+
matte_loss = matte_scale * matte_loss
|
131 |
+
|
132 |
+
return semantic_loss, detail_loss, matte_loss
|
133 |
+
|
134 |
+
|
135 |
+
# ------------------------------------------------------------------------------
|
136 |
+
# Useful functions
|
137 |
+
# ------------------------------------------------------------------------------
|
138 |
+
|
139 |
+
def _make_divisible(v, divisor, min_value=None):
|
140 |
+
if min_value is None:
|
141 |
+
min_value = divisor
|
142 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
143 |
+
# Make sure that round down does not go down by more than 10%.
|
144 |
+
if new_v < 0.9 * v:
|
145 |
+
new_v += divisor
|
146 |
+
return new_v
|
147 |
+
|
148 |
+
|
149 |
+
def conv_bn(inp, oup, stride):
|
150 |
+
return nn.Sequential(
|
151 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
152 |
+
nn.BatchNorm2d(oup),
|
153 |
+
nn.ReLU6(inplace=True)
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def conv_1x1_bn(inp, oup):
|
158 |
+
return nn.Sequential(
|
159 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
160 |
+
nn.BatchNorm2d(oup),
|
161 |
+
nn.ReLU6(inplace=True)
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
# ------------------------------------------------------------------------------
|
166 |
+
# Class of Inverted Residual block
|
167 |
+
# ------------------------------------------------------------------------------
|
168 |
+
|
169 |
+
class InvertedResidual(nn.Module):
|
170 |
+
def __init__(self, inp, oup, stride, expansion, dilation=1):
|
171 |
+
super(InvertedResidual, self).__init__()
|
172 |
+
self.stride = stride
|
173 |
+
assert stride in [1, 2]
|
174 |
+
|
175 |
+
hidden_dim = round(inp * expansion)
|
176 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
177 |
+
|
178 |
+
if expansion == 1:
|
179 |
+
self.conv = nn.Sequential(
|
180 |
+
# dw
|
181 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
|
182 |
+
nn.BatchNorm2d(hidden_dim),
|
183 |
+
nn.ReLU6(inplace=True),
|
184 |
+
# pw-linear
|
185 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
186 |
+
nn.BatchNorm2d(oup),
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
self.conv = nn.Sequential(
|
190 |
+
# pw
|
191 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
192 |
+
nn.BatchNorm2d(hidden_dim),
|
193 |
+
nn.ReLU6(inplace=True),
|
194 |
+
# dw
|
195 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
|
196 |
+
nn.BatchNorm2d(hidden_dim),
|
197 |
+
nn.ReLU6(inplace=True),
|
198 |
+
# pw-linear
|
199 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
200 |
+
nn.BatchNorm2d(oup),
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
if self.use_res_connect:
|
205 |
+
return x + self.conv(x)
|
206 |
+
else:
|
207 |
+
return self.conv(x)
|
208 |
+
|
209 |
+
|
210 |
+
# ------------------------------------------------------------------------------
|
211 |
+
# Class of MobileNetV2
|
212 |
+
# ------------------------------------------------------------------------------
|
213 |
+
|
214 |
+
class MobileNetV2(nn.Module):
|
215 |
+
def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
|
216 |
+
super(MobileNetV2, self).__init__()
|
217 |
+
self.in_channels = in_channels
|
218 |
+
self.num_classes = num_classes
|
219 |
+
input_channel = 32
|
220 |
+
last_channel = 1280
|
221 |
+
interverted_residual_setting = [
|
222 |
+
# t, c, n, s
|
223 |
+
[1, 16, 1, 1],
|
224 |
+
[expansion, 24, 2, 2],
|
225 |
+
[expansion, 32, 3, 2],
|
226 |
+
[expansion, 64, 4, 2],
|
227 |
+
[expansion, 96, 3, 1],
|
228 |
+
[expansion, 160, 3, 2],
|
229 |
+
[expansion, 320, 1, 1],
|
230 |
+
]
|
231 |
+
|
232 |
+
# building first layer
|
233 |
+
input_channel = _make_divisible(input_channel * alpha, 8)
|
234 |
+
self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
|
235 |
+
self.features = [conv_bn(self.in_channels, input_channel, 2)]
|
236 |
+
|
237 |
+
# building inverted residual blocks
|
238 |
+
for t, c, n, s in interverted_residual_setting:
|
239 |
+
output_channel = _make_divisible(int(c * alpha), 8)
|
240 |
+
for i in range(n):
|
241 |
+
if i == 0:
|
242 |
+
self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
|
243 |
+
else:
|
244 |
+
self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
|
245 |
+
input_channel = output_channel
|
246 |
+
|
247 |
+
# building last several layers
|
248 |
+
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
249 |
+
|
250 |
+
# make it nn.Sequential
|
251 |
+
self.features = nn.Sequential(*self.features)
|
252 |
+
|
253 |
+
# building classifier
|
254 |
+
if self.num_classes is not None:
|
255 |
+
self.classifier = nn.Sequential(
|
256 |
+
nn.Dropout(0.2),
|
257 |
+
nn.Linear(self.last_channel, num_classes),
|
258 |
+
)
|
259 |
+
|
260 |
+
# Initialize weights
|
261 |
+
self._init_weights()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
# Stage1
|
265 |
+
x = self.features[0](x)
|
266 |
+
x = self.features[1](x)
|
267 |
+
# Stage2
|
268 |
+
x = self.features[2](x)
|
269 |
+
x = self.features[3](x)
|
270 |
+
# Stage3
|
271 |
+
x = self.features[4](x)
|
272 |
+
x = self.features[5](x)
|
273 |
+
x = self.features[6](x)
|
274 |
+
# Stage4
|
275 |
+
x = self.features[7](x)
|
276 |
+
x = self.features[8](x)
|
277 |
+
x = self.features[9](x)
|
278 |
+
x = self.features[10](x)
|
279 |
+
x = self.features[11](x)
|
280 |
+
x = self.features[12](x)
|
281 |
+
x = self.features[13](x)
|
282 |
+
# Stage5
|
283 |
+
x = self.features[14](x)
|
284 |
+
x = self.features[15](x)
|
285 |
+
x = self.features[16](x)
|
286 |
+
x = self.features[17](x)
|
287 |
+
x = self.features[18](x)
|
288 |
+
|
289 |
+
# Classification
|
290 |
+
if self.num_classes is not None:
|
291 |
+
x = x.mean(dim=(2, 3))
|
292 |
+
x = self.classifier(x)
|
293 |
+
|
294 |
+
# Output
|
295 |
+
return x
|
296 |
+
|
297 |
+
def _load_pretrained_model(self, pretrained_file):
|
298 |
+
pretrain_dict = torch.load(pretrained_file, map_location='cpu')
|
299 |
+
model_dict = {}
|
300 |
+
state_dict = self.state_dict()
|
301 |
+
print("[MobileNetV2] Loading pretrained model...")
|
302 |
+
for k, v in pretrain_dict.items():
|
303 |
+
if k in state_dict:
|
304 |
+
model_dict[k] = v
|
305 |
+
else:
|
306 |
+
print(k, "is ignored")
|
307 |
+
state_dict.update(model_dict)
|
308 |
+
self.load_state_dict(state_dict)
|
309 |
+
|
310 |
+
def _init_weights(self):
|
311 |
+
for m in self.modules():
|
312 |
+
if isinstance(m, nn.Conv2d):
|
313 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
314 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
315 |
+
if m.bias is not None:
|
316 |
+
m.bias.data.zero_()
|
317 |
+
elif isinstance(m, nn.BatchNorm2d):
|
318 |
+
m.weight.data.fill_(1)
|
319 |
+
m.bias.data.zero_()
|
320 |
+
elif isinstance(m, nn.Linear):
|
321 |
+
n = m.weight.size(1)
|
322 |
+
m.weight.data.normal_(0, 0.01)
|
323 |
+
m.bias.data.zero_()
|
324 |
+
|
325 |
+
|
326 |
+
class BaseBackbone(nn.Module):
|
327 |
+
""" Superclass of Replaceable Backbone Model for Semantic Estimation
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self, in_channels):
|
331 |
+
super(BaseBackbone, self).__init__()
|
332 |
+
self.in_channels = in_channels
|
333 |
+
|
334 |
+
self.model = None
|
335 |
+
self.enc_channels = []
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
raise NotImplementedError
|
339 |
+
|
340 |
+
def load_pretrained_ckpt(self):
|
341 |
+
raise NotImplementedError
|
342 |
+
|
343 |
+
|
344 |
+
class MobileNetV2Backbone(BaseBackbone):
|
345 |
+
""" MobileNetV2 Backbone
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, in_channels):
|
349 |
+
super(MobileNetV2Backbone, self).__init__(in_channels)
|
350 |
+
|
351 |
+
self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
|
352 |
+
self.enc_channels = [16, 24, 32, 96, 1280]
|
353 |
+
|
354 |
+
def forward(self, x):
|
355 |
+
# x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
|
356 |
+
x = self.model.features[0](x)
|
357 |
+
x = self.model.features[1](x)
|
358 |
+
enc2x = x
|
359 |
+
|
360 |
+
# x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
|
361 |
+
x = self.model.features[2](x)
|
362 |
+
x = self.model.features[3](x)
|
363 |
+
enc4x = x
|
364 |
+
|
365 |
+
# x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
|
366 |
+
x = self.model.features[4](x)
|
367 |
+
x = self.model.features[5](x)
|
368 |
+
x = self.model.features[6](x)
|
369 |
+
enc8x = x
|
370 |
+
|
371 |
+
# x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
|
372 |
+
x = self.model.features[7](x)
|
373 |
+
x = self.model.features[8](x)
|
374 |
+
x = self.model.features[9](x)
|
375 |
+
x = self.model.features[10](x)
|
376 |
+
x = self.model.features[11](x)
|
377 |
+
x = self.model.features[12](x)
|
378 |
+
x = self.model.features[13](x)
|
379 |
+
enc16x = x
|
380 |
+
|
381 |
+
# x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
|
382 |
+
x = self.model.features[14](x)
|
383 |
+
x = self.model.features[15](x)
|
384 |
+
x = self.model.features[16](x)
|
385 |
+
x = self.model.features[17](x)
|
386 |
+
x = self.model.features[18](x)
|
387 |
+
enc32x = x
|
388 |
+
return [enc2x, enc4x, enc8x, enc16x, enc32x]
|
389 |
+
|
390 |
+
def load_pretrained_ckpt(self):
|
391 |
+
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
|
392 |
+
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
|
393 |
+
if not os.path.exists(ckpt_path):
|
394 |
+
print('cannot find the pretrained mobilenetv2 backbone')
|
395 |
+
exit()
|
396 |
+
|
397 |
+
ckpt = torch.load(ckpt_path)
|
398 |
+
self.model.load_state_dict(ckpt)
|
399 |
+
|
400 |
+
|
401 |
+
SUPPORTED_BACKBONES = {
|
402 |
+
'mobilenetv2': MobileNetV2Backbone,
|
403 |
+
}
|
404 |
+
|
405 |
+
|
406 |
+
# ------------------------------------------------------------------------------
|
407 |
+
# MODNet Basic Modules
|
408 |
+
# ------------------------------------------------------------------------------
|
409 |
+
|
410 |
+
class IBNorm(nn.Module):
|
411 |
+
""" Combine Instance Norm and Batch Norm into One Layer
|
412 |
+
"""
|
413 |
+
|
414 |
+
def __init__(self, in_channels):
|
415 |
+
super(IBNorm, self).__init__()
|
416 |
+
in_channels = in_channels
|
417 |
+
self.bnorm_channels = int(in_channels / 2)
|
418 |
+
self.inorm_channels = in_channels - self.bnorm_channels
|
419 |
+
|
420 |
+
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
|
421 |
+
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
|
425 |
+
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
|
426 |
+
|
427 |
+
return torch.cat((bn_x, in_x), 1)
|
428 |
+
|
429 |
+
|
430 |
+
class Conv2dIBNormRelu(nn.Module):
|
431 |
+
""" Convolution + IBNorm + ReLu
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
435 |
+
stride=1, padding=0, dilation=1, groups=1, bias=True,
|
436 |
+
with_ibn=True, with_relu=True):
|
437 |
+
super(Conv2dIBNormRelu, self).__init__()
|
438 |
+
|
439 |
+
layers = [
|
440 |
+
nn.Conv2d(in_channels, out_channels, kernel_size,
|
441 |
+
stride=stride, padding=padding, dilation=dilation,
|
442 |
+
groups=groups, bias=bias)
|
443 |
+
]
|
444 |
+
|
445 |
+
if with_ibn:
|
446 |
+
layers.append(IBNorm(out_channels))
|
447 |
+
if with_relu:
|
448 |
+
layers.append(nn.ReLU(inplace=True))
|
449 |
+
|
450 |
+
self.layers = nn.Sequential(*layers)
|
451 |
+
|
452 |
+
def forward(self, x):
|
453 |
+
return self.layers(x)
|
454 |
+
|
455 |
+
|
456 |
+
class SEBlock(nn.Module):
|
457 |
+
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
|
458 |
+
"""
|
459 |
+
|
460 |
+
def __init__(self, in_channels, out_channels, reduction=1):
|
461 |
+
super(SEBlock, self).__init__()
|
462 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
463 |
+
self.fc = nn.Sequential(
|
464 |
+
nn.Linear(in_channels, int(in_channels // reduction), bias=False),
|
465 |
+
nn.ReLU(inplace=True),
|
466 |
+
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
|
467 |
+
nn.Sigmoid()
|
468 |
+
)
|
469 |
+
|
470 |
+
def forward(self, x):
|
471 |
+
b, c, _, _ = x.size()
|
472 |
+
w = self.pool(x).view(b, c)
|
473 |
+
w = self.fc(w).view(b, c, 1, 1)
|
474 |
+
|
475 |
+
return x * w.expand_as(x)
|
476 |
+
|
477 |
+
|
478 |
+
# ------------------------------------------------------------------------------
|
479 |
+
# MODNet Branches
|
480 |
+
# ------------------------------------------------------------------------------
|
481 |
+
|
482 |
+
class LRBranch(nn.Module):
|
483 |
+
""" Low Resolution Branch of MODNet
|
484 |
+
"""
|
485 |
+
|
486 |
+
def __init__(self, backbone):
|
487 |
+
super(LRBranch, self).__init__()
|
488 |
+
|
489 |
+
enc_channels = backbone.enc_channels
|
490 |
+
|
491 |
+
self.backbone = backbone
|
492 |
+
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
|
493 |
+
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
|
494 |
+
self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
|
495 |
+
self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False,
|
496 |
+
with_relu=False)
|
497 |
+
|
498 |
+
def forward(self, img, inference):
|
499 |
+
enc_features = self.backbone.forward(img)
|
500 |
+
enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
|
501 |
+
|
502 |
+
enc32x = self.se_block(enc32x)
|
503 |
+
lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
|
504 |
+
lr16x = self.conv_lr16x(lr16x)
|
505 |
+
lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
|
506 |
+
lr8x = self.conv_lr8x(lr8x)
|
507 |
+
|
508 |
+
pred_semantic = None
|
509 |
+
if not inference:
|
510 |
+
lr = self.conv_lr(lr8x)
|
511 |
+
pred_semantic = torch.sigmoid(lr)
|
512 |
+
|
513 |
+
return pred_semantic, lr8x, [enc2x, enc4x]
|
514 |
+
|
515 |
+
|
516 |
+
class HRBranch(nn.Module):
|
517 |
+
""" High Resolution Branch of MODNet
|
518 |
+
"""
|
519 |
+
|
520 |
+
def __init__(self, hr_channels, enc_channels):
|
521 |
+
super(HRBranch, self).__init__()
|
522 |
+
|
523 |
+
self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
|
524 |
+
self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
|
525 |
+
|
526 |
+
self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
|
527 |
+
self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
|
528 |
+
|
529 |
+
self.conv_hr4x = nn.Sequential(
|
530 |
+
Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
|
531 |
+
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
|
532 |
+
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
|
533 |
+
)
|
534 |
+
|
535 |
+
self.conv_hr2x = nn.Sequential(
|
536 |
+
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
|
537 |
+
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
|
538 |
+
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
|
539 |
+
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
|
540 |
+
)
|
541 |
+
|
542 |
+
self.conv_hr = nn.Sequential(
|
543 |
+
Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
|
544 |
+
Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
|
545 |
+
)
|
546 |
+
|
547 |
+
def forward(self, img, enc2x, enc4x, lr8x, inference):
|
548 |
+
img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
|
549 |
+
img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
|
550 |
+
|
551 |
+
enc2x = self.tohr_enc2x(enc2x)
|
552 |
+
hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
|
553 |
+
|
554 |
+
enc4x = self.tohr_enc4x(enc4x)
|
555 |
+
hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
|
556 |
+
|
557 |
+
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
|
558 |
+
hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
|
559 |
+
|
560 |
+
hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
|
561 |
+
hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
|
562 |
+
|
563 |
+
pred_detail = None
|
564 |
+
if not inference:
|
565 |
+
hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
|
566 |
+
hr = self.conv_hr(torch.cat((hr, img), dim=1))
|
567 |
+
pred_detail = torch.sigmoid(hr)
|
568 |
+
|
569 |
+
return pred_detail, hr2x
|
570 |
+
|
571 |
+
|
572 |
+
class FusionBranch(nn.Module):
|
573 |
+
""" Fusion Branch of MODNet
|
574 |
+
"""
|
575 |
+
|
576 |
+
def __init__(self, hr_channels, enc_channels):
|
577 |
+
super(FusionBranch, self).__init__()
|
578 |
+
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
|
579 |
+
|
580 |
+
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
|
581 |
+
self.conv_f = nn.Sequential(
|
582 |
+
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
|
583 |
+
Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
|
584 |
+
)
|
585 |
+
|
586 |
+
def forward(self, img, lr8x, hr2x):
|
587 |
+
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
|
588 |
+
lr4x = self.conv_lr4x(lr4x)
|
589 |
+
lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
|
590 |
+
|
591 |
+
f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
|
592 |
+
f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
|
593 |
+
f = self.conv_f(torch.cat((f, img), dim=1))
|
594 |
+
pred_matte = torch.sigmoid(f)
|
595 |
+
|
596 |
+
return pred_matte
|
597 |
+
|
598 |
+
|
599 |
+
# ------------------------------------------------------------------------------
|
600 |
+
# MODNet
|
601 |
+
# ------------------------------------------------------------------------------
|
602 |
+
|
603 |
+
class MODNet(nn.Module):
|
604 |
+
""" Architecture of MODNet
|
605 |
+
"""
|
606 |
+
|
607 |
+
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=False):
|
608 |
+
super(MODNet, self).__init__()
|
609 |
+
|
610 |
+
self.in_channels = in_channels
|
611 |
+
self.hr_channels = hr_channels
|
612 |
+
self.backbone_arch = backbone_arch
|
613 |
+
self.backbone_pretrained = backbone_pretrained
|
614 |
+
|
615 |
+
self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
|
616 |
+
|
617 |
+
self.lr_branch = LRBranch(self.backbone)
|
618 |
+
self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
|
619 |
+
self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
|
620 |
+
|
621 |
+
for m in self.modules():
|
622 |
+
if isinstance(m, nn.Conv2d):
|
623 |
+
self._init_conv(m)
|
624 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
625 |
+
self._init_norm(m)
|
626 |
+
|
627 |
+
if self.backbone_pretrained:
|
628 |
+
self.backbone.load_pretrained_ckpt()
|
629 |
+
|
630 |
+
def forward(self, img, inference):
|
631 |
+
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
|
632 |
+
pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
|
633 |
+
pred_matte = self.f_branch(img, lr8x, hr2x)
|
634 |
+
|
635 |
+
return pred_semantic, pred_detail, pred_matte
|
636 |
+
|
637 |
+
@staticmethod
|
638 |
+
def compute_loss(args):
|
639 |
+
pred_semantic, pred_detail, pred_matte, image, trimap, gt_matte = args
|
640 |
+
semantic_loss, detail_loss, matte_loss = loss_func(pred_semantic, pred_detail, pred_matte,
|
641 |
+
image, trimap, gt_matte)
|
642 |
+
loss = semantic_loss + detail_loss + matte_loss
|
643 |
+
return matte_loss, loss
|
644 |
+
|
645 |
+
def freeze_norm(self):
|
646 |
+
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
|
647 |
+
for m in self.modules():
|
648 |
+
for n in norm_types:
|
649 |
+
if isinstance(m, n):
|
650 |
+
m.eval()
|
651 |
+
continue
|
652 |
+
|
653 |
+
def _init_conv(self, conv):
|
654 |
+
nn.init.kaiming_uniform_(
|
655 |
+
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
|
656 |
+
if conv.bias is not None:
|
657 |
+
nn.init.constant_(conv.bias, 0)
|
658 |
+
|
659 |
+
def _init_norm(self, norm):
|
660 |
+
if norm.weight is not None:
|
661 |
+
nn.init.constant_(norm.weight, 1)
|
662 |
+
nn.init.constant_(norm.bias, 0)
|
663 |
+
|
664 |
+
def _apply(self, fn):
|
665 |
+
super(MODNet, self)._apply(fn)
|
666 |
+
blurer._apply(fn) # let blurer's device same as modnet
|
667 |
+
return self
|
animeinsseg/models/animeseg_refine/u2net.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Codes are borrowed from
|
2 |
+
# https://github.com/xuebinqin/U-2-Net/blob/master/model/u2net_refactor.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
__all__ = ['U2NET_full', 'U2NET_full2', 'U2NET_lite', 'U2NET_lite2', "U2NET"]
|
10 |
+
|
11 |
+
bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
|
12 |
+
|
13 |
+
|
14 |
+
def _upsample_like(x, size):
|
15 |
+
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
16 |
+
|
17 |
+
|
18 |
+
def _size_map(x, height):
|
19 |
+
# {height: size} for Upsample
|
20 |
+
size = list(x.shape[-2:])
|
21 |
+
sizes = {}
|
22 |
+
for h in range(1, height):
|
23 |
+
sizes[h] = size
|
24 |
+
size = [math.ceil(w / 2) for w in size]
|
25 |
+
return sizes
|
26 |
+
|
27 |
+
|
28 |
+
class REBNCONV(nn.Module):
|
29 |
+
def __init__(self, in_ch=3, out_ch=3, dilate=1):
|
30 |
+
super(REBNCONV, self).__init__()
|
31 |
+
|
32 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate)
|
33 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
34 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
|
38 |
+
|
39 |
+
|
40 |
+
class RSU(nn.Module):
|
41 |
+
def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
|
42 |
+
super(RSU, self).__init__()
|
43 |
+
self.name = name
|
44 |
+
self.height = height
|
45 |
+
self.dilated = dilated
|
46 |
+
self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
sizes = _size_map(x, self.height)
|
50 |
+
x = self.rebnconvin(x)
|
51 |
+
|
52 |
+
# U-Net like symmetric encoder-decoder structure
|
53 |
+
def unet(x, height=1):
|
54 |
+
if height < self.height:
|
55 |
+
x1 = getattr(self, f'rebnconv{height}')(x)
|
56 |
+
if not self.dilated and height < self.height - 1:
|
57 |
+
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
|
58 |
+
else:
|
59 |
+
x2 = unet(x1, height + 1)
|
60 |
+
|
61 |
+
x = getattr(self, f'rebnconv{height}d')(torch.cat((x2, x1), 1))
|
62 |
+
return _upsample_like(x, sizes[height - 1]) if not self.dilated and height > 1 else x
|
63 |
+
else:
|
64 |
+
return getattr(self, f'rebnconv{height}')(x)
|
65 |
+
|
66 |
+
return x + unet(x)
|
67 |
+
|
68 |
+
def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
|
69 |
+
self.add_module('rebnconvin', REBNCONV(in_ch, out_ch))
|
70 |
+
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
71 |
+
|
72 |
+
self.add_module(f'rebnconv1', REBNCONV(out_ch, mid_ch))
|
73 |
+
self.add_module(f'rebnconv1d', REBNCONV(mid_ch * 2, out_ch))
|
74 |
+
|
75 |
+
for i in range(2, height):
|
76 |
+
dilate = 1 if not dilated else 2 ** (i - 1)
|
77 |
+
self.add_module(f'rebnconv{i}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
78 |
+
self.add_module(f'rebnconv{i}d', REBNCONV(mid_ch * 2, mid_ch, dilate=dilate))
|
79 |
+
|
80 |
+
dilate = 2 if not dilated else 2 ** (height - 1)
|
81 |
+
self.add_module(f'rebnconv{height}', REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
82 |
+
|
83 |
+
|
84 |
+
class U2NET(nn.Module):
|
85 |
+
def __init__(self, cfgs, out_ch):
|
86 |
+
super(U2NET, self).__init__()
|
87 |
+
self.out_ch = out_ch
|
88 |
+
self._make_layers(cfgs)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
sizes = _size_map(x, self.height)
|
92 |
+
maps = [] # storage for maps
|
93 |
+
|
94 |
+
# side saliency map
|
95 |
+
def unet(x, height=1):
|
96 |
+
if height < 6:
|
97 |
+
x1 = getattr(self, f'stage{height}')(x)
|
98 |
+
x2 = unet(getattr(self, 'downsample')(x1), height + 1)
|
99 |
+
x = getattr(self, f'stage{height}d')(torch.cat((x2, x1), 1))
|
100 |
+
side(x, height)
|
101 |
+
return _upsample_like(x, sizes[height - 1]) if height > 1 else x
|
102 |
+
else:
|
103 |
+
x = getattr(self, f'stage{height}')(x)
|
104 |
+
side(x, height)
|
105 |
+
return _upsample_like(x, sizes[height - 1])
|
106 |
+
|
107 |
+
def side(x, h):
|
108 |
+
# side output saliency map (before sigmoid)
|
109 |
+
x = getattr(self, f'side{h}')(x)
|
110 |
+
x = _upsample_like(x, sizes[1])
|
111 |
+
maps.append(x)
|
112 |
+
|
113 |
+
def fuse():
|
114 |
+
# fuse saliency probability maps
|
115 |
+
maps.reverse()
|
116 |
+
x = torch.cat(maps, 1)
|
117 |
+
x = getattr(self, 'outconv')(x)
|
118 |
+
maps.insert(0, x)
|
119 |
+
# return [torch.sigmoid(x) for x in maps]
|
120 |
+
return [x for x in maps]
|
121 |
+
|
122 |
+
unet(x)
|
123 |
+
maps = fuse()
|
124 |
+
return maps
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def compute_loss(args):
|
128 |
+
preds, labels_v = args
|
129 |
+
d0, d1, d2, d3, d4, d5, d6 = preds
|
130 |
+
loss0 = bce_loss(d0, labels_v)
|
131 |
+
loss1 = bce_loss(d1, labels_v)
|
132 |
+
loss2 = bce_loss(d2, labels_v)
|
133 |
+
loss3 = bce_loss(d3, labels_v)
|
134 |
+
loss4 = bce_loss(d4, labels_v)
|
135 |
+
loss5 = bce_loss(d5, labels_v)
|
136 |
+
loss6 = bce_loss(d6, labels_v)
|
137 |
+
|
138 |
+
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
|
139 |
+
|
140 |
+
return loss0, loss
|
141 |
+
|
142 |
+
def _make_layers(self, cfgs):
|
143 |
+
self.height = int((len(cfgs) + 1) / 2)
|
144 |
+
self.add_module('downsample', nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
145 |
+
for k, v in cfgs.items():
|
146 |
+
# build rsu block
|
147 |
+
self.add_module(k, RSU(v[0], *v[1]))
|
148 |
+
if v[2] > 0:
|
149 |
+
# build side layer
|
150 |
+
self.add_module(f'side{v[0][-1]}', nn.Conv2d(v[2], self.out_ch, 3, padding=1))
|
151 |
+
# build fuse layer
|
152 |
+
self.add_module('outconv', nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1))
|
153 |
+
|
154 |
+
|
155 |
+
def U2NET_full():
|
156 |
+
full = {
|
157 |
+
# cfgs for building RSUs and sides
|
158 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
159 |
+
'stage1': ['En_1', (7, 3, 32, 64), -1],
|
160 |
+
'stage2': ['En_2', (6, 64, 32, 128), -1],
|
161 |
+
'stage3': ['En_3', (5, 128, 64, 256), -1],
|
162 |
+
'stage4': ['En_4', (4, 256, 128, 512), -1],
|
163 |
+
'stage5': ['En_5', (4, 512, 256, 512, True), -1],
|
164 |
+
'stage6': ['En_6', (4, 512, 256, 512, True), 512],
|
165 |
+
'stage5d': ['De_5', (4, 1024, 256, 512, True), 512],
|
166 |
+
'stage4d': ['De_4', (4, 1024, 128, 256), 256],
|
167 |
+
'stage3d': ['De_3', (5, 512, 64, 128), 128],
|
168 |
+
'stage2d': ['De_2', (6, 256, 32, 64), 64],
|
169 |
+
'stage1d': ['De_1', (7, 128, 16, 64), 64],
|
170 |
+
}
|
171 |
+
return U2NET(cfgs=full, out_ch=1)
|
172 |
+
|
173 |
+
|
174 |
+
def U2NET_full2():
|
175 |
+
full = {
|
176 |
+
# cfgs for building RSUs and sides
|
177 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
178 |
+
'stage1': ['En_1', (8, 3, 32, 64), -1],
|
179 |
+
'stage2': ['En_2', (7, 64, 32, 128), -1],
|
180 |
+
'stage3': ['En_3', (6, 128, 64, 256), -1],
|
181 |
+
'stage4': ['En_4', (5, 256, 128, 512), -1],
|
182 |
+
'stage5': ['En_5', (5, 512, 256, 512, True), -1],
|
183 |
+
'stage6': ['En_6', (5, 512, 256, 512, True), 512],
|
184 |
+
'stage5d': ['De_5', (5, 1024, 256, 512, True), 512],
|
185 |
+
'stage4d': ['De_4', (5, 1024, 128, 256), 256],
|
186 |
+
'stage3d': ['De_3', (6, 512, 64, 128), 128],
|
187 |
+
'stage2d': ['De_2', (7, 256, 32, 64), 64],
|
188 |
+
'stage1d': ['De_1', (8, 128, 16, 64), 64],
|
189 |
+
}
|
190 |
+
return U2NET(cfgs=full, out_ch=1)
|
191 |
+
|
192 |
+
|
193 |
+
def U2NET_lite():
|
194 |
+
lite = {
|
195 |
+
# cfgs for building RSUs and sides
|
196 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
197 |
+
'stage1': ['En_1', (7, 3, 16, 64), -1],
|
198 |
+
'stage2': ['En_2', (6, 64, 16, 64), -1],
|
199 |
+
'stage3': ['En_3', (5, 64, 16, 64), -1],
|
200 |
+
'stage4': ['En_4', (4, 64, 16, 64), -1],
|
201 |
+
'stage5': ['En_5', (4, 64, 16, 64, True), -1],
|
202 |
+
'stage6': ['En_6', (4, 64, 16, 64, True), 64],
|
203 |
+
'stage5d': ['De_5', (4, 128, 16, 64, True), 64],
|
204 |
+
'stage4d': ['De_4', (4, 128, 16, 64), 64],
|
205 |
+
'stage3d': ['De_3', (5, 128, 16, 64), 64],
|
206 |
+
'stage2d': ['De_2', (6, 128, 16, 64), 64],
|
207 |
+
'stage1d': ['De_1', (7, 128, 16, 64), 64],
|
208 |
+
}
|
209 |
+
return U2NET(cfgs=lite, out_ch=1)
|
210 |
+
|
211 |
+
|
212 |
+
def U2NET_lite2():
|
213 |
+
lite = {
|
214 |
+
# cfgs for building RSUs and sides
|
215 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
216 |
+
'stage1': ['En_1', (8, 3, 16, 64), -1],
|
217 |
+
'stage2': ['En_2', (7, 64, 16, 64), -1],
|
218 |
+
'stage3': ['En_3', (6, 64, 16, 64), -1],
|
219 |
+
'stage4': ['En_4', (5, 64, 16, 64), -1],
|
220 |
+
'stage5': ['En_5', (5, 64, 16, 64, True), -1],
|
221 |
+
'stage6': ['En_6', (5, 64, 16, 64, True), 64],
|
222 |
+
'stage5d': ['De_5', (5, 128, 16, 64, True), 64],
|
223 |
+
'stage4d': ['De_4', (5, 128, 16, 64), 64],
|
224 |
+
'stage3d': ['De_3', (6, 128, 16, 64), 64],
|
225 |
+
'stage2d': ['De_2', (7, 128, 16, 64), 64],
|
226 |
+
'stage1d': ['De_1', (8, 128, 16, 64), 64],
|
227 |
+
}
|
228 |
+
return U2NET(cfgs=lite, out_ch=1)
|
animeinsseg/models/rtmdet_inshead_custom.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from mmcv.cnn import ConvModule, is_norm
|
10 |
+
from mmcv.ops import batched_nms
|
11 |
+
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
|
12 |
+
normal_init)
|
13 |
+
from mmengine.structures import InstanceData
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
from mmdet.models.layers.transformer import inverse_sigmoid
|
17 |
+
from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
|
18 |
+
select_single_mlvl, sigmoid_geometric_mean)
|
19 |
+
from mmdet.registry import MODELS
|
20 |
+
from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor,
|
21 |
+
get_box_wh, scale_boxes)
|
22 |
+
from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
|
23 |
+
from mmdet.models.dense_heads.rtmdet_head import RTMDetHead
|
24 |
+
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead, MaskFeatModule
|
25 |
+
|
26 |
+
from mmdet.utils import AvoidCUDAOOM
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def sthgoeswrong(logits):
|
31 |
+
return torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits))
|
32 |
+
|
33 |
+
from time import time
|
34 |
+
|
35 |
+
@MODELS.register_module(force=True)
|
36 |
+
class RTMDetInsHeadCustom(RTMDetInsHead):
|
37 |
+
|
38 |
+
def loss_by_feat(self,
|
39 |
+
cls_scores: List[Tensor],
|
40 |
+
bbox_preds: List[Tensor],
|
41 |
+
kernel_preds: List[Tensor],
|
42 |
+
mask_feat: Tensor,
|
43 |
+
batch_gt_instances: InstanceList,
|
44 |
+
batch_img_metas: List[dict],
|
45 |
+
batch_gt_instances_ignore: OptInstanceList = None):
|
46 |
+
"""Compute losses of the head.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
cls_scores (list[Tensor]): Box scores for each scale level
|
50 |
+
Has shape (N, num_anchors * num_classes, H, W)
|
51 |
+
bbox_preds (list[Tensor]): Decoded box for each scale
|
52 |
+
level with shape (N, num_anchors * 4, H, W) in
|
53 |
+
[tl_x, tl_y, br_x, br_y] format.
|
54 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
55 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
56 |
+
attributes.
|
57 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
58 |
+
image size, scaling factor, etc.
|
59 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
|
60 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
61 |
+
data that is ignored during training and testing.
|
62 |
+
Defaults to None.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
dict[str, Tensor]: A dictionary of loss components.
|
66 |
+
"""
|
67 |
+
num_imgs = len(batch_img_metas)
|
68 |
+
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
69 |
+
assert len(featmap_sizes) == self.prior_generator.num_levels
|
70 |
+
|
71 |
+
device = cls_scores[0].device
|
72 |
+
anchor_list, valid_flag_list = self.get_anchors(
|
73 |
+
featmap_sizes, batch_img_metas, device=device)
|
74 |
+
flatten_cls_scores = torch.cat([
|
75 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
76 |
+
self.cls_out_channels)
|
77 |
+
for cls_score in cls_scores
|
78 |
+
], 1)
|
79 |
+
flatten_kernels = torch.cat([
|
80 |
+
kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
81 |
+
self.num_gen_params)
|
82 |
+
for kernel_pred in kernel_preds
|
83 |
+
], 1)
|
84 |
+
decoded_bboxes = []
|
85 |
+
for anchor, bbox_pred in zip(anchor_list[0], bbox_preds):
|
86 |
+
anchor = anchor.reshape(-1, 4)
|
87 |
+
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
88 |
+
bbox_pred = distance2bbox(anchor, bbox_pred)
|
89 |
+
decoded_bboxes.append(bbox_pred)
|
90 |
+
|
91 |
+
flatten_bboxes = torch.cat(decoded_bboxes, 1)
|
92 |
+
for gt_instances in batch_gt_instances:
|
93 |
+
gt_instances.masks = gt_instances.masks.to_tensor(
|
94 |
+
dtype=torch.bool, device=device)
|
95 |
+
|
96 |
+
cls_reg_targets = self.get_targets(
|
97 |
+
flatten_cls_scores,
|
98 |
+
flatten_bboxes,
|
99 |
+
anchor_list,
|
100 |
+
valid_flag_list,
|
101 |
+
batch_gt_instances,
|
102 |
+
batch_img_metas,
|
103 |
+
batch_gt_instances_ignore=batch_gt_instances_ignore)
|
104 |
+
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
|
105 |
+
assign_metrics_list, sampling_results_list) = cls_reg_targets
|
106 |
+
|
107 |
+
losses_cls, losses_bbox,\
|
108 |
+
cls_avg_factors, bbox_avg_factors = multi_apply(
|
109 |
+
self.loss_by_feat_single,
|
110 |
+
cls_scores,
|
111 |
+
decoded_bboxes,
|
112 |
+
labels_list,
|
113 |
+
label_weights_list,
|
114 |
+
bbox_targets_list,
|
115 |
+
assign_metrics_list,
|
116 |
+
self.prior_generator.strides)
|
117 |
+
|
118 |
+
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
|
119 |
+
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
|
120 |
+
|
121 |
+
bbox_avg_factor = reduce_mean(
|
122 |
+
sum(bbox_avg_factors)).clamp_(min=1).item()
|
123 |
+
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
|
124 |
+
|
125 |
+
loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels,
|
126 |
+
sampling_results_list,
|
127 |
+
batch_gt_instances)
|
128 |
+
loss = dict(
|
129 |
+
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask)
|
130 |
+
|
131 |
+
return loss
|
132 |
+
|
133 |
+
|
134 |
+
def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
|
135 |
+
priors: Tensor) -> Tensor:
|
136 |
+
|
137 |
+
ori_maskfeat = mask_feat
|
138 |
+
|
139 |
+
num_inst = priors.shape[0]
|
140 |
+
h, w = mask_feat.size()[-2:]
|
141 |
+
if num_inst < 1:
|
142 |
+
return torch.empty(
|
143 |
+
size=(num_inst, h, w),
|
144 |
+
dtype=mask_feat.dtype,
|
145 |
+
device=mask_feat.device)
|
146 |
+
if len(mask_feat.shape) < 4:
|
147 |
+
mask_feat.unsqueeze(0)
|
148 |
+
|
149 |
+
coord = self.prior_generator.single_level_grid_priors(
|
150 |
+
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
|
151 |
+
num_inst = priors.shape[0]
|
152 |
+
points = priors[:, :2].reshape(-1, 1, 2)
|
153 |
+
strides = priors[:, 2:].reshape(-1, 1, 2)
|
154 |
+
relative_coord = (points - coord).permute(0, 2, 1) / (
|
155 |
+
strides[..., 0].reshape(-1, 1, 1) * 8)
|
156 |
+
relative_coord = relative_coord.reshape(num_inst, 2, h, w)
|
157 |
+
|
158 |
+
mask_feat = torch.cat(
|
159 |
+
[relative_coord,
|
160 |
+
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
|
161 |
+
weights, biases = self.parse_dynamic_params(kernels)
|
162 |
+
|
163 |
+
fp16_used = weights[0].dtype == torch.float16
|
164 |
+
|
165 |
+
n_layers = len(weights)
|
166 |
+
x = mask_feat.reshape(1, -1, h, w)
|
167 |
+
for i, (weight, bias) in enumerate(zip(weights, biases)):
|
168 |
+
with torch.cuda.amp.autocast(enabled=False):
|
169 |
+
if fp16_used:
|
170 |
+
weight = weight.to(torch.float32)
|
171 |
+
bias = bias.to(torch.float32)
|
172 |
+
x = F.conv2d(
|
173 |
+
x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
|
174 |
+
if i < n_layers - 1:
|
175 |
+
x = F.relu(x)
|
176 |
+
|
177 |
+
if fp16_used:
|
178 |
+
x = torch.clip(x, -8192, 8192)
|
179 |
+
if sthgoeswrong(x):
|
180 |
+
torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
|
181 |
+
raise Exception('Mask Head NaN')
|
182 |
+
|
183 |
+
x = x.reshape(num_inst, h, w)
|
184 |
+
return x
|
185 |
+
|
186 |
+
def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
|
187 |
+
sampling_results_list: list,
|
188 |
+
batch_gt_instances: InstanceList) -> Tensor:
|
189 |
+
batch_pos_mask_logits = []
|
190 |
+
pos_gt_masks = []
|
191 |
+
ignore_masks = []
|
192 |
+
for idx, (mask_feat, kernels, sampling_results,
|
193 |
+
gt_instances) in enumerate(
|
194 |
+
zip(mask_feats, flatten_kernels, sampling_results_list,
|
195 |
+
batch_gt_instances)):
|
196 |
+
pos_priors = sampling_results.pos_priors
|
197 |
+
pos_inds = sampling_results.pos_inds
|
198 |
+
pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
|
199 |
+
pos_mask_logits = self._mask_predict_by_feat_single(
|
200 |
+
mask_feat, pos_kernels, pos_priors)
|
201 |
+
if gt_instances.masks.numel() == 0:
|
202 |
+
gt_masks = torch.empty_like(gt_instances.masks)
|
203 |
+
if gt_masks.shape[0] > 0:
|
204 |
+
ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
|
205 |
+
ignore_masks.append(ignore)
|
206 |
+
else:
|
207 |
+
gt_masks = gt_instances.masks[
|
208 |
+
sampling_results.pos_assigned_gt_inds, :]
|
209 |
+
ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
|
210 |
+
batch_pos_mask_logits.append(pos_mask_logits)
|
211 |
+
pos_gt_masks.append(gt_masks)
|
212 |
+
|
213 |
+
pos_gt_masks = torch.cat(pos_gt_masks, 0)
|
214 |
+
batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
|
215 |
+
ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
|
216 |
+
|
217 |
+
pos_gt_masks = pos_gt_masks[ignore_masks]
|
218 |
+
batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
|
219 |
+
|
220 |
+
|
221 |
+
# avg_factor
|
222 |
+
num_pos = batch_pos_mask_logits.shape[0]
|
223 |
+
num_pos = reduce_mean(mask_feats.new_tensor([num_pos
|
224 |
+
])).clamp_(min=1).item()
|
225 |
+
|
226 |
+
if batch_pos_mask_logits.shape[0] == 0:
|
227 |
+
return mask_feats.sum() * 0
|
228 |
+
|
229 |
+
scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
|
230 |
+
# upsample pred masks
|
231 |
+
batch_pos_mask_logits = F.interpolate(
|
232 |
+
batch_pos_mask_logits.unsqueeze(0),
|
233 |
+
scale_factor=scale,
|
234 |
+
mode='bilinear',
|
235 |
+
align_corners=False).squeeze(0)
|
236 |
+
# downsample gt masks
|
237 |
+
pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
|
238 |
+
2::self.mask_loss_stride,
|
239 |
+
self.mask_loss_stride //
|
240 |
+
2::self.mask_loss_stride]
|
241 |
+
|
242 |
+
loss_mask = self.loss_mask(
|
243 |
+
batch_pos_mask_logits,
|
244 |
+
pos_gt_masks,
|
245 |
+
weight=None,
|
246 |
+
avg_factor=num_pos)
|
247 |
+
|
248 |
+
return loss_mask
|
249 |
+
|
250 |
+
|
251 |
+
@MODELS.register_module()
|
252 |
+
class RTMDetInsSepBNHeadCustom(RTMDetInsSepBNHead):
|
253 |
+
def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor,
|
254 |
+
priors: Tensor) -> Tensor:
|
255 |
+
|
256 |
+
ori_maskfeat = mask_feat
|
257 |
+
|
258 |
+
num_inst = priors.shape[0]
|
259 |
+
h, w = mask_feat.size()[-2:]
|
260 |
+
if num_inst < 1:
|
261 |
+
return torch.empty(
|
262 |
+
size=(num_inst, h, w),
|
263 |
+
dtype=mask_feat.dtype,
|
264 |
+
device=mask_feat.device)
|
265 |
+
if len(mask_feat.shape) < 4:
|
266 |
+
mask_feat.unsqueeze(0)
|
267 |
+
|
268 |
+
coord = self.prior_generator.single_level_grid_priors(
|
269 |
+
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
|
270 |
+
num_inst = priors.shape[0]
|
271 |
+
points = priors[:, :2].reshape(-1, 1, 2)
|
272 |
+
strides = priors[:, 2:].reshape(-1, 1, 2)
|
273 |
+
relative_coord = (points - coord).permute(0, 2, 1) / (
|
274 |
+
strides[..., 0].reshape(-1, 1, 1) * 8)
|
275 |
+
relative_coord = relative_coord.reshape(num_inst, 2, h, w)
|
276 |
+
|
277 |
+
mask_feat = torch.cat(
|
278 |
+
[relative_coord,
|
279 |
+
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
|
280 |
+
weights, biases = self.parse_dynamic_params(kernels)
|
281 |
+
|
282 |
+
fp16_used = weights[0].dtype == torch.float16
|
283 |
+
|
284 |
+
n_layers = len(weights)
|
285 |
+
x = mask_feat.reshape(1, -1, h, w)
|
286 |
+
for i, (weight, bias) in enumerate(zip(weights, biases)):
|
287 |
+
with torch.cuda.amp.autocast(enabled=False):
|
288 |
+
if fp16_used:
|
289 |
+
weight = weight.to(torch.float32)
|
290 |
+
bias = bias.to(torch.float32)
|
291 |
+
x = F.conv2d(
|
292 |
+
x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
|
293 |
+
if i < n_layers - 1:
|
294 |
+
x = F.relu(x)
|
295 |
+
|
296 |
+
if fp16_used:
|
297 |
+
x = torch.clip(x, -8192, 8192)
|
298 |
+
if sthgoeswrong(x):
|
299 |
+
torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt')
|
300 |
+
raise Exception('Mask Head NaN')
|
301 |
+
|
302 |
+
x = x.reshape(num_inst, h, w)
|
303 |
+
return x
|
304 |
+
|
305 |
+
@AvoidCUDAOOM.retry_if_cuda_oom
|
306 |
+
def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor,
|
307 |
+
sampling_results_list: list,
|
308 |
+
batch_gt_instances: InstanceList) -> Tensor:
|
309 |
+
batch_pos_mask_logits = []
|
310 |
+
pos_gt_masks = []
|
311 |
+
ignore_masks = []
|
312 |
+
for idx, (mask_feat, kernels, sampling_results,
|
313 |
+
gt_instances) in enumerate(
|
314 |
+
zip(mask_feats, flatten_kernels, sampling_results_list,
|
315 |
+
batch_gt_instances)):
|
316 |
+
pos_priors = sampling_results.pos_priors
|
317 |
+
pos_inds = sampling_results.pos_inds
|
318 |
+
pos_kernels = kernels[pos_inds] # n_pos, num_gen_params
|
319 |
+
pos_mask_logits = self._mask_predict_by_feat_single(
|
320 |
+
mask_feat, pos_kernels, pos_priors)
|
321 |
+
if gt_instances.masks.numel() == 0:
|
322 |
+
gt_masks = torch.empty_like(gt_instances.masks)
|
323 |
+
# if gt_masks.shape[0] > 0:
|
324 |
+
# ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device)
|
325 |
+
# ignore_masks.append(ignore)
|
326 |
+
else:
|
327 |
+
msk = torch.logical_not(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
|
328 |
+
gt_masks = gt_instances.masks[
|
329 |
+
sampling_results.pos_assigned_gt_inds, :][msk]
|
330 |
+
pos_mask_logits = pos_mask_logits[msk]
|
331 |
+
# ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds])
|
332 |
+
batch_pos_mask_logits.append(pos_mask_logits)
|
333 |
+
pos_gt_masks.append(gt_masks)
|
334 |
+
|
335 |
+
pos_gt_masks = torch.cat(pos_gt_masks, 0)
|
336 |
+
batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0)
|
337 |
+
# ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0))
|
338 |
+
|
339 |
+
# pos_gt_masks = pos_gt_masks[ignore_masks]
|
340 |
+
# batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks]
|
341 |
+
|
342 |
+
|
343 |
+
# avg_factor
|
344 |
+
num_pos = batch_pos_mask_logits.shape[0]
|
345 |
+
num_pos = reduce_mean(mask_feats.new_tensor([num_pos
|
346 |
+
])).clamp_(min=1).item()
|
347 |
+
|
348 |
+
if batch_pos_mask_logits.shape[0] == 0:
|
349 |
+
return mask_feats.sum() * 0
|
350 |
+
|
351 |
+
scale = self.prior_generator.strides[0][0] // self.mask_loss_stride
|
352 |
+
# upsample pred masks
|
353 |
+
batch_pos_mask_logits = F.interpolate(
|
354 |
+
batch_pos_mask_logits.unsqueeze(0),
|
355 |
+
scale_factor=scale,
|
356 |
+
mode='bilinear',
|
357 |
+
align_corners=False).squeeze(0)
|
358 |
+
# downsample gt masks
|
359 |
+
pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride //
|
360 |
+
2::self.mask_loss_stride,
|
361 |
+
self.mask_loss_stride //
|
362 |
+
2::self.mask_loss_stride]
|
363 |
+
|
364 |
+
loss_mask = self.loss_mask(
|
365 |
+
batch_pos_mask_logits,
|
366 |
+
pos_gt_masks,
|
367 |
+
weight=None,
|
368 |
+
avg_factor=num_pos)
|
369 |
+
|
370 |
+
return loss_mask
|
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from animeinsseg import AnimeInsSeg, AnimeInstances
|
8 |
+
from animeinsseg.anime_instances import get_color
|
9 |
+
|
10 |
+
import os
|
11 |
+
|
12 |
+
if not os.path.exists("models"):
|
13 |
+
os.mkdir("models")
|
14 |
+
|
15 |
+
os.system("huggingface-cli lfs-enable-largefiles .")
|
16 |
+
os.system("git clone https://huggingface.co/dreMaz/AnimeInstanceSegmentation models/AnimeInstanceSegmentation")
|
17 |
+
|
18 |
+
ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
|
19 |
+
|
20 |
+
mask_thres = 0.3
|
21 |
+
instance_thres = 0.3
|
22 |
+
refine_kwargs = {'refine_method': 'refinenet_isnet'} # set to None if not using refinenet
|
23 |
+
# refine_kwargs = None
|
24 |
+
|
25 |
+
net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
|
26 |
+
|
27 |
+
def fn(image):
|
28 |
+
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
29 |
+
instances: AnimeInstances = net.infer(
|
30 |
+
img,
|
31 |
+
output_type='numpy',
|
32 |
+
pred_score_thr=instance_thres
|
33 |
+
)
|
34 |
+
|
35 |
+
drawed = img.copy()
|
36 |
+
im_h, im_w = img.shape[:2]
|
37 |
+
|
38 |
+
# instances.bboxes, instances.masks will be None, None if no obj is detected
|
39 |
+
|
40 |
+
for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
|
41 |
+
color = get_color(ii)
|
42 |
+
|
43 |
+
mask_alpha = 0.5
|
44 |
+
linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
|
45 |
+
|
46 |
+
# draw bbox
|
47 |
+
p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1]))
|
48 |
+
cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
|
49 |
+
|
50 |
+
# draw mask
|
51 |
+
p = mask.astype(np.float32)
|
52 |
+
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
|
53 |
+
alpha_msk = (mask_alpha * p)[..., None]
|
54 |
+
alpha_ori = 1 - alpha_msk
|
55 |
+
drawed = drawed * alpha_ori + alpha_msk * blend_mask
|
56 |
+
|
57 |
+
drawed = drawed.astype(np.uint8)
|
58 |
+
|
59 |
+
return Image.fromarray(drawed[..., ::-1])
|
60 |
+
|
61 |
+
iface = gr.Interface(
|
62 |
+
inputs=gr.Image(type="numpy"),
|
63 |
+
outputs="Image",
|
64 |
+
fn=fn
|
65 |
+
)
|
66 |
+
|
67 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
imageio
|
3 |
+
git+https://github.com/cocodataset/panopticapi.git
|
4 |
+
pytorch-lightning
|
5 |
+
albumentations
|
6 |
+
huggingface_hub
|
7 |
+
|
8 |
+
# For Web UI
|
9 |
+
gradio
|
10 |
+
torch
|
11 |
+
torchvision
|
12 |
+
openmim
|
13 |
+
mmengine
|
14 |
+
mmcv>=2.0.0
|
15 |
+
mmdet
|