Spaces:
Runtime error
Runtime error
add utils
Browse files- utils/__init__.py +0 -0
- utils/booru_tagger.py +116 -0
- utils/constants.py +82 -0
- utils/cupy_utils.py +122 -0
- utils/effects.py +182 -0
- utils/env_utils.py +65 -0
- utils/helper_math.h +1449 -0
- utils/io_utils.py +473 -0
- utils/logger.py +20 -0
- utils/mmdet_custom_hooks.py +223 -0
utils/__init__.py
ADDED
File without changes
|
utils/booru_tagger.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from onnxruntime import InferenceSession
|
6 |
+
from typing import Tuple, List, Dict
|
7 |
+
from io import BytesIO
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
def make_square(img, target_size):
|
16 |
+
old_size = img.shape[:2]
|
17 |
+
desired_size = max(old_size)
|
18 |
+
desired_size = max(desired_size, target_size)
|
19 |
+
|
20 |
+
delta_w = desired_size - old_size[1]
|
21 |
+
delta_h = desired_size - old_size[0]
|
22 |
+
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
23 |
+
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
24 |
+
|
25 |
+
color = [255, 255, 255]
|
26 |
+
new_im = cv2.copyMakeBorder(
|
27 |
+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
28 |
+
)
|
29 |
+
return new_im
|
30 |
+
|
31 |
+
|
32 |
+
def smart_resize(img, size):
|
33 |
+
# Assumes the image has already gone through make_square
|
34 |
+
if img.shape[0] > size:
|
35 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
36 |
+
elif img.shape[0] < size:
|
37 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
38 |
+
return img
|
39 |
+
|
40 |
+
class Tagger :
|
41 |
+
def __init__(self, filename) -> None:
|
42 |
+
self.model = InferenceSession(filename, providers=['CUDAExecutionProvider'])
|
43 |
+
[root, _] = os.path.split(filename)
|
44 |
+
self.tags = pd.read_csv(os.path.join(root, 'selected_tags.csv') if root else 'selected_tags.csv')
|
45 |
+
|
46 |
+
_, self.height, _, _ = self.model.get_inputs()[0].shape
|
47 |
+
|
48 |
+
characters = self.tags.loc[self.tags['category'] == 4]
|
49 |
+
self.characters = set(characters['name'].values.tolist())
|
50 |
+
|
51 |
+
def label(self, image: Image) -> Dict[str, float] :
|
52 |
+
# alpha to white
|
53 |
+
image = image.convert('RGBA')
|
54 |
+
new_image = Image.new('RGBA', image.size, 'WHITE')
|
55 |
+
new_image.paste(image, mask=image)
|
56 |
+
image = new_image.convert('RGB')
|
57 |
+
image = np.asarray(image)
|
58 |
+
|
59 |
+
# PIL RGB to OpenCV BGR
|
60 |
+
image = image[:, :, ::-1]
|
61 |
+
|
62 |
+
image = make_square(image, self.height)
|
63 |
+
image = smart_resize(image, self.height)
|
64 |
+
image = image.astype(np.float32)
|
65 |
+
image = np.expand_dims(image, 0)
|
66 |
+
|
67 |
+
# evaluate model
|
68 |
+
input_name = self.model.get_inputs()[0].name
|
69 |
+
label_name = self.model.get_outputs()[0].name
|
70 |
+
confidents = self.model.run([label_name], {input_name: image})[0]
|
71 |
+
|
72 |
+
tags = self.tags[:][['name']]
|
73 |
+
tags['confidents'] = confidents[0]
|
74 |
+
|
75 |
+
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
76 |
+
ratings = dict(tags[:4].values)
|
77 |
+
|
78 |
+
# rest are regular tags
|
79 |
+
tags = dict(tags[4:].values)
|
80 |
+
|
81 |
+
tags = {t: v for t, v in tags.items() if v > 0.5}
|
82 |
+
return tags
|
83 |
+
|
84 |
+
def label_cv2_bgr(self, image: np.ndarray) -> Dict[str, float] :
|
85 |
+
# image in BGR u8
|
86 |
+
image = make_square(image, self.height)
|
87 |
+
image = smart_resize(image, self.height)
|
88 |
+
image = image.astype(np.float32)
|
89 |
+
image = np.expand_dims(image, 0)
|
90 |
+
|
91 |
+
# evaluate model
|
92 |
+
input_name = self.model.get_inputs()[0].name
|
93 |
+
label_name = self.model.get_outputs()[0].name
|
94 |
+
confidents = self.model.run([label_name], {input_name: image})[0]
|
95 |
+
|
96 |
+
tags = self.tags[:][['name']]
|
97 |
+
cats = self.tags[:][['category']]
|
98 |
+
tags['confidents'] = confidents[0]
|
99 |
+
|
100 |
+
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
101 |
+
ratings = dict(tags[:4].values)
|
102 |
+
|
103 |
+
# rest are regular tags
|
104 |
+
tags = dict(tags[4:].values)
|
105 |
+
|
106 |
+
tags = [t for t, v in tags.items() if v > 0.5]
|
107 |
+
character_str = []
|
108 |
+
for t in tags:
|
109 |
+
if t in self.characters:
|
110 |
+
character_str.append(t)
|
111 |
+
return tags, character_str
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
modelp = r'models/wd-v1-4-swinv2-tagger-v2/model.onnx'
|
116 |
+
tagger = Tagger(modelp)
|
utils/constants.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
CATEGORIES = [
|
4 |
+
{"id": 0, "name": "object", "isthing": 1}
|
5 |
+
]
|
6 |
+
|
7 |
+
IMAGE_ID_ZFILL = 12
|
8 |
+
|
9 |
+
COLOR_PALETTE = [
|
10 |
+
(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228),
|
11 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30),
|
12 |
+
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
|
13 |
+
(165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255),
|
14 |
+
(0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255),
|
15 |
+
(199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92),
|
16 |
+
(209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164),
|
17 |
+
(92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0),
|
18 |
+
(174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174),
|
19 |
+
(255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54),
|
20 |
+
(207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51),
|
21 |
+
(74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65),
|
22 |
+
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
|
23 |
+
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
|
24 |
+
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
|
25 |
+
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
|
26 |
+
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
|
27 |
+
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
|
28 |
+
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
|
29 |
+
(246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203),
|
30 |
+
(150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100),
|
31 |
+
(92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255),
|
32 |
+
(124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0),
|
33 |
+
(193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176),
|
34 |
+
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
|
35 |
+
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
|
36 |
+
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
|
37 |
+
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
|
38 |
+
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
|
39 |
+
(146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140),
|
40 |
+
(96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152),
|
41 |
+
(208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0),
|
42 |
+
(0, 114, 143), (102, 102, 156), (250, 141, 255)
|
43 |
+
]
|
44 |
+
|
45 |
+
class Colors:
|
46 |
+
# Ultralytics color palette https://ultralytics.com/
|
47 |
+
def __init__(self):
|
48 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
49 |
+
hexs = ('FF1010', '10FF10', 'FFF010', '100FFF', '0018EC', 'FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
50 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
51 |
+
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
52 |
+
self.n = len(self.palette)
|
53 |
+
|
54 |
+
def __call__(self, i, bgr=True):
|
55 |
+
c = self.palette[int(i) % self.n]
|
56 |
+
return (c[2], c[1], c[0]) if bgr else c
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def hex2rgb(h): # rgb order (PIL)
|
60 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
61 |
+
|
62 |
+
colors = Colors()
|
63 |
+
def get_color(idx):
|
64 |
+
if idx == -1:
|
65 |
+
return 255
|
66 |
+
else:
|
67 |
+
return colors(idx)
|
68 |
+
|
69 |
+
|
70 |
+
MULTIPLE_TAGS = {'2girls', '3girls', '4girls', '5girls', '6+girls', 'multiple_girls',
|
71 |
+
'2boys', '3boys', '4boys', '5boys', '6+boys', 'multiple_boys',
|
72 |
+
'2others', '3others', '4others', '5others', '6+others', 'multiple_others'}
|
73 |
+
|
74 |
+
if hasattr(torch, 'cuda'):
|
75 |
+
DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
76 |
+
else:
|
77 |
+
DEFAULT_DEVICE = 'cpu'
|
78 |
+
|
79 |
+
DEFAULT_DETECTOR_CKPT = 'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
|
80 |
+
DEFAULT_DEPTHREFINE_CKPT = 'models/AnimeInstanceSegmentation/kenburns_depth_refinenet.ckpt'
|
81 |
+
DEFAULT_INPAINTNET_CKPT = 'models/AnimeInstanceSegmentation/kenburns_inpaintnet.ckpt'
|
82 |
+
DEPTH_ZOE_CKPT = 'models/AnimeInstanceSegmentation/ZoeD_M12_N.pt'
|
utils/cupy_utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import cupy
|
4 |
+
import os.path as osp
|
5 |
+
import torch
|
6 |
+
|
7 |
+
@cupy.memoize(for_each_device=True)
|
8 |
+
def launch_kernel(strFunction, strKernel):
|
9 |
+
if 'CUDA_HOME' not in os.environ:
|
10 |
+
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
11 |
+
# end
|
12 |
+
# , options=tuple([ '-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include' ])
|
13 |
+
return cupy.RawKernel(strKernel, strFunction)
|
14 |
+
|
15 |
+
|
16 |
+
def preprocess_kernel(strKernel, objVariables):
|
17 |
+
path_to_math_helper = osp.join(osp.dirname(osp.abspath(__file__)), 'helper_math.h')
|
18 |
+
strKernel = '''
|
19 |
+
#include <{{HELPER_PATH}}>
|
20 |
+
|
21 |
+
__device__ __forceinline__ float atomicMin(const float* buffer, float dblValue) {
|
22 |
+
int intValue = __float_as_int(*buffer);
|
23 |
+
|
24 |
+
while (__int_as_float(intValue) > dblValue) {
|
25 |
+
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
|
26 |
+
}
|
27 |
+
|
28 |
+
return __int_as_float(intValue);
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
__device__ __forceinline__ float atomicMax(const float* buffer, float dblValue) {
|
33 |
+
int intValue = __float_as_int(*buffer);
|
34 |
+
|
35 |
+
while (__int_as_float(intValue) < dblValue) {
|
36 |
+
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
|
37 |
+
}
|
38 |
+
|
39 |
+
return __int_as_float(intValue);
|
40 |
+
}
|
41 |
+
'''.replace('{{HELPER_PATH}}', path_to_math_helper) + strKernel
|
42 |
+
# end
|
43 |
+
|
44 |
+
for strVariable in objVariables:
|
45 |
+
objValue = objVariables[strVariable]
|
46 |
+
|
47 |
+
if type(objValue) == int:
|
48 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
49 |
+
|
50 |
+
elif type(objValue) == float:
|
51 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
|
52 |
+
|
53 |
+
elif type(objValue) == str:
|
54 |
+
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
|
55 |
+
|
56 |
+
# end
|
57 |
+
# end
|
58 |
+
|
59 |
+
while True:
|
60 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
61 |
+
|
62 |
+
if objMatch is None:
|
63 |
+
break
|
64 |
+
# end
|
65 |
+
|
66 |
+
intArg = int(objMatch.group(2))
|
67 |
+
|
68 |
+
strTensor = objMatch.group(4)
|
69 |
+
intSizes = objVariables[strTensor].size()
|
70 |
+
|
71 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
72 |
+
# end
|
73 |
+
|
74 |
+
while True:
|
75 |
+
objMatch = re.search('(STRIDE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
76 |
+
|
77 |
+
if objMatch is None:
|
78 |
+
break
|
79 |
+
# end
|
80 |
+
|
81 |
+
intArg = int(objMatch.group(2))
|
82 |
+
|
83 |
+
strTensor = objMatch.group(4)
|
84 |
+
intStrides = objVariables[strTensor].stride()
|
85 |
+
|
86 |
+
strKernel = strKernel.replace(objMatch.group(), str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()))
|
87 |
+
# end
|
88 |
+
|
89 |
+
while True:
|
90 |
+
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
|
91 |
+
|
92 |
+
if objMatch is None:
|
93 |
+
break
|
94 |
+
# end
|
95 |
+
|
96 |
+
intArgs = int(objMatch.group(2))
|
97 |
+
strArgs = objMatch.group(4).split(',')
|
98 |
+
|
99 |
+
strTensor = strArgs[0]
|
100 |
+
intStrides = objVariables[strTensor].stride()
|
101 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
102 |
+
|
103 |
+
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
|
104 |
+
# end
|
105 |
+
|
106 |
+
while True:
|
107 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
108 |
+
|
109 |
+
if objMatch is None:
|
110 |
+
break
|
111 |
+
# end
|
112 |
+
|
113 |
+
intArgs = int(objMatch.group(2))
|
114 |
+
strArgs = objMatch.group(4).split(',')
|
115 |
+
|
116 |
+
strTensor = strArgs[0]
|
117 |
+
intStrides = objVariables[strTensor].stride()
|
118 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
119 |
+
|
120 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
121 |
+
# end
|
122 |
+
return strKernel
|
utils/effects.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numba import jit, njit
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
import cv2
|
5 |
+
import math
|
6 |
+
from pathlib import Path
|
7 |
+
import os.path as osp
|
8 |
+
import torch
|
9 |
+
from .cupy_utils import launch_kernel, preprocess_kernel
|
10 |
+
import cupy
|
11 |
+
|
12 |
+
def bokeh_filter_cupy(img, depth, dx, dy, im_h, im_w, num_samples=32):
|
13 |
+
blurred = img.clone()
|
14 |
+
n = im_h * im_w
|
15 |
+
|
16 |
+
str_kernel = '''
|
17 |
+
extern "C" __global__ void kernel_bokeh(
|
18 |
+
const int n,
|
19 |
+
const int h,
|
20 |
+
const int w,
|
21 |
+
const int nsamples,
|
22 |
+
const float dx,
|
23 |
+
const float dy,
|
24 |
+
const float* img,
|
25 |
+
const float* depth,
|
26 |
+
float* blurred
|
27 |
+
) {
|
28 |
+
|
29 |
+
const int im_size = min(h, w);
|
30 |
+
const int sample_offset = nsamples / 2;
|
31 |
+
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n * 3; intIndex += blockDim.x * gridDim.x) {
|
32 |
+
|
33 |
+
const int intSample = intIndex / 3;
|
34 |
+
|
35 |
+
const int c = intIndex % 3;
|
36 |
+
const int y = ( intSample / w) % h;
|
37 |
+
const int x = intSample % w;
|
38 |
+
|
39 |
+
const int flatten_xy = y * w + x;
|
40 |
+
const int fid = flatten_xy * 3 + c;
|
41 |
+
const float d = depth[flatten_xy];
|
42 |
+
|
43 |
+
const float _dx = dx * d;
|
44 |
+
const float _dy = dy * d;
|
45 |
+
float weight = 0;
|
46 |
+
float color = 0;
|
47 |
+
for (int s = 0; s < nsamples; s += 1) {
|
48 |
+
|
49 |
+
const int sp = (s - sample_offset) * im_size;
|
50 |
+
const int x_ = x + int(round(_dx * sp));
|
51 |
+
const int y_ = y + int(round(_dy * sp));
|
52 |
+
|
53 |
+
if ((x_ >= w) | (y_ >= h) | (x_ < 0) | (y_ < 0))
|
54 |
+
continue;
|
55 |
+
|
56 |
+
const int flatten_xy_ = y_ * w + x_;
|
57 |
+
const float w_ = depth[flatten_xy_];
|
58 |
+
weight += w_;
|
59 |
+
const int fid_ = flatten_xy_ * 3 + c;
|
60 |
+
color += img[fid_] * w_;
|
61 |
+
}
|
62 |
+
|
63 |
+
if (weight != 0) {
|
64 |
+
color /= weight;
|
65 |
+
}
|
66 |
+
else {
|
67 |
+
color = img[fid];
|
68 |
+
}
|
69 |
+
|
70 |
+
blurred[fid] = color;
|
71 |
+
|
72 |
+
}
|
73 |
+
|
74 |
+
}
|
75 |
+
'''
|
76 |
+
launch_kernel('kernel_bokeh', str_kernel)(
|
77 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
78 |
+
block=tuple([ 512, 1, 1 ]),
|
79 |
+
args=[ cupy.int32(n), cupy.int32(im_h), cupy.int32(im_w), \
|
80 |
+
cupy.int32(num_samples), cupy.float32(dx), cupy.float32(dy),
|
81 |
+
img.data_ptr(), depth.data_ptr(), blurred.data_ptr() ]
|
82 |
+
)
|
83 |
+
|
84 |
+
return blurred
|
85 |
+
|
86 |
+
|
87 |
+
def np2flatten_tensor(arr: np.ndarray, to_cuda: bool = True) -> torch.Tensor:
|
88 |
+
c = 1
|
89 |
+
if len(arr.shape) == 3:
|
90 |
+
c = arr.shape[2]
|
91 |
+
else:
|
92 |
+
arr = arr[..., None]
|
93 |
+
arr = arr.transpose((2, 0, 1))[None, ...]
|
94 |
+
t = torch.from_numpy(arr).view(1, c, -1)
|
95 |
+
|
96 |
+
if to_cuda:
|
97 |
+
t = t.cuda()
|
98 |
+
return t
|
99 |
+
|
100 |
+
def ftensor2img(t: torch.Tensor, im_h, im_w):
|
101 |
+
t = t.detach().cpu().numpy().squeeze()
|
102 |
+
c = t.shape[0]
|
103 |
+
t = t.transpose((1, 0)).reshape((im_h, im_w, c))
|
104 |
+
return t
|
105 |
+
|
106 |
+
|
107 |
+
@njit
|
108 |
+
def bokeh_filter(img, depth, dx, dy, num_samples=32):
|
109 |
+
|
110 |
+
sample_offset = num_samples // 2
|
111 |
+
# _scale = 0.0005
|
112 |
+
# depth = depth * _scale
|
113 |
+
|
114 |
+
im_h, im_w = img.shape[0], img.shape[1]
|
115 |
+
im_size = min(im_h, im_w)
|
116 |
+
blured = np.zeros_like(img)
|
117 |
+
for x in range(im_w):
|
118 |
+
for y in range(im_h):
|
119 |
+
d = depth[y, x]
|
120 |
+
_color = np.array([0, 0, 0], dtype=np.float32)
|
121 |
+
_dx = dx * d
|
122 |
+
_dy = dy * d
|
123 |
+
weight = 0
|
124 |
+
for s in range(num_samples):
|
125 |
+
s = (s - sample_offset) * im_size
|
126 |
+
x_ = x + int(round(_dx * s))
|
127 |
+
y_ = y + int(round(_dy * s))
|
128 |
+
if x_ >= im_w or y_ >= im_h or x_ < 0 or y_ < 0:
|
129 |
+
continue
|
130 |
+
_w = depth[y_, x_]
|
131 |
+
weight += _w
|
132 |
+
_color += img[y_, x_] * _w
|
133 |
+
if weight == 0:
|
134 |
+
blured[y, x] = img[y, x]
|
135 |
+
else:
|
136 |
+
blured[y, x] = _color / np.array([weight, weight, weight], dtype=np.float32)
|
137 |
+
|
138 |
+
return blured
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def bokeh_blur(img, depth, num_samples=32, lightness_factor=10, depth_factor=2, use_cuda=False, focal_plane=None):
|
144 |
+
img = np.ascontiguousarray(img)
|
145 |
+
|
146 |
+
if depth is not None:
|
147 |
+
depth = depth.astype(np.float32)
|
148 |
+
if focal_plane is not None:
|
149 |
+
depth = depth.max() - np.abs(depth - focal_plane)
|
150 |
+
if depth_factor != 1:
|
151 |
+
depth = np.power(depth, depth_factor)
|
152 |
+
depth = depth - depth.min()
|
153 |
+
depth = depth.astype(np.float32) / depth.max()
|
154 |
+
depth = 1 - depth
|
155 |
+
|
156 |
+
img = img.astype(np.float32) / 255
|
157 |
+
img_hightlighted = np.power(img, lightness_factor)
|
158 |
+
|
159 |
+
# img =
|
160 |
+
im_h, im_w = img.shape[:2]
|
161 |
+
PI = math.pi
|
162 |
+
|
163 |
+
_scale = 0.0005
|
164 |
+
depth = depth * _scale
|
165 |
+
|
166 |
+
if use_cuda:
|
167 |
+
img_hightlighted = np2flatten_tensor(img_hightlighted, True)
|
168 |
+
depth = np2flatten_tensor(depth, True)
|
169 |
+
vertical_blured = bokeh_filter_cupy(img_hightlighted, depth, 0, 1, im_h, im_w, num_samples)
|
170 |
+
diag_blured = bokeh_filter_cupy(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), im_h, im_w, num_samples)
|
171 |
+
rhom_blur = bokeh_filter_cupy(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), im_h, im_w, num_samples)
|
172 |
+
blured = (diag_blured + rhom_blur) / 2
|
173 |
+
blured = ftensor2img(blured, im_h, im_w)
|
174 |
+
else:
|
175 |
+
vertical_blured = bokeh_filter(img_hightlighted, depth, 0, 1, num_samples)
|
176 |
+
diag_blured = bokeh_filter(vertical_blured, depth, math.cos(-PI/6), math.sin(-PI/6), num_samples)
|
177 |
+
rhom_blur = bokeh_filter(diag_blured, depth, math.cos(-PI * 5 /6), math.sin(-PI * 5 /6), num_samples)
|
178 |
+
blured = (diag_blured + rhom_blur) / 2
|
179 |
+
blured = np.power(blured, 1 / lightness_factor)
|
180 |
+
blured = (blured * 255).astype(np.uint8)
|
181 |
+
|
182 |
+
return blured
|
utils/env_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
|
7 |
+
|
8 |
+
def set_multi_processing(
|
9 |
+
mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
|
10 |
+
) -> None:
|
11 |
+
"""Set multi-processing related environment.
|
12 |
+
|
13 |
+
This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
|
14 |
+
|
15 |
+
Args:
|
16 |
+
mp_start_method (str): Set the method which should be used to start
|
17 |
+
child processes. Defaults to 'fork'.
|
18 |
+
opencv_num_threads (int): Number of threads for opencv.
|
19 |
+
Defaults to 0.
|
20 |
+
distributed (bool): True if distributed environment.
|
21 |
+
Defaults to False.
|
22 |
+
""" # noqa
|
23 |
+
# set multi-process start method as `fork` to speed up the training
|
24 |
+
if platform.system() != "Windows":
|
25 |
+
current_method = mp.get_start_method(allow_none=True)
|
26 |
+
if current_method is not None and current_method != mp_start_method:
|
27 |
+
warnings.warn(
|
28 |
+
f"Multi-processing start method `{mp_start_method}` is "
|
29 |
+
f"different from the previous setting `{current_method}`."
|
30 |
+
f"It will be force set to `{mp_start_method}`. You can "
|
31 |
+
"change this behavior by changing `mp_start_method` in "
|
32 |
+
"your config."
|
33 |
+
)
|
34 |
+
mp.set_start_method(mp_start_method, force=True)
|
35 |
+
|
36 |
+
try:
|
37 |
+
import cv2
|
38 |
+
|
39 |
+
# disable opencv multithreading to avoid system being overloaded
|
40 |
+
cv2.setNumThreads(opencv_num_threads)
|
41 |
+
except ImportError:
|
42 |
+
pass
|
43 |
+
|
44 |
+
# setup OMP threads
|
45 |
+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
46 |
+
if "OMP_NUM_THREADS" not in os.environ and distributed:
|
47 |
+
omp_num_threads = 1
|
48 |
+
warnings.warn(
|
49 |
+
"Setting OMP_NUM_THREADS environment variable for each process"
|
50 |
+
f" to be {omp_num_threads} in default, to avoid your system "
|
51 |
+
"being overloaded, please further tune the variable for "
|
52 |
+
"optimal performance in your application as needed."
|
53 |
+
)
|
54 |
+
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
|
55 |
+
|
56 |
+
# # setup MKL threads
|
57 |
+
if "MKL_NUM_THREADS" not in os.environ and distributed:
|
58 |
+
mkl_num_threads = 1
|
59 |
+
warnings.warn(
|
60 |
+
"Setting MKL_NUM_THREADS environment variable for each process"
|
61 |
+
f" to be {mkl_num_threads} in default, to avoid your system "
|
62 |
+
"being overloaded, please further tune the variable for "
|
63 |
+
"optimal performance in your application as needed."
|
64 |
+
)
|
65 |
+
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
|
utils/helper_math.h
ADDED
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright 1993-2012 NVIDIA Corporation. All rights reserved.
|
3 |
+
*
|
4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
5 |
+
* with this source code for terms and conditions that govern your use of
|
6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
7 |
+
* this software and related documentation outside the terms of the EULA
|
8 |
+
* is strictly prohibited.
|
9 |
+
*
|
10 |
+
*/
|
11 |
+
|
12 |
+
/*
|
13 |
+
* This file implements common mathematical operations on vector types
|
14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
15 |
+
*
|
16 |
+
* The syntax is modeled on the Cg standard library.
|
17 |
+
*
|
18 |
+
* This is part of the Helper library includes
|
19 |
+
*
|
20 |
+
* Thanks to Linh Hah for additions and fixes.
|
21 |
+
*/
|
22 |
+
|
23 |
+
#ifndef HELPER_MATH_H
|
24 |
+
#define HELPER_MATH_H
|
25 |
+
|
26 |
+
#include "cuda_runtime.h"
|
27 |
+
|
28 |
+
typedef unsigned int uint;
|
29 |
+
typedef unsigned short ushort;
|
30 |
+
|
31 |
+
#ifndef __CUDACC__
|
32 |
+
#include <math.h>
|
33 |
+
|
34 |
+
////////////////////////////////////////////////////////////////////////////////
|
35 |
+
// host implementations of CUDA functions
|
36 |
+
////////////////////////////////////////////////////////////////////////////////
|
37 |
+
|
38 |
+
inline float fminf(float a, float b)
|
39 |
+
{
|
40 |
+
return a < b ? a : b;
|
41 |
+
}
|
42 |
+
|
43 |
+
inline float fmaxf(float a, float b)
|
44 |
+
{
|
45 |
+
return a > b ? a : b;
|
46 |
+
}
|
47 |
+
|
48 |
+
inline int max(int a, int b)
|
49 |
+
{
|
50 |
+
return a > b ? a : b;
|
51 |
+
}
|
52 |
+
|
53 |
+
inline int min(int a, int b)
|
54 |
+
{
|
55 |
+
return a < b ? a : b;
|
56 |
+
}
|
57 |
+
|
58 |
+
inline float rsqrtf(float x)
|
59 |
+
{
|
60 |
+
return 1.0f / sqrtf(x);
|
61 |
+
}
|
62 |
+
#endif
|
63 |
+
|
64 |
+
////////////////////////////////////////////////////////////////////////////////
|
65 |
+
// constructors
|
66 |
+
////////////////////////////////////////////////////////////////////////////////
|
67 |
+
|
68 |
+
inline __host__ __device__ float2 make_float2(float s)
|
69 |
+
{
|
70 |
+
return make_float2(s, s);
|
71 |
+
}
|
72 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
73 |
+
{
|
74 |
+
return make_float2(a.x, a.y);
|
75 |
+
}
|
76 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
77 |
+
{
|
78 |
+
return make_float2(float(a.x), float(a.y));
|
79 |
+
}
|
80 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
81 |
+
{
|
82 |
+
return make_float2(float(a.x), float(a.y));
|
83 |
+
}
|
84 |
+
|
85 |
+
inline __host__ __device__ int2 make_int2(int s)
|
86 |
+
{
|
87 |
+
return make_int2(s, s);
|
88 |
+
}
|
89 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
90 |
+
{
|
91 |
+
return make_int2(a.x, a.y);
|
92 |
+
}
|
93 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
94 |
+
{
|
95 |
+
return make_int2(int(a.x), int(a.y));
|
96 |
+
}
|
97 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
98 |
+
{
|
99 |
+
return make_int2(int(a.x), int(a.y));
|
100 |
+
}
|
101 |
+
|
102 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
103 |
+
{
|
104 |
+
return make_uint2(s, s);
|
105 |
+
}
|
106 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
107 |
+
{
|
108 |
+
return make_uint2(a.x, a.y);
|
109 |
+
}
|
110 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
111 |
+
{
|
112 |
+
return make_uint2(uint(a.x), uint(a.y));
|
113 |
+
}
|
114 |
+
|
115 |
+
inline __host__ __device__ float3 make_float3(float s)
|
116 |
+
{
|
117 |
+
return make_float3(s, s, s);
|
118 |
+
}
|
119 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
120 |
+
{
|
121 |
+
return make_float3(a.x, a.y, 0.0f);
|
122 |
+
}
|
123 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
124 |
+
{
|
125 |
+
return make_float3(a.x, a.y, s);
|
126 |
+
}
|
127 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
128 |
+
{
|
129 |
+
return make_float3(a.x, a.y, a.z);
|
130 |
+
}
|
131 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
132 |
+
{
|
133 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
134 |
+
}
|
135 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
136 |
+
{
|
137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
138 |
+
}
|
139 |
+
|
140 |
+
inline __host__ __device__ int3 make_int3(int s)
|
141 |
+
{
|
142 |
+
return make_int3(s, s, s);
|
143 |
+
}
|
144 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
145 |
+
{
|
146 |
+
return make_int3(a.x, a.y, 0);
|
147 |
+
}
|
148 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
149 |
+
{
|
150 |
+
return make_int3(a.x, a.y, s);
|
151 |
+
}
|
152 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
153 |
+
{
|
154 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
155 |
+
}
|
156 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
157 |
+
{
|
158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
159 |
+
}
|
160 |
+
|
161 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
162 |
+
{
|
163 |
+
return make_uint3(s, s, s);
|
164 |
+
}
|
165 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
166 |
+
{
|
167 |
+
return make_uint3(a.x, a.y, 0);
|
168 |
+
}
|
169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
170 |
+
{
|
171 |
+
return make_uint3(a.x, a.y, s);
|
172 |
+
}
|
173 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
174 |
+
{
|
175 |
+
return make_uint3(a.x, a.y, a.z);
|
176 |
+
}
|
177 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
178 |
+
{
|
179 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
180 |
+
}
|
181 |
+
|
182 |
+
inline __host__ __device__ float4 make_float4(float s)
|
183 |
+
{
|
184 |
+
return make_float4(s, s, s, s);
|
185 |
+
}
|
186 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
187 |
+
{
|
188 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
189 |
+
}
|
190 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
191 |
+
{
|
192 |
+
return make_float4(a.x, a.y, a.z, w);
|
193 |
+
}
|
194 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
195 |
+
{
|
196 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
197 |
+
}
|
198 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
199 |
+
{
|
200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
201 |
+
}
|
202 |
+
|
203 |
+
inline __host__ __device__ int4 make_int4(int s)
|
204 |
+
{
|
205 |
+
return make_int4(s, s, s, s);
|
206 |
+
}
|
207 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
208 |
+
{
|
209 |
+
return make_int4(a.x, a.y, a.z, 0);
|
210 |
+
}
|
211 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
212 |
+
{
|
213 |
+
return make_int4(a.x, a.y, a.z, w);
|
214 |
+
}
|
215 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
216 |
+
{
|
217 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
218 |
+
}
|
219 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
220 |
+
{
|
221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
222 |
+
}
|
223 |
+
|
224 |
+
|
225 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
226 |
+
{
|
227 |
+
return make_uint4(s, s, s, s);
|
228 |
+
}
|
229 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
230 |
+
{
|
231 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
232 |
+
}
|
233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
234 |
+
{
|
235 |
+
return make_uint4(a.x, a.y, a.z, w);
|
236 |
+
}
|
237 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
238 |
+
{
|
239 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
240 |
+
}
|
241 |
+
|
242 |
+
////////////////////////////////////////////////////////////////////////////////
|
243 |
+
// negate
|
244 |
+
////////////////////////////////////////////////////////////////////////////////
|
245 |
+
|
246 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
247 |
+
{
|
248 |
+
return make_float2(-a.x, -a.y);
|
249 |
+
}
|
250 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
251 |
+
{
|
252 |
+
return make_int2(-a.x, -a.y);
|
253 |
+
}
|
254 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
255 |
+
{
|
256 |
+
return make_float3(-a.x, -a.y, -a.z);
|
257 |
+
}
|
258 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
259 |
+
{
|
260 |
+
return make_int3(-a.x, -a.y, -a.z);
|
261 |
+
}
|
262 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
263 |
+
{
|
264 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
265 |
+
}
|
266 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
267 |
+
{
|
268 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
269 |
+
}
|
270 |
+
|
271 |
+
////////////////////////////////////////////////////////////////////////////////
|
272 |
+
// addition
|
273 |
+
////////////////////////////////////////////////////////////////////////////////
|
274 |
+
|
275 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
276 |
+
{
|
277 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
278 |
+
}
|
279 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
280 |
+
{
|
281 |
+
a.x += b.x;
|
282 |
+
a.y += b.y;
|
283 |
+
}
|
284 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
285 |
+
{
|
286 |
+
return make_float2(a.x + b, a.y + b);
|
287 |
+
}
|
288 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
289 |
+
{
|
290 |
+
return make_float2(a.x + b, a.y + b);
|
291 |
+
}
|
292 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
293 |
+
{
|
294 |
+
a.x += b;
|
295 |
+
a.y += b;
|
296 |
+
}
|
297 |
+
|
298 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
299 |
+
{
|
300 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
301 |
+
}
|
302 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
303 |
+
{
|
304 |
+
a.x += b.x;
|
305 |
+
a.y += b.y;
|
306 |
+
}
|
307 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
308 |
+
{
|
309 |
+
return make_int2(a.x + b, a.y + b);
|
310 |
+
}
|
311 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
312 |
+
{
|
313 |
+
return make_int2(a.x + b, a.y + b);
|
314 |
+
}
|
315 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
316 |
+
{
|
317 |
+
a.x += b;
|
318 |
+
a.y += b;
|
319 |
+
}
|
320 |
+
|
321 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
322 |
+
{
|
323 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
324 |
+
}
|
325 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
326 |
+
{
|
327 |
+
a.x += b.x;
|
328 |
+
a.y += b.y;
|
329 |
+
}
|
330 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
331 |
+
{
|
332 |
+
return make_uint2(a.x + b, a.y + b);
|
333 |
+
}
|
334 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
335 |
+
{
|
336 |
+
return make_uint2(a.x + b, a.y + b);
|
337 |
+
}
|
338 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
339 |
+
{
|
340 |
+
a.x += b;
|
341 |
+
a.y += b;
|
342 |
+
}
|
343 |
+
|
344 |
+
|
345 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
346 |
+
{
|
347 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
348 |
+
}
|
349 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
350 |
+
{
|
351 |
+
a.x += b.x;
|
352 |
+
a.y += b.y;
|
353 |
+
a.z += b.z;
|
354 |
+
}
|
355 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
356 |
+
{
|
357 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
358 |
+
}
|
359 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
360 |
+
{
|
361 |
+
a.x += b;
|
362 |
+
a.y += b;
|
363 |
+
a.z += b;
|
364 |
+
}
|
365 |
+
|
366 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
367 |
+
{
|
368 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
369 |
+
}
|
370 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
371 |
+
{
|
372 |
+
a.x += b.x;
|
373 |
+
a.y += b.y;
|
374 |
+
a.z += b.z;
|
375 |
+
}
|
376 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
377 |
+
{
|
378 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
379 |
+
}
|
380 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
381 |
+
{
|
382 |
+
a.x += b;
|
383 |
+
a.y += b;
|
384 |
+
a.z += b;
|
385 |
+
}
|
386 |
+
|
387 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
388 |
+
{
|
389 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
390 |
+
}
|
391 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
392 |
+
{
|
393 |
+
a.x += b.x;
|
394 |
+
a.y += b.y;
|
395 |
+
a.z += b.z;
|
396 |
+
}
|
397 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
398 |
+
{
|
399 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
400 |
+
}
|
401 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
402 |
+
{
|
403 |
+
a.x += b;
|
404 |
+
a.y += b;
|
405 |
+
a.z += b;
|
406 |
+
}
|
407 |
+
|
408 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
409 |
+
{
|
410 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
411 |
+
}
|
412 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
413 |
+
{
|
414 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
415 |
+
}
|
416 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
417 |
+
{
|
418 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
419 |
+
}
|
420 |
+
|
421 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
422 |
+
{
|
423 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
424 |
+
}
|
425 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
426 |
+
{
|
427 |
+
a.x += b.x;
|
428 |
+
a.y += b.y;
|
429 |
+
a.z += b.z;
|
430 |
+
a.w += b.w;
|
431 |
+
}
|
432 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
433 |
+
{
|
434 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
435 |
+
}
|
436 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
437 |
+
{
|
438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
439 |
+
}
|
440 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
441 |
+
{
|
442 |
+
a.x += b;
|
443 |
+
a.y += b;
|
444 |
+
a.z += b;
|
445 |
+
a.w += b;
|
446 |
+
}
|
447 |
+
|
448 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
449 |
+
{
|
450 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
451 |
+
}
|
452 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
453 |
+
{
|
454 |
+
a.x += b.x;
|
455 |
+
a.y += b.y;
|
456 |
+
a.z += b.z;
|
457 |
+
a.w += b.w;
|
458 |
+
}
|
459 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
460 |
+
{
|
461 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
462 |
+
}
|
463 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
464 |
+
{
|
465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
466 |
+
}
|
467 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
468 |
+
{
|
469 |
+
a.x += b;
|
470 |
+
a.y += b;
|
471 |
+
a.z += b;
|
472 |
+
a.w += b;
|
473 |
+
}
|
474 |
+
|
475 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
476 |
+
{
|
477 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
478 |
+
}
|
479 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
480 |
+
{
|
481 |
+
a.x += b.x;
|
482 |
+
a.y += b.y;
|
483 |
+
a.z += b.z;
|
484 |
+
a.w += b.w;
|
485 |
+
}
|
486 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
487 |
+
{
|
488 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
489 |
+
}
|
490 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
491 |
+
{
|
492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
493 |
+
}
|
494 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
495 |
+
{
|
496 |
+
a.x += b;
|
497 |
+
a.y += b;
|
498 |
+
a.z += b;
|
499 |
+
a.w += b;
|
500 |
+
}
|
501 |
+
|
502 |
+
////////////////////////////////////////////////////////////////////////////////
|
503 |
+
// subtract
|
504 |
+
////////////////////////////////////////////////////////////////////////////////
|
505 |
+
|
506 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
507 |
+
{
|
508 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
509 |
+
}
|
510 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
511 |
+
{
|
512 |
+
a.x -= b.x;
|
513 |
+
a.y -= b.y;
|
514 |
+
}
|
515 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
516 |
+
{
|
517 |
+
return make_float2(a.x - b, a.y - b);
|
518 |
+
}
|
519 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
520 |
+
{
|
521 |
+
return make_float2(b - a.x, b - a.y);
|
522 |
+
}
|
523 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
524 |
+
{
|
525 |
+
a.x -= b;
|
526 |
+
a.y -= b;
|
527 |
+
}
|
528 |
+
|
529 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
530 |
+
{
|
531 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
532 |
+
}
|
533 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
534 |
+
{
|
535 |
+
a.x -= b.x;
|
536 |
+
a.y -= b.y;
|
537 |
+
}
|
538 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
539 |
+
{
|
540 |
+
return make_int2(a.x - b, a.y - b);
|
541 |
+
}
|
542 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
543 |
+
{
|
544 |
+
return make_int2(b - a.x, b - a.y);
|
545 |
+
}
|
546 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
547 |
+
{
|
548 |
+
a.x -= b;
|
549 |
+
a.y -= b;
|
550 |
+
}
|
551 |
+
|
552 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
553 |
+
{
|
554 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
555 |
+
}
|
556 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
557 |
+
{
|
558 |
+
a.x -= b.x;
|
559 |
+
a.y -= b.y;
|
560 |
+
}
|
561 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
562 |
+
{
|
563 |
+
return make_uint2(a.x - b, a.y - b);
|
564 |
+
}
|
565 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
566 |
+
{
|
567 |
+
return make_uint2(b - a.x, b - a.y);
|
568 |
+
}
|
569 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
570 |
+
{
|
571 |
+
a.x -= b;
|
572 |
+
a.y -= b;
|
573 |
+
}
|
574 |
+
|
575 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
576 |
+
{
|
577 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
578 |
+
}
|
579 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
580 |
+
{
|
581 |
+
a.x -= b.x;
|
582 |
+
a.y -= b.y;
|
583 |
+
a.z -= b.z;
|
584 |
+
}
|
585 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
586 |
+
{
|
587 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
588 |
+
}
|
589 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
590 |
+
{
|
591 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
592 |
+
}
|
593 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
594 |
+
{
|
595 |
+
a.x -= b;
|
596 |
+
a.y -= b;
|
597 |
+
a.z -= b;
|
598 |
+
}
|
599 |
+
|
600 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
601 |
+
{
|
602 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
603 |
+
}
|
604 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
605 |
+
{
|
606 |
+
a.x -= b.x;
|
607 |
+
a.y -= b.y;
|
608 |
+
a.z -= b.z;
|
609 |
+
}
|
610 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
611 |
+
{
|
612 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
613 |
+
}
|
614 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
615 |
+
{
|
616 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
617 |
+
}
|
618 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
619 |
+
{
|
620 |
+
a.x -= b;
|
621 |
+
a.y -= b;
|
622 |
+
a.z -= b;
|
623 |
+
}
|
624 |
+
|
625 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
626 |
+
{
|
627 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
628 |
+
}
|
629 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
630 |
+
{
|
631 |
+
a.x -= b.x;
|
632 |
+
a.y -= b.y;
|
633 |
+
a.z -= b.z;
|
634 |
+
}
|
635 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
636 |
+
{
|
637 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
638 |
+
}
|
639 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
640 |
+
{
|
641 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
642 |
+
}
|
643 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
644 |
+
{
|
645 |
+
a.x -= b;
|
646 |
+
a.y -= b;
|
647 |
+
a.z -= b;
|
648 |
+
}
|
649 |
+
|
650 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
651 |
+
{
|
652 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
653 |
+
}
|
654 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
655 |
+
{
|
656 |
+
a.x -= b.x;
|
657 |
+
a.y -= b.y;
|
658 |
+
a.z -= b.z;
|
659 |
+
a.w -= b.w;
|
660 |
+
}
|
661 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
662 |
+
{
|
663 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
664 |
+
}
|
665 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
666 |
+
{
|
667 |
+
a.x -= b;
|
668 |
+
a.y -= b;
|
669 |
+
a.z -= b;
|
670 |
+
a.w -= b;
|
671 |
+
}
|
672 |
+
|
673 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
674 |
+
{
|
675 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
676 |
+
}
|
677 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
678 |
+
{
|
679 |
+
a.x -= b.x;
|
680 |
+
a.y -= b.y;
|
681 |
+
a.z -= b.z;
|
682 |
+
a.w -= b.w;
|
683 |
+
}
|
684 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
685 |
+
{
|
686 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
687 |
+
}
|
688 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
689 |
+
{
|
690 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
691 |
+
}
|
692 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
693 |
+
{
|
694 |
+
a.x -= b;
|
695 |
+
a.y -= b;
|
696 |
+
a.z -= b;
|
697 |
+
a.w -= b;
|
698 |
+
}
|
699 |
+
|
700 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
701 |
+
{
|
702 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
703 |
+
}
|
704 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
705 |
+
{
|
706 |
+
a.x -= b.x;
|
707 |
+
a.y -= b.y;
|
708 |
+
a.z -= b.z;
|
709 |
+
a.w -= b.w;
|
710 |
+
}
|
711 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
712 |
+
{
|
713 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
714 |
+
}
|
715 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
716 |
+
{
|
717 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
718 |
+
}
|
719 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
720 |
+
{
|
721 |
+
a.x -= b;
|
722 |
+
a.y -= b;
|
723 |
+
a.z -= b;
|
724 |
+
a.w -= b;
|
725 |
+
}
|
726 |
+
|
727 |
+
////////////////////////////////////////////////////////////////////////////////
|
728 |
+
// multiply
|
729 |
+
////////////////////////////////////////////////////////////////////////////////
|
730 |
+
|
731 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
732 |
+
{
|
733 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
734 |
+
}
|
735 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
736 |
+
{
|
737 |
+
a.x *= b.x;
|
738 |
+
a.y *= b.y;
|
739 |
+
}
|
740 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
741 |
+
{
|
742 |
+
return make_float2(a.x * b, a.y * b);
|
743 |
+
}
|
744 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
745 |
+
{
|
746 |
+
return make_float2(b * a.x, b * a.y);
|
747 |
+
}
|
748 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
749 |
+
{
|
750 |
+
a.x *= b;
|
751 |
+
a.y *= b;
|
752 |
+
}
|
753 |
+
|
754 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
755 |
+
{
|
756 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
757 |
+
}
|
758 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
759 |
+
{
|
760 |
+
a.x *= b.x;
|
761 |
+
a.y *= b.y;
|
762 |
+
}
|
763 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
764 |
+
{
|
765 |
+
return make_int2(a.x * b, a.y * b);
|
766 |
+
}
|
767 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
768 |
+
{
|
769 |
+
return make_int2(b * a.x, b * a.y);
|
770 |
+
}
|
771 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
772 |
+
{
|
773 |
+
a.x *= b;
|
774 |
+
a.y *= b;
|
775 |
+
}
|
776 |
+
|
777 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
778 |
+
{
|
779 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
780 |
+
}
|
781 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
782 |
+
{
|
783 |
+
a.x *= b.x;
|
784 |
+
a.y *= b.y;
|
785 |
+
}
|
786 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
787 |
+
{
|
788 |
+
return make_uint2(a.x * b, a.y * b);
|
789 |
+
}
|
790 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
791 |
+
{
|
792 |
+
return make_uint2(b * a.x, b * a.y);
|
793 |
+
}
|
794 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
795 |
+
{
|
796 |
+
a.x *= b;
|
797 |
+
a.y *= b;
|
798 |
+
}
|
799 |
+
|
800 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
801 |
+
{
|
802 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
803 |
+
}
|
804 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
805 |
+
{
|
806 |
+
a.x *= b.x;
|
807 |
+
a.y *= b.y;
|
808 |
+
a.z *= b.z;
|
809 |
+
}
|
810 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
811 |
+
{
|
812 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
813 |
+
}
|
814 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
815 |
+
{
|
816 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
817 |
+
}
|
818 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
819 |
+
{
|
820 |
+
a.x *= b;
|
821 |
+
a.y *= b;
|
822 |
+
a.z *= b;
|
823 |
+
}
|
824 |
+
|
825 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
826 |
+
{
|
827 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
828 |
+
}
|
829 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
830 |
+
{
|
831 |
+
a.x *= b.x;
|
832 |
+
a.y *= b.y;
|
833 |
+
a.z *= b.z;
|
834 |
+
}
|
835 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
836 |
+
{
|
837 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
838 |
+
}
|
839 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
840 |
+
{
|
841 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
842 |
+
}
|
843 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
844 |
+
{
|
845 |
+
a.x *= b;
|
846 |
+
a.y *= b;
|
847 |
+
a.z *= b;
|
848 |
+
}
|
849 |
+
|
850 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
851 |
+
{
|
852 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
853 |
+
}
|
854 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
855 |
+
{
|
856 |
+
a.x *= b.x;
|
857 |
+
a.y *= b.y;
|
858 |
+
a.z *= b.z;
|
859 |
+
}
|
860 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
861 |
+
{
|
862 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
863 |
+
}
|
864 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
865 |
+
{
|
866 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
867 |
+
}
|
868 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
869 |
+
{
|
870 |
+
a.x *= b;
|
871 |
+
a.y *= b;
|
872 |
+
a.z *= b;
|
873 |
+
}
|
874 |
+
|
875 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
876 |
+
{
|
877 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
878 |
+
}
|
879 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
880 |
+
{
|
881 |
+
a.x *= b.x;
|
882 |
+
a.y *= b.y;
|
883 |
+
a.z *= b.z;
|
884 |
+
a.w *= b.w;
|
885 |
+
}
|
886 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
887 |
+
{
|
888 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
889 |
+
}
|
890 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
891 |
+
{
|
892 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
893 |
+
}
|
894 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
895 |
+
{
|
896 |
+
a.x *= b;
|
897 |
+
a.y *= b;
|
898 |
+
a.z *= b;
|
899 |
+
a.w *= b;
|
900 |
+
}
|
901 |
+
|
902 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
903 |
+
{
|
904 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
905 |
+
}
|
906 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
907 |
+
{
|
908 |
+
a.x *= b.x;
|
909 |
+
a.y *= b.y;
|
910 |
+
a.z *= b.z;
|
911 |
+
a.w *= b.w;
|
912 |
+
}
|
913 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
914 |
+
{
|
915 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
916 |
+
}
|
917 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
918 |
+
{
|
919 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
920 |
+
}
|
921 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
922 |
+
{
|
923 |
+
a.x *= b;
|
924 |
+
a.y *= b;
|
925 |
+
a.z *= b;
|
926 |
+
a.w *= b;
|
927 |
+
}
|
928 |
+
|
929 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
930 |
+
{
|
931 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
932 |
+
}
|
933 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
934 |
+
{
|
935 |
+
a.x *= b.x;
|
936 |
+
a.y *= b.y;
|
937 |
+
a.z *= b.z;
|
938 |
+
a.w *= b.w;
|
939 |
+
}
|
940 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
941 |
+
{
|
942 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
943 |
+
}
|
944 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
945 |
+
{
|
946 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
947 |
+
}
|
948 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
949 |
+
{
|
950 |
+
a.x *= b;
|
951 |
+
a.y *= b;
|
952 |
+
a.z *= b;
|
953 |
+
a.w *= b;
|
954 |
+
}
|
955 |
+
|
956 |
+
////////////////////////////////////////////////////////////////////////////////
|
957 |
+
// divide
|
958 |
+
////////////////////////////////////////////////////////////////////////////////
|
959 |
+
|
960 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
961 |
+
{
|
962 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
963 |
+
}
|
964 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
965 |
+
{
|
966 |
+
a.x /= b.x;
|
967 |
+
a.y /= b.y;
|
968 |
+
}
|
969 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
970 |
+
{
|
971 |
+
return make_float2(a.x / b, a.y / b);
|
972 |
+
}
|
973 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
974 |
+
{
|
975 |
+
a.x /= b;
|
976 |
+
a.y /= b;
|
977 |
+
}
|
978 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
979 |
+
{
|
980 |
+
return make_float2(b / a.x, b / a.y);
|
981 |
+
}
|
982 |
+
|
983 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
984 |
+
{
|
985 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
986 |
+
}
|
987 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
988 |
+
{
|
989 |
+
a.x /= b.x;
|
990 |
+
a.y /= b.y;
|
991 |
+
a.z /= b.z;
|
992 |
+
}
|
993 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
994 |
+
{
|
995 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
996 |
+
}
|
997 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
998 |
+
{
|
999 |
+
a.x /= b;
|
1000 |
+
a.y /= b;
|
1001 |
+
a.z /= b;
|
1002 |
+
}
|
1003 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
1004 |
+
{
|
1005 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
1006 |
+
}
|
1007 |
+
|
1008 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
1009 |
+
{
|
1010 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
1011 |
+
}
|
1012 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
1013 |
+
{
|
1014 |
+
a.x /= b.x;
|
1015 |
+
a.y /= b.y;
|
1016 |
+
a.z /= b.z;
|
1017 |
+
a.w /= b.w;
|
1018 |
+
}
|
1019 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
1020 |
+
{
|
1021 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
1022 |
+
}
|
1023 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
1024 |
+
{
|
1025 |
+
a.x /= b;
|
1026 |
+
a.y /= b;
|
1027 |
+
a.z /= b;
|
1028 |
+
a.w /= b;
|
1029 |
+
}
|
1030 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
1031 |
+
{
|
1032 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
1033 |
+
}
|
1034 |
+
|
1035 |
+
////////////////////////////////////////////////////////////////////////////////
|
1036 |
+
// min
|
1037 |
+
////////////////////////////////////////////////////////////////////////////////
|
1038 |
+
|
1039 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
1040 |
+
{
|
1041 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
1042 |
+
}
|
1043 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
1044 |
+
{
|
1045 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
1046 |
+
}
|
1047 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
1048 |
+
{
|
1049 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
1050 |
+
}
|
1051 |
+
|
1052 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
1053 |
+
{
|
1054 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
1055 |
+
}
|
1056 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
1057 |
+
{
|
1058 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1059 |
+
}
|
1060 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
1061 |
+
{
|
1062 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
1066 |
+
{
|
1067 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
1068 |
+
}
|
1069 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
1070 |
+
{
|
1071 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
1072 |
+
}
|
1073 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
1074 |
+
{
|
1075 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
////////////////////////////////////////////////////////////////////////////////
|
1079 |
+
// max
|
1080 |
+
////////////////////////////////////////////////////////////////////////////////
|
1081 |
+
|
1082 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
1083 |
+
{
|
1084 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
1085 |
+
}
|
1086 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
1087 |
+
{
|
1088 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
1089 |
+
}
|
1090 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
1091 |
+
{
|
1092 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
1093 |
+
}
|
1094 |
+
|
1095 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
1096 |
+
{
|
1097 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
1098 |
+
}
|
1099 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
1100 |
+
{
|
1101 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1102 |
+
}
|
1103 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
1104 |
+
{
|
1105 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1106 |
+
}
|
1107 |
+
|
1108 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
1109 |
+
{
|
1110 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
1111 |
+
}
|
1112 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
1113 |
+
{
|
1114 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
1115 |
+
}
|
1116 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
1117 |
+
{
|
1118 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
1119 |
+
}
|
1120 |
+
|
1121 |
+
////////////////////////////////////////////////////////////////////////////////
|
1122 |
+
// lerp
|
1123 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
1124 |
+
////////////////////////////////////////////////////////////////////////////////
|
1125 |
+
|
1126 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
1127 |
+
{
|
1128 |
+
return a + t*(b-a);
|
1129 |
+
}
|
1130 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
1131 |
+
{
|
1132 |
+
return a + t*(b-a);
|
1133 |
+
}
|
1134 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
1135 |
+
{
|
1136 |
+
return a + t*(b-a);
|
1137 |
+
}
|
1138 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
1139 |
+
{
|
1140 |
+
return a + t*(b-a);
|
1141 |
+
}
|
1142 |
+
|
1143 |
+
////////////////////////////////////////////////////////////////////////////////
|
1144 |
+
// clamp
|
1145 |
+
// - clamp the value v to be in the range [a, b]
|
1146 |
+
////////////////////////////////////////////////////////////////////////////////
|
1147 |
+
|
1148 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
1149 |
+
{
|
1150 |
+
return fmaxf(a, fminf(f, b));
|
1151 |
+
}
|
1152 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
1153 |
+
{
|
1154 |
+
return max(a, min(f, b));
|
1155 |
+
}
|
1156 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
1157 |
+
{
|
1158 |
+
return max(a, min(f, b));
|
1159 |
+
}
|
1160 |
+
|
1161 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
1162 |
+
{
|
1163 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1164 |
+
}
|
1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
1166 |
+
{
|
1167 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1168 |
+
}
|
1169 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
1170 |
+
{
|
1171 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1172 |
+
}
|
1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
1174 |
+
{
|
1175 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1176 |
+
}
|
1177 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
1178 |
+
{
|
1179 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1180 |
+
}
|
1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
1182 |
+
{
|
1183 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1184 |
+
}
|
1185 |
+
|
1186 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
1187 |
+
{
|
1188 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1189 |
+
}
|
1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
1191 |
+
{
|
1192 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1193 |
+
}
|
1194 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
1195 |
+
{
|
1196 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1197 |
+
}
|
1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
1199 |
+
{
|
1200 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1201 |
+
}
|
1202 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
1203 |
+
{
|
1204 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1205 |
+
}
|
1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
1207 |
+
{
|
1208 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1209 |
+
}
|
1210 |
+
|
1211 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
1212 |
+
{
|
1213 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
1214 |
+
}
|
1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
1216 |
+
{
|
1217 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
1218 |
+
}
|
1219 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
1220 |
+
{
|
1221 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
1222 |
+
}
|
1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
1224 |
+
{
|
1225 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
1226 |
+
}
|
1227 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
1228 |
+
{
|
1229 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
1230 |
+
}
|
1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
1232 |
+
{
|
1233 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
1234 |
+
}
|
1235 |
+
|
1236 |
+
////////////////////////////////////////////////////////////////////////////////
|
1237 |
+
// dot product
|
1238 |
+
////////////////////////////////////////////////////////////////////////////////
|
1239 |
+
|
1240 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
1241 |
+
{
|
1242 |
+
return a.x * b.x + a.y * b.y;
|
1243 |
+
}
|
1244 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
1245 |
+
{
|
1246 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1247 |
+
}
|
1248 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
1249 |
+
{
|
1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1251 |
+
}
|
1252 |
+
|
1253 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
1254 |
+
{
|
1255 |
+
return a.x * b.x + a.y * b.y;
|
1256 |
+
}
|
1257 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
1258 |
+
{
|
1259 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1260 |
+
}
|
1261 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
1262 |
+
{
|
1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1264 |
+
}
|
1265 |
+
|
1266 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
1267 |
+
{
|
1268 |
+
return a.x * b.x + a.y * b.y;
|
1269 |
+
}
|
1270 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
1271 |
+
{
|
1272 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
1273 |
+
}
|
1274 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
1275 |
+
{
|
1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
1277 |
+
}
|
1278 |
+
|
1279 |
+
////////////////////////////////////////////////////////////////////////////////
|
1280 |
+
// length
|
1281 |
+
////////////////////////////////////////////////////////////////////////////////
|
1282 |
+
|
1283 |
+
inline __host__ __device__ float length(float2 v)
|
1284 |
+
{
|
1285 |
+
return sqrtf(dot(v, v));
|
1286 |
+
}
|
1287 |
+
inline __host__ __device__ float length(float3 v)
|
1288 |
+
{
|
1289 |
+
return sqrtf(dot(v, v));
|
1290 |
+
}
|
1291 |
+
inline __host__ __device__ float length(float4 v)
|
1292 |
+
{
|
1293 |
+
return sqrtf(dot(v, v));
|
1294 |
+
}
|
1295 |
+
|
1296 |
+
////////////////////////////////////////////////////////////////////////////////
|
1297 |
+
// normalize
|
1298 |
+
////////////////////////////////////////////////////////////////////////////////
|
1299 |
+
|
1300 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
1301 |
+
{
|
1302 |
+
float invLen = rsqrtf(dot(v, v));
|
1303 |
+
return v * invLen;
|
1304 |
+
}
|
1305 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
1306 |
+
{
|
1307 |
+
float invLen = rsqrtf(dot(v, v));
|
1308 |
+
return v * invLen;
|
1309 |
+
}
|
1310 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
1311 |
+
{
|
1312 |
+
float invLen = rsqrtf(dot(v, v));
|
1313 |
+
return v * invLen;
|
1314 |
+
}
|
1315 |
+
|
1316 |
+
////////////////////////////////////////////////////////////////////////////////
|
1317 |
+
// floor
|
1318 |
+
////////////////////////////////////////////////////////////////////////////////
|
1319 |
+
|
1320 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
1321 |
+
{
|
1322 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
1323 |
+
}
|
1324 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
1325 |
+
{
|
1326 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
1327 |
+
}
|
1328 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
1329 |
+
{
|
1330 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
1331 |
+
}
|
1332 |
+
|
1333 |
+
////////////////////////////////////////////////////////////////////////////////
|
1334 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
1335 |
+
////////////////////////////////////////////////////////////////////////////////
|
1336 |
+
|
1337 |
+
inline __host__ __device__ float fracf(float v)
|
1338 |
+
{
|
1339 |
+
return v - floorf(v);
|
1340 |
+
}
|
1341 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
1342 |
+
{
|
1343 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
1344 |
+
}
|
1345 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
1346 |
+
{
|
1347 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
1348 |
+
}
|
1349 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
1350 |
+
{
|
1351 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
1352 |
+
}
|
1353 |
+
|
1354 |
+
////////////////////////////////////////////////////////////////////////////////
|
1355 |
+
// fmod
|
1356 |
+
////////////////////////////////////////////////////////////////////////////////
|
1357 |
+
|
1358 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
1359 |
+
{
|
1360 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
1361 |
+
}
|
1362 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
1363 |
+
{
|
1364 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
1365 |
+
}
|
1366 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
1367 |
+
{
|
1368 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
1369 |
+
}
|
1370 |
+
|
1371 |
+
////////////////////////////////////////////////////////////////////////////////
|
1372 |
+
// absolute value
|
1373 |
+
////////////////////////////////////////////////////////////////////////////////
|
1374 |
+
|
1375 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
1376 |
+
{
|
1377 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
1378 |
+
}
|
1379 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
1380 |
+
{
|
1381 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
1382 |
+
}
|
1383 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
1384 |
+
{
|
1385 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
1386 |
+
}
|
1387 |
+
|
1388 |
+
inline __host__ __device__ int2 abs(int2 v)
|
1389 |
+
{
|
1390 |
+
return make_int2(abs(v.x), abs(v.y));
|
1391 |
+
}
|
1392 |
+
inline __host__ __device__ int3 abs(int3 v)
|
1393 |
+
{
|
1394 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
1395 |
+
}
|
1396 |
+
inline __host__ __device__ int4 abs(int4 v)
|
1397 |
+
{
|
1398 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
1399 |
+
}
|
1400 |
+
|
1401 |
+
////////////////////////////////////////////////////////////////////////////////
|
1402 |
+
// reflect
|
1403 |
+
// - returns reflection of incident ray I around surface normal N
|
1404 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
1406 |
+
|
1407 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
1408 |
+
{
|
1409 |
+
return i - 2.0f * n * dot(n,i);
|
1410 |
+
}
|
1411 |
+
|
1412 |
+
////////////////////////////////////////////////////////////////////////////////
|
1413 |
+
// cross product
|
1414 |
+
////////////////////////////////////////////////////////////////////////////////
|
1415 |
+
|
1416 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
1417 |
+
{
|
1418 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
1419 |
+
}
|
1420 |
+
|
1421 |
+
////////////////////////////////////////////////////////////////////////////////
|
1422 |
+
// smoothstep
|
1423 |
+
// - returns 0 if x < a
|
1424 |
+
// - returns 1 if x > b
|
1425 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
1426 |
+
////////////////////////////////////////////////////////////////////////////////
|
1427 |
+
|
1428 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
1429 |
+
{
|
1430 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1431 |
+
return (y*y*(3.0f - (2.0f*y)));
|
1432 |
+
}
|
1433 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
1434 |
+
{
|
1435 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1436 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
1437 |
+
}
|
1438 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
1439 |
+
{
|
1440 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1441 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
1442 |
+
}
|
1443 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
1444 |
+
{
|
1445 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
1446 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
1447 |
+
}
|
1448 |
+
|
1449 |
+
#endif
|
utils/io_utils.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json, os, sys
|
3 |
+
import os.path as osp
|
4 |
+
from typing import List, Union, Tuple, Dict
|
5 |
+
from pathlib import Path
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from imageio import imread, imwrite
|
9 |
+
import pickle
|
10 |
+
import pycocotools.mask as maskUtils
|
11 |
+
from einops import rearrange
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
import io
|
15 |
+
import requests
|
16 |
+
import traceback
|
17 |
+
import base64
|
18 |
+
import time
|
19 |
+
|
20 |
+
|
21 |
+
NP_BOOL_TYPES = (np.bool_, np.bool8)
|
22 |
+
NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
|
23 |
+
NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
|
24 |
+
|
25 |
+
class NumpyEncoder(json.JSONEncoder):
|
26 |
+
def default(self, obj):
|
27 |
+
if isinstance(obj, np.ndarray):
|
28 |
+
return obj.tolist()
|
29 |
+
elif isinstance(obj, np.ScalarType):
|
30 |
+
if isinstance(obj, NP_BOOL_TYPES):
|
31 |
+
return bool(obj)
|
32 |
+
elif isinstance(obj, NP_FLOAT_TYPES):
|
33 |
+
return float(obj)
|
34 |
+
elif isinstance(obj, NP_INT_TYPES):
|
35 |
+
return int(obj)
|
36 |
+
return json.JSONEncoder.default(self, obj)
|
37 |
+
|
38 |
+
|
39 |
+
def json2dict(json_path: str):
|
40 |
+
with open(json_path, 'r', encoding='utf8') as f:
|
41 |
+
metadata = json.loads(f.read())
|
42 |
+
return metadata
|
43 |
+
|
44 |
+
|
45 |
+
def dict2json(adict: dict, json_path: str):
|
46 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
47 |
+
f.write(json.dumps(adict, ensure_ascii=False, cls=NumpyEncoder))
|
48 |
+
|
49 |
+
|
50 |
+
def dict2pickle(dumped_path: str, tgt_dict: dict):
|
51 |
+
with open(dumped_path, "wb") as f:
|
52 |
+
pickle.dump(tgt_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
|
53 |
+
|
54 |
+
|
55 |
+
def pickle2dict(pkl_path: str) -> Dict:
|
56 |
+
with open(pkl_path, "rb") as f:
|
57 |
+
dumped_data = pickle.load(f)
|
58 |
+
return dumped_data
|
59 |
+
|
60 |
+
def get_all_dirs(root_p: str) -> List[str]:
|
61 |
+
alldir = os.listdir(root_p)
|
62 |
+
dirlist = []
|
63 |
+
for dirp in alldir:
|
64 |
+
dirp = osp.join(root_p, dirp)
|
65 |
+
if osp.isdir(dirp):
|
66 |
+
dirlist.append(dirp)
|
67 |
+
return dirlist
|
68 |
+
|
69 |
+
|
70 |
+
def read_filelist(filelistp: str):
|
71 |
+
with open(filelistp, 'r', encoding='utf8') as f:
|
72 |
+
lines = f.readlines()
|
73 |
+
if len(lines) > 0 and lines[-1].strip() == '':
|
74 |
+
lines = lines[:-1]
|
75 |
+
return lines
|
76 |
+
|
77 |
+
|
78 |
+
VIDEO_EXTS = {'.flv', '.mp4', '.mkv', '.ts', '.mov', 'mpeg'}
|
79 |
+
def get_all_videos(video_dir: str, video_exts=VIDEO_EXTS, abs_path=False) -> List[str]:
|
80 |
+
filelist = os.listdir(video_dir)
|
81 |
+
vlist = []
|
82 |
+
for f in filelist:
|
83 |
+
if Path(f).suffix in video_exts:
|
84 |
+
if abs_path:
|
85 |
+
vlist.append(osp.join(video_dir, f))
|
86 |
+
else:
|
87 |
+
vlist.append(f)
|
88 |
+
return vlist
|
89 |
+
|
90 |
+
|
91 |
+
IMG_EXT = {'.bmp', '.jpg', '.png', '.jpeg'}
|
92 |
+
def find_all_imgs(img_dir, abs_path=False):
|
93 |
+
imglist = []
|
94 |
+
dir_list = os.listdir(img_dir)
|
95 |
+
for filename in dir_list:
|
96 |
+
file_suffix = Path(filename).suffix
|
97 |
+
if file_suffix.lower() not in IMG_EXT:
|
98 |
+
continue
|
99 |
+
if abs_path:
|
100 |
+
imglist.append(osp.join(img_dir, filename))
|
101 |
+
else:
|
102 |
+
imglist.append(filename)
|
103 |
+
return imglist
|
104 |
+
|
105 |
+
|
106 |
+
def find_all_files_recursive(tgt_dir: Union[List, str], ext, exclude_dirs={}):
|
107 |
+
if isinstance(tgt_dir, str):
|
108 |
+
tgt_dir = [tgt_dir]
|
109 |
+
|
110 |
+
filelst = []
|
111 |
+
for d in tgt_dir:
|
112 |
+
for root, _, files in os.walk(d):
|
113 |
+
if osp.basename(root) in exclude_dirs:
|
114 |
+
continue
|
115 |
+
for f in files:
|
116 |
+
if Path(f).suffix.lower() in ext:
|
117 |
+
filelst.append(osp.join(root, f))
|
118 |
+
|
119 |
+
return filelst
|
120 |
+
|
121 |
+
|
122 |
+
def danbooruid2relpath(id_str: str, file_ext='.jpg'):
|
123 |
+
if not isinstance(id_str, str):
|
124 |
+
id_str = str(id_str)
|
125 |
+
return id_str[-3:].zfill(4) + '/' + id_str + file_ext
|
126 |
+
|
127 |
+
|
128 |
+
def get_template_histvq(template: np.ndarray) -> Tuple[List[np.ndarray]]:
|
129 |
+
len_shape = len(template.shape)
|
130 |
+
num_c = 3
|
131 |
+
mask = None
|
132 |
+
if len_shape == 2:
|
133 |
+
num_c = 1
|
134 |
+
elif len_shape == 3 and template.shape[-1] == 4:
|
135 |
+
mask = np.where(template[..., -1])
|
136 |
+
template = template[..., :num_c][mask]
|
137 |
+
|
138 |
+
values, quantiles = [], []
|
139 |
+
for ii in range(num_c):
|
140 |
+
v, c = np.unique(template[..., ii].ravel(), return_counts=True)
|
141 |
+
q = np.cumsum(c).astype(np.float64)
|
142 |
+
if len(q) < 1:
|
143 |
+
return None, None
|
144 |
+
q /= q[-1]
|
145 |
+
values.append(v)
|
146 |
+
quantiles.append(q)
|
147 |
+
return values, quantiles
|
148 |
+
|
149 |
+
|
150 |
+
def inplace_hist_matching(img: np.ndarray, tv: List[np.ndarray], tq: List[np.ndarray]) -> None:
|
151 |
+
len_shape = len(img.shape)
|
152 |
+
num_c = 3
|
153 |
+
mask = None
|
154 |
+
|
155 |
+
tgtimg = img
|
156 |
+
if len_shape == 2:
|
157 |
+
num_c = 1
|
158 |
+
elif len_shape == 3 and img.shape[-1] == 4:
|
159 |
+
mask = np.where(img[..., -1])
|
160 |
+
tgtimg = img[..., :num_c][mask]
|
161 |
+
|
162 |
+
im_h, im_w = img.shape[:2]
|
163 |
+
oldtype = img.dtype
|
164 |
+
for ii in range(num_c):
|
165 |
+
_, bin_idx, s_counts = np.unique(tgtimg[..., ii].ravel(), return_inverse=True,
|
166 |
+
return_counts=True)
|
167 |
+
s_quantiles = np.cumsum(s_counts).astype(np.float64)
|
168 |
+
if len(s_quantiles) == 0:
|
169 |
+
return
|
170 |
+
s_quantiles /= s_quantiles[-1]
|
171 |
+
interp_t_values = np.interp(s_quantiles, tq[ii], tv[ii]).astype(oldtype)
|
172 |
+
if mask is not None:
|
173 |
+
img[..., ii][mask] = interp_t_values[bin_idx]
|
174 |
+
else:
|
175 |
+
img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
|
176 |
+
# try:
|
177 |
+
# img[..., ii] = interp_t_values[bin_idx].reshape((im_h, im_w))
|
178 |
+
# except:
|
179 |
+
# LOGGER.error('##################### sth goes wrong')
|
180 |
+
# cv2.imshow('img', img)
|
181 |
+
# cv2.waitKey(0)
|
182 |
+
|
183 |
+
|
184 |
+
def fgbg_hist_matching(fg_list: List, bg: np.ndarray, min_tq_num=128):
|
185 |
+
btv, btq = get_template_histvq(bg)
|
186 |
+
ftv, ftq = get_template_histvq(fg_list[0]['image'])
|
187 |
+
num_fg = len(fg_list)
|
188 |
+
idx_matched = -1
|
189 |
+
if num_fg > 1:
|
190 |
+
_ftv, _ftq = get_template_histvq(fg_list[0]['image'])
|
191 |
+
if _ftq is not None and ftq is not None:
|
192 |
+
if len(_ftq[0]) > len(ftq[0]):
|
193 |
+
idx_matched = num_fg - 1
|
194 |
+
ftv, ftq = _ftv, _ftq
|
195 |
+
else:
|
196 |
+
idx_matched = 0
|
197 |
+
|
198 |
+
if btq is not None and ftq is not None:
|
199 |
+
if len(btq[0]) > len(ftq[0]):
|
200 |
+
tv, tq = btv, btq
|
201 |
+
idx_matched = -1
|
202 |
+
else:
|
203 |
+
tv, tq = ftv, ftq
|
204 |
+
if len(tq[0]) > min_tq_num:
|
205 |
+
inplace_hist_matching(bg, tv, tq)
|
206 |
+
|
207 |
+
if len(tq[0]) > min_tq_num:
|
208 |
+
for ii, fg_dict in enumerate(fg_list):
|
209 |
+
fg = fg_dict['image']
|
210 |
+
if ii != idx_matched and len(tq[0]) > min_tq_num:
|
211 |
+
inplace_hist_matching(fg, tv, tq)
|
212 |
+
|
213 |
+
|
214 |
+
def imread_nogrey_rgb(imp: str) -> np.ndarray:
|
215 |
+
img: np.ndarray = imread(imp)
|
216 |
+
c = 1
|
217 |
+
if len(img.shape) == 3:
|
218 |
+
c = img.shape[-1]
|
219 |
+
if c == 1:
|
220 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
221 |
+
if c == 4:
|
222 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
223 |
+
return img
|
224 |
+
|
225 |
+
|
226 |
+
def square_pad_resize(img: np.ndarray, tgt_size: int, pad_value: Tuple = (114, 114, 114)):
|
227 |
+
h, w = img.shape[:2]
|
228 |
+
pad_h, pad_w = 0, 0
|
229 |
+
|
230 |
+
# make square image
|
231 |
+
if w < h:
|
232 |
+
pad_w = h - w
|
233 |
+
w += pad_w
|
234 |
+
elif h < w:
|
235 |
+
pad_h = w - h
|
236 |
+
h += pad_h
|
237 |
+
|
238 |
+
pad_size = tgt_size - h
|
239 |
+
if pad_size > 0:
|
240 |
+
pad_h += pad_size
|
241 |
+
pad_w += pad_size
|
242 |
+
|
243 |
+
if pad_h > 0 or pad_w > 0:
|
244 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=pad_value)
|
245 |
+
|
246 |
+
down_scale_ratio = tgt_size / img.shape[0]
|
247 |
+
assert down_scale_ratio <= 1
|
248 |
+
if down_scale_ratio < 1:
|
249 |
+
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
|
250 |
+
|
251 |
+
return img, down_scale_ratio, pad_h, pad_w
|
252 |
+
|
253 |
+
|
254 |
+
def scaledown_maxsize(img: np.ndarray, max_size: int, divisior: int = None):
|
255 |
+
|
256 |
+
im_h, im_w = img.shape[:2]
|
257 |
+
ori_h, ori_w = img.shape[:2]
|
258 |
+
resize_ratio = max_size / max(im_h, im_w)
|
259 |
+
if resize_ratio < 1:
|
260 |
+
if im_h > im_w:
|
261 |
+
im_h = max_size
|
262 |
+
im_w = max(1, int(round(im_w * resize_ratio)))
|
263 |
+
|
264 |
+
else:
|
265 |
+
im_w = max_size
|
266 |
+
im_h = max(1, int(round(im_h * resize_ratio)))
|
267 |
+
if divisior is not None:
|
268 |
+
im_w = int(np.ceil(im_w / divisior) * divisior)
|
269 |
+
im_h = int(np.ceil(im_h / divisior) * divisior)
|
270 |
+
|
271 |
+
if im_w != ori_w or im_h != ori_h:
|
272 |
+
img = cv2.resize(img, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
|
273 |
+
|
274 |
+
return img
|
275 |
+
|
276 |
+
|
277 |
+
def resize_pad(img: np.ndarray, tgt_size: int, pad_value: Tuple = (0, 0, 0)):
|
278 |
+
# downscale to tgt_size and pad to square
|
279 |
+
img = scaledown_maxsize(img, tgt_size)
|
280 |
+
padl, padr, padt, padb = 0, 0, 0, 0
|
281 |
+
h, w = img.shape[:2]
|
282 |
+
# padt = (tgt_size - h) // 2
|
283 |
+
# padb = tgt_size - h - padt
|
284 |
+
# padl = (tgt_size - w) // 2
|
285 |
+
# padr = tgt_size - w - padl
|
286 |
+
padb = tgt_size - h
|
287 |
+
padr = tgt_size - w
|
288 |
+
|
289 |
+
if padt + padb + padl + padr > 0:
|
290 |
+
img = cv2.copyMakeBorder(img, padt, padb, padl, padr, cv2.BORDER_CONSTANT, value=pad_value)
|
291 |
+
|
292 |
+
return img, (padt, padb, padl, padr)
|
293 |
+
|
294 |
+
|
295 |
+
def resize_pad2divisior(img: np.ndarray, tgt_size: int, divisior: int = 64, pad_value: Tuple = (0, 0, 0)):
|
296 |
+
img = scaledown_maxsize(img, tgt_size)
|
297 |
+
img, (pad_h, pad_w) = pad2divisior(img, divisior, pad_value)
|
298 |
+
return img, (pad_h, pad_w)
|
299 |
+
|
300 |
+
|
301 |
+
def img2grey(img: Union[np.ndarray, str], is_rgb: bool = False) -> np.ndarray:
|
302 |
+
if isinstance(img, np.ndarray):
|
303 |
+
if len(img.shape) == 3:
|
304 |
+
if img.shape[-1] != 1:
|
305 |
+
if is_rgb:
|
306 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
307 |
+
else:
|
308 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
309 |
+
else:
|
310 |
+
img = img[..., 0]
|
311 |
+
return img
|
312 |
+
elif isinstance(img, str):
|
313 |
+
return cv2.imread(img, cv2.IMREAD_GRAYSCALE)
|
314 |
+
else:
|
315 |
+
raise NotImplementedError
|
316 |
+
|
317 |
+
|
318 |
+
def pad2divisior(img: np.ndarray, divisior: int, value = (0, 0, 0)) -> np.ndarray:
|
319 |
+
im_h, im_w = img.shape[:2]
|
320 |
+
pad_h = int(np.ceil(im_h / divisior)) * divisior - im_h
|
321 |
+
pad_w = int(np.ceil(im_w / divisior)) * divisior - im_w
|
322 |
+
if pad_h != 0 or pad_w != 0:
|
323 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, value=value, borderType=cv2.BORDER_CONSTANT)
|
324 |
+
return img, (pad_h, pad_w)
|
325 |
+
|
326 |
+
|
327 |
+
def mask2rle(mask: np.ndarray, decode_for_json: bool = True) -> Dict:
|
328 |
+
mask_rle = maskUtils.encode(np.array(
|
329 |
+
mask[..., np.newaxis] > 0, order='F',
|
330 |
+
dtype='uint8'))[0]
|
331 |
+
if decode_for_json:
|
332 |
+
mask_rle['counts'] = mask_rle['counts'].decode()
|
333 |
+
return mask_rle
|
334 |
+
|
335 |
+
|
336 |
+
def bbox2xyxy(box) -> Tuple[int]:
|
337 |
+
x1, y1 = box[0], box[1]
|
338 |
+
return x1, y1, x1+box[2], y1+box[3]
|
339 |
+
|
340 |
+
|
341 |
+
def bbox_overlap_area(abox, boxb) -> int:
|
342 |
+
ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
|
343 |
+
bx1, by1, bx2, by2 = bbox2xyxy(boxb)
|
344 |
+
|
345 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
346 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
347 |
+
|
348 |
+
if ix > 0 and iy > 0:
|
349 |
+
return ix * iy
|
350 |
+
else:
|
351 |
+
return 0
|
352 |
+
|
353 |
+
|
354 |
+
def bbox_overlap_xy(abox, boxb) -> Tuple[int]:
|
355 |
+
ax1, ay1, ax2, ay2 = bbox2xyxy(abox)
|
356 |
+
bx1, by1, bx2, by2 = bbox2xyxy(boxb)
|
357 |
+
|
358 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
359 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
360 |
+
|
361 |
+
return ix, iy
|
362 |
+
|
363 |
+
|
364 |
+
def xyxy_overlap_area(axyxy, bxyxy) -> int:
|
365 |
+
ax1, ay1, ax2, ay2 = axyxy
|
366 |
+
bx1, by1, bx2, by2 = bxyxy
|
367 |
+
|
368 |
+
ix = min(ax2, bx2) - max(ax1, bx1)
|
369 |
+
iy = min(ay2, by2) - max(ay1, by1)
|
370 |
+
|
371 |
+
if ix > 0 and iy > 0:
|
372 |
+
return ix * iy
|
373 |
+
else:
|
374 |
+
return 0
|
375 |
+
|
376 |
+
|
377 |
+
DIRNAME2TAG = {'rezero': 're:zero'}
|
378 |
+
def dirname2charactername(dirname, start=6):
|
379 |
+
cname = dirname[start:]
|
380 |
+
for k, v in DIRNAME2TAG.items():
|
381 |
+
cname = cname.replace(k, v)
|
382 |
+
return cname
|
383 |
+
|
384 |
+
|
385 |
+
def imglist2grid(imglist: np.ndarray, grid_size: int = 384, col=None) -> np.ndarray:
|
386 |
+
sqimlist = []
|
387 |
+
for img in imglist:
|
388 |
+
sqimlist.append(square_pad_resize(img, grid_size)[0])
|
389 |
+
|
390 |
+
nimg = len(imglist)
|
391 |
+
if nimg == 0:
|
392 |
+
return None
|
393 |
+
padn = 0
|
394 |
+
if col is None:
|
395 |
+
if nimg > 5:
|
396 |
+
row = int(np.round(np.sqrt(nimg)))
|
397 |
+
col = int(np.ceil(nimg / row))
|
398 |
+
else:
|
399 |
+
col = nimg
|
400 |
+
|
401 |
+
padn = int(np.ceil(nimg / col) * col) - nimg
|
402 |
+
if padn != 0:
|
403 |
+
padimg = np.zeros_like(sqimlist[0])
|
404 |
+
for _ in range(padn):
|
405 |
+
sqimlist.append(padimg)
|
406 |
+
|
407 |
+
return rearrange(sqimlist, '(row col) h w c -> (row h) (col w) c', col=col)
|
408 |
+
|
409 |
+
def write_jsonlines(filep: str, dict_lst: List[str], progress_bar: bool = True):
|
410 |
+
with open(filep, 'w') as out:
|
411 |
+
if progress_bar:
|
412 |
+
lst = tqdm(dict_lst)
|
413 |
+
else:
|
414 |
+
lst = dict_lst
|
415 |
+
for ddict in lst:
|
416 |
+
jout = json.dumps(ddict) + '\n'
|
417 |
+
out.write(jout)
|
418 |
+
|
419 |
+
def read_jsonlines(filep: str):
|
420 |
+
with open(filep, 'r', encoding='utf8') as f:
|
421 |
+
result = [json.loads(jline) for jline in f.read().splitlines()]
|
422 |
+
return result
|
423 |
+
|
424 |
+
|
425 |
+
def _b64encode(x: bytes) -> str:
|
426 |
+
return base64.b64encode(x).decode("utf-8")
|
427 |
+
|
428 |
+
|
429 |
+
def img2b64(img):
|
430 |
+
"""
|
431 |
+
Convert a PIL image to a base64-encoded string.
|
432 |
+
"""
|
433 |
+
if isinstance(img, np.ndarray):
|
434 |
+
img = Image.fromarray(img)
|
435 |
+
buffered = io.BytesIO()
|
436 |
+
img.save(buffered, format='PNG')
|
437 |
+
return _b64encode(buffered.getvalue())
|
438 |
+
|
439 |
+
|
440 |
+
def save_encoded_image(b64_image: str, output_path: str):
|
441 |
+
with open(output_path, "wb") as image_file:
|
442 |
+
image_file.write(base64.b64decode(b64_image))
|
443 |
+
|
444 |
+
def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 30):
|
445 |
+
response = None
|
446 |
+
try:
|
447 |
+
while True:
|
448 |
+
try:
|
449 |
+
response = requests.post(url, data=data, auth=auth)
|
450 |
+
response.raise_for_status()
|
451 |
+
break
|
452 |
+
except Exception as e:
|
453 |
+
if wait_time > 0:
|
454 |
+
print(traceback.format_exc(), file=sys.stderr)
|
455 |
+
print(f'sleep {wait_time} sec...')
|
456 |
+
time.sleep(wait_time)
|
457 |
+
continue
|
458 |
+
else:
|
459 |
+
raise e
|
460 |
+
except Exception as e:
|
461 |
+
print(traceback.format_exc(), file=sys.stderr)
|
462 |
+
if response is not None:
|
463 |
+
print('response content: ' + response.text)
|
464 |
+
if exist_on_exception:
|
465 |
+
exit()
|
466 |
+
return response
|
467 |
+
|
468 |
+
|
469 |
+
# def resize_image(input_image, resolution):
|
470 |
+
# H, W = input_image.shape[:2]
|
471 |
+
# k = float(min(resolution)) / min(H, W)
|
472 |
+
# img = cv2.resize(input_image, resolution, interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
473 |
+
# return img
|
utils/logger.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os.path as osp
|
3 |
+
from termcolor import colored
|
4 |
+
|
5 |
+
def set_logging(name=None, verbose=True):
|
6 |
+
for handler in logging.root.handlers[:]:
|
7 |
+
logging.root.removeHandler(handler)
|
8 |
+
# Sets level and returns logger
|
9 |
+
# rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
10 |
+
fmt = (
|
11 |
+
# colored("[%(name)s]", "magenta", attrs=["bold"])
|
12 |
+
colored("[%(asctime)s]", "blue")
|
13 |
+
+ colored("%(levelname)s:", "green")
|
14 |
+
+ colored("%(message)s", "white")
|
15 |
+
)
|
16 |
+
logging.basicConfig(format=fmt, level=logging.INFO if verbose else logging.WARNING)
|
17 |
+
return logging.getLogger(name)
|
18 |
+
|
19 |
+
LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
|
20 |
+
|
utils/mmdet_custom_hooks.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmengine.fileio import FileClient
|
2 |
+
from mmengine.dist import master_only
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import mmcv
|
6 |
+
import numpy as np
|
7 |
+
import os.path as osp
|
8 |
+
import cv2
|
9 |
+
from typing import Optional, Sequence
|
10 |
+
import torch.nn as nn
|
11 |
+
from mmdet.apis import inference_detector
|
12 |
+
from mmcv.transforms import Compose
|
13 |
+
from mmdet.engine import DetVisualizationHook
|
14 |
+
from mmdet.registry import HOOKS
|
15 |
+
from mmdet.structures import DetDataSample
|
16 |
+
|
17 |
+
from utils.io_utils import find_all_imgs, square_pad_resize, imglist2grid
|
18 |
+
|
19 |
+
def inference_detector(
|
20 |
+
model: nn.Module,
|
21 |
+
imgs,
|
22 |
+
test_pipeline
|
23 |
+
):
|
24 |
+
|
25 |
+
if isinstance(imgs, (list, tuple)):
|
26 |
+
is_batch = True
|
27 |
+
else:
|
28 |
+
imgs = [imgs]
|
29 |
+
is_batch = False
|
30 |
+
|
31 |
+
if len(imgs) == 0:
|
32 |
+
return []
|
33 |
+
|
34 |
+
test_pipeline = test_pipeline.copy()
|
35 |
+
if isinstance(imgs[0], np.ndarray):
|
36 |
+
# Calling this method across libraries will result
|
37 |
+
# in module unregistered error if not prefixed with mmdet.
|
38 |
+
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
|
39 |
+
|
40 |
+
test_pipeline = Compose(test_pipeline)
|
41 |
+
|
42 |
+
result_list = []
|
43 |
+
for img in imgs:
|
44 |
+
# prepare data
|
45 |
+
if isinstance(img, np.ndarray):
|
46 |
+
# TODO: remove img_id.
|
47 |
+
data_ = dict(img=img, img_id=0)
|
48 |
+
else:
|
49 |
+
# TODO: remove img_id.
|
50 |
+
data_ = dict(img_path=img, img_id=0)
|
51 |
+
# build the data pipeline
|
52 |
+
data_ = test_pipeline(data_)
|
53 |
+
|
54 |
+
data_['inputs'] = [data_['inputs']]
|
55 |
+
data_['data_samples'] = [data_['data_samples']]
|
56 |
+
|
57 |
+
# forward the model
|
58 |
+
with torch.no_grad():
|
59 |
+
results = model.test_step(data_)[0]
|
60 |
+
|
61 |
+
result_list.append(results)
|
62 |
+
|
63 |
+
if not is_batch:
|
64 |
+
return result_list[0]
|
65 |
+
else:
|
66 |
+
return result_list
|
67 |
+
|
68 |
+
|
69 |
+
@HOOKS.register_module()
|
70 |
+
class InstanceSegVisualizationHook(DetVisualizationHook):
|
71 |
+
|
72 |
+
def __init__(self, visualize_samples: str = '',
|
73 |
+
read_rgb: bool = False,
|
74 |
+
draw: bool = False,
|
75 |
+
interval: int = 50,
|
76 |
+
score_thr: float = 0.3,
|
77 |
+
show: bool = False,
|
78 |
+
wait_time: float = 0.,
|
79 |
+
test_out_dir: Optional[str] = None,
|
80 |
+
file_client_args: dict = dict(backend='disk')):
|
81 |
+
super().__init__(draw, interval, score_thr, show, wait_time, test_out_dir, file_client_args)
|
82 |
+
self.vis_samples = []
|
83 |
+
|
84 |
+
if osp.exists(visualize_samples):
|
85 |
+
self.channel_order = channel_order = 'rgb' if read_rgb else 'bgr'
|
86 |
+
samples = find_all_imgs(visualize_samples, abs_path=True)
|
87 |
+
for imgp in samples:
|
88 |
+
img = mmcv.imread(imgp, channel_order=channel_order)
|
89 |
+
img, _, _, _ = square_pad_resize(img, 640)
|
90 |
+
self.vis_samples.append(img)
|
91 |
+
|
92 |
+
def before_val(self, runner) -> None:
|
93 |
+
total_curr_iter = runner.iter
|
94 |
+
self._visualize_data(total_curr_iter, runner)
|
95 |
+
return super().before_val(runner)
|
96 |
+
|
97 |
+
# def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
98 |
+
# outputs: Sequence[DetDataSample]) -> None:
|
99 |
+
# """Run after every ``self.interval`` validation iterations.
|
100 |
+
|
101 |
+
# Args:
|
102 |
+
# runner (:obj:`Runner`): The runner of the validation process.
|
103 |
+
# batch_idx (int): The index of the current batch in the val loop.
|
104 |
+
# data_batch (dict): Data from dataloader.
|
105 |
+
# outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
|
106 |
+
# that contain annotations and predictions.
|
107 |
+
# """
|
108 |
+
# # if self.draw is False:
|
109 |
+
# # return
|
110 |
+
|
111 |
+
# if self.file_client is None:
|
112 |
+
# self.file_client = FileClient(**self.file_client_args)
|
113 |
+
|
114 |
+
|
115 |
+
# # There is no guarantee that the same batch of images
|
116 |
+
# # is visualized for each evaluation.
|
117 |
+
# total_curr_iter = runner.iter + batch_idx
|
118 |
+
|
119 |
+
# # # Visualize only the first data
|
120 |
+
# # img_path = outputs[0].img_path
|
121 |
+
# # img_bytes = self.file_client.get(img_path)
|
122 |
+
# # img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
123 |
+
# if total_curr_iter % self.interval == 0 and self.vis_samples:
|
124 |
+
# self._visualize_data(total_curr_iter, runner)
|
125 |
+
|
126 |
+
|
127 |
+
@master_only
|
128 |
+
def _visualize_data(self, total_curr_iter, runner):
|
129 |
+
|
130 |
+
tgt_size = 384
|
131 |
+
|
132 |
+
runner.model.eval()
|
133 |
+
outputs = inference_detector(runner.model, self.vis_samples, test_pipeline=runner.cfg.test_pipeline)
|
134 |
+
vis_results = []
|
135 |
+
for img, output in zip(self.vis_samples, outputs):
|
136 |
+
vis_img = self.add_datasample(
|
137 |
+
'val_img',
|
138 |
+
img,
|
139 |
+
data_sample=output,
|
140 |
+
show=self.show,
|
141 |
+
wait_time=self.wait_time,
|
142 |
+
pred_score_thr=self.score_thr,
|
143 |
+
draw_gt=False,
|
144 |
+
step=total_curr_iter)
|
145 |
+
vis_results.append(cv2.resize(vis_img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA))
|
146 |
+
|
147 |
+
drawn_img = imglist2grid(vis_results, tgt_size)
|
148 |
+
if drawn_img is None:
|
149 |
+
return
|
150 |
+
drawn_img = cv2.cvtColor(drawn_img, cv2.COLOR_BGR2RGB)
|
151 |
+
visualizer = self._visualizer
|
152 |
+
visualizer.set_image(drawn_img)
|
153 |
+
visualizer.add_image('val_img', drawn_img, total_curr_iter)
|
154 |
+
|
155 |
+
|
156 |
+
@master_only
|
157 |
+
def add_datasample(
|
158 |
+
self,
|
159 |
+
name: str,
|
160 |
+
image: np.ndarray,
|
161 |
+
data_sample: Optional['DetDataSample'] = None,
|
162 |
+
draw_gt: bool = True,
|
163 |
+
draw_pred: bool = True,
|
164 |
+
show: bool = False,
|
165 |
+
wait_time: float = 0,
|
166 |
+
# TODO: Supported in mmengine's Viusalizer.
|
167 |
+
out_file: Optional[str] = None,
|
168 |
+
pred_score_thr: float = 0.3,
|
169 |
+
step: int = 0) -> np.ndarray:
|
170 |
+
image = image.clip(0, 255).astype(np.uint8)
|
171 |
+
visualizer = self._visualizer
|
172 |
+
classes = visualizer.dataset_meta.get('classes', None)
|
173 |
+
palette = visualizer.dataset_meta.get('palette', None)
|
174 |
+
|
175 |
+
gt_img_data = None
|
176 |
+
pred_img_data = None
|
177 |
+
|
178 |
+
if data_sample is not None:
|
179 |
+
data_sample = data_sample.cpu()
|
180 |
+
|
181 |
+
if draw_gt and data_sample is not None:
|
182 |
+
gt_img_data = image
|
183 |
+
if 'gt_instances' in data_sample:
|
184 |
+
gt_img_data = visualizer._draw_instances(image,
|
185 |
+
data_sample.gt_instances,
|
186 |
+
classes, palette)
|
187 |
+
|
188 |
+
if 'gt_panoptic_seg' in data_sample:
|
189 |
+
assert classes is not None, 'class information is ' \
|
190 |
+
'not provided when ' \
|
191 |
+
'visualizing panoptic ' \
|
192 |
+
'segmentation results.'
|
193 |
+
gt_img_data = visualizer._draw_panoptic_seg(
|
194 |
+
gt_img_data, data_sample.gt_panoptic_seg, classes)
|
195 |
+
|
196 |
+
if draw_pred and data_sample is not None:
|
197 |
+
pred_img_data = image
|
198 |
+
if 'pred_instances' in data_sample:
|
199 |
+
pred_instances = data_sample.pred_instances
|
200 |
+
pred_instances = pred_instances[
|
201 |
+
pred_instances.scores > pred_score_thr]
|
202 |
+
pred_img_data = visualizer._draw_instances(image, pred_instances,
|
203 |
+
classes, palette)
|
204 |
+
if 'pred_panoptic_seg' in data_sample:
|
205 |
+
assert classes is not None, 'class information is ' \
|
206 |
+
'not provided when ' \
|
207 |
+
'visualizing panoptic ' \
|
208 |
+
'segmentation results.'
|
209 |
+
pred_img_data = visualizer._draw_panoptic_seg(
|
210 |
+
pred_img_data, data_sample.pred_panoptic_seg.numpy(),
|
211 |
+
classes)
|
212 |
+
|
213 |
+
if gt_img_data is not None and pred_img_data is not None:
|
214 |
+
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
215 |
+
elif gt_img_data is not None:
|
216 |
+
drawn_img = gt_img_data
|
217 |
+
elif pred_img_data is not None:
|
218 |
+
drawn_img = pred_img_data
|
219 |
+
else:
|
220 |
+
# Display the original image directly if nothing is drawn.
|
221 |
+
drawn_img = image
|
222 |
+
|
223 |
+
return drawn_img
|