Spaces:
Runtime error
Runtime error
Sara Mandelli
commited on
Commit
•
6bd8735
1
Parent(s):
f6b58ff
Update detector
Browse files- gan_vs_real_detector.py +38 -60
- utils/architectures.py +422 -0
- utils/python_patch_extractor/PatchExtractor.py +306 -0
- utils/python_patch_extractor/__init__.py +0 -0
gan_vs_real_detector.py
CHANGED
@@ -11,39 +11,14 @@ torch.multiprocessing.set_sharing_strategy('file_system')
|
|
11 |
import albumentations as A
|
12 |
import albumentations.pytorch as Ap
|
13 |
from utils import architectures
|
|
|
14 |
from PIL import Image
|
15 |
|
16 |
|
17 |
class Detector:
|
18 |
def __init__(self):
|
19 |
|
20 |
-
|
21 |
-
# model_A_dir = 'weights/method_A/net-EfficientNetB4_lr-0.001_img_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
|
22 |
-
# '\'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\', \'resize\', \'jpeg\']' \
|
23 |
-
# '_img_aug_p-0.5_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
|
24 |
-
#
|
25 |
-
# # model directory and path for detector B
|
26 |
-
# model_B_dir = 'weights/method_B/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
|
27 |
-
# '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
|
28 |
-
# '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
|
29 |
-
#
|
30 |
-
# # model directory and path for detector C
|
31 |
-
# model_C_dir = 'weights/method_C/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
|
32 |
-
# ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
|
33 |
-
# '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-5_batch_size-50_num_classes-2'
|
34 |
-
#
|
35 |
-
# # model directory and path for detector D
|
36 |
-
# model_D_dir = 'weights/method_D/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
|
37 |
-
# '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
|
38 |
-
# '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-10_batch_size-25_num_classes-2'
|
39 |
-
#
|
40 |
-
# # model directory for detector E
|
41 |
-
# model_E_dir = 'weights/method_E/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
|
42 |
-
# ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
|
43 |
-
# '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
|
44 |
-
|
45 |
-
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
|
46 |
-
# self.model_path = os.path.join(model_dir, 'bestval.pth')
|
47 |
|
48 |
# GPU configuration if available
|
49 |
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
@@ -72,17 +47,15 @@ class Detector:
|
|
72 |
Ap.transforms.ToTensorV2()
|
73 |
]
|
74 |
self.trans = A.Compose(transform)
|
75 |
-
|
76 |
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
|
77 |
-
|
78 |
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
|
79 |
|
80 |
-
def synth_real_detector(self, img_path: str, n_patch: int =
|
81 |
|
82 |
# Load image:
|
83 |
img = np.asarray(Image.open(img_path))
|
84 |
|
85 |
-
#
|
86 |
if img.shape == ():
|
87 |
print('{} None dimension'.format(img_path))
|
88 |
return None
|
@@ -96,47 +69,52 @@ class Detector:
|
|
96 |
print('Omitting alpha channel')
|
97 |
img = img[:, :, :3]
|
98 |
|
99 |
-
|
100 |
-
|
101 |
|
102 |
-
|
103 |
-
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
softmax_scores = torch.softmax(patch_scores, dim=1)
|
110 |
-
predictions = torch.argmax(softmax_scores, dim=1)
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
# LLR > 0: synthetic
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
def main():
|
131 |
-
# img_path
|
132 |
-
img_path = "/nas/public/exchange/semafor/eval1/stylegan2/100k-generated-images/car-512x384_cropped/stylegan2-" \
|
133 |
-
"config-f-psi-0.5/097000/097001.png"
|
134 |
|
135 |
-
#
|
136 |
-
|
137 |
|
138 |
detector = Detector()
|
139 |
-
detector.synth_real_detector(img_path
|
|
|
|
|
140 |
|
141 |
return 0
|
142 |
|
|
|
11 |
import albumentations as A
|
12 |
import albumentations.pytorch as Ap
|
13 |
from utils import architectures
|
14 |
+
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
|
15 |
from PIL import Image
|
16 |
|
17 |
|
18 |
class Detector:
|
19 |
def __init__(self):
|
20 |
|
21 |
+
self.weights_path_list = [os.path.join('/nas/home/nbonettini/projects/StyleGAN3-detection/weights', f'method_{x}.pth') for x in 'ABCDE']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# GPU configuration if available
|
24 |
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
|
47 |
Ap.transforms.ToTensorV2()
|
48 |
]
|
49 |
self.trans = A.Compose(transform)
|
|
|
50 |
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
|
|
|
51 |
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
|
52 |
|
53 |
+
def synth_real_detector(self, img_path: str, n_patch: int = 200):
|
54 |
|
55 |
# Load image:
|
56 |
img = np.asarray(Image.open(img_path))
|
57 |
|
58 |
+
# Opt-out if image is non conforming
|
59 |
if img.shape == ():
|
60 |
print('{} None dimension'.format(img_path))
|
61 |
return None
|
|
|
69 |
print('Omitting alpha channel')
|
70 |
img = img[:, :, :3]
|
71 |
|
72 |
+
img_net_scores = []
|
73 |
+
for net_idx, net in enumerate(self.nets):
|
74 |
|
75 |
+
if net_idx == 0:
|
|
|
76 |
|
77 |
+
# only for detector A, extract N = 200 random patches per image
|
78 |
+
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
|
79 |
+
|
80 |
+
else:
|
|
|
|
|
81 |
|
82 |
+
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
|
83 |
+
# we want more or less 200 patches per img
|
84 |
+
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
|
85 |
+
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
|
86 |
+
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
|
87 |
+
patches = pe.extract(img)
|
88 |
+
patch_list = list(patches.reshape((patches.shape[0]*patches.shape[1], 128, 128, 3)))
|
89 |
|
90 |
+
# Normalization
|
91 |
+
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
|
|
|
92 |
|
93 |
+
# Compute scores
|
94 |
+
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
|
95 |
+
with torch.no_grad():
|
96 |
+
patch_scores = net(transf_patch_tensor).cpu().numpy()
|
97 |
+
patch_predictions = np.argmax(patch_scores, axis=1)
|
98 |
|
99 |
+
maj_voting = np.any(patch_predictions).astype(int)
|
100 |
+
scores_maj_voting = patch_scores[:, maj_voting]
|
101 |
+
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
|
102 |
+
|
103 |
+
# final score is the average among the 5 scores returned by the detectors
|
104 |
+
img_score = np.mean(img_net_scores)
|
105 |
+
|
106 |
+
return img_score
|
107 |
|
108 |
|
109 |
def main():
|
|
|
|
|
|
|
110 |
|
111 |
+
# img_path on fermi:
|
112 |
+
img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
|
113 |
|
114 |
detector = Detector()
|
115 |
+
score = detector.synth_real_detector(img_path)
|
116 |
+
|
117 |
+
print('Image Score: {}'.format(score))
|
118 |
|
119 |
return 0
|
120 |
|
utils/architectures.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from efficientnet_pytorch import EfficientNet
|
7 |
+
from efficientnet_pytorch.utils import (
|
8 |
+
round_filters,
|
9 |
+
round_repeats,
|
10 |
+
drop_connect,
|
11 |
+
get_same_padding_conv2d,
|
12 |
+
get_model_params,
|
13 |
+
efficientnet_params,
|
14 |
+
load_pretrained_weights,
|
15 |
+
Swish,
|
16 |
+
MemoryEfficientSwish,
|
17 |
+
)
|
18 |
+
from efficientnet_pytorch.model import MBConvBlock
|
19 |
+
from torchvision.models import resnet
|
20 |
+
from pytorchcv.model_provider import get_model
|
21 |
+
|
22 |
+
|
23 |
+
class Head(nn.Module):
|
24 |
+
def __init__(self, in_f, out_f):
|
25 |
+
super(Head, self).__init__()
|
26 |
+
|
27 |
+
self.f = nn.Flatten()
|
28 |
+
self.l = nn.Linear(in_f, 512)
|
29 |
+
self.d = nn.Dropout(0.5)
|
30 |
+
self.o = nn.Linear(512, out_f)
|
31 |
+
self.b1 = nn.BatchNorm1d(in_f)
|
32 |
+
self.b2 = nn.BatchNorm1d(512)
|
33 |
+
self.r = nn.ReLU()
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.f(x)
|
37 |
+
x = self.b1(x)
|
38 |
+
x = self.d(x)
|
39 |
+
|
40 |
+
x = self.l(x)
|
41 |
+
x = self.r(x)
|
42 |
+
x = self.b2(x)
|
43 |
+
x = self.d(x)
|
44 |
+
|
45 |
+
out = self.o(x)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class FCN(nn.Module):
|
50 |
+
def __init__(self, base, in_f, out_f):
|
51 |
+
super(FCN, self).__init__()
|
52 |
+
self.base = base
|
53 |
+
self.h1 = Head(in_f, out_f)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = self.base(x)
|
57 |
+
return self.h1(x)
|
58 |
+
|
59 |
+
|
60 |
+
class BaseFCN(nn.Module):
|
61 |
+
def __init__(self, n_classes: int):
|
62 |
+
super(BaseFCN, self).__init__()
|
63 |
+
|
64 |
+
self.f = nn.Flatten()
|
65 |
+
self.l = nn.Linear(625, 256)
|
66 |
+
self.d = nn.Dropout(0.5)
|
67 |
+
self.o = nn.Linear(256, n_classes)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.f(x)
|
71 |
+
x = self.l(x)
|
72 |
+
x = self.d(x)
|
73 |
+
out = self.o(x)
|
74 |
+
return out
|
75 |
+
|
76 |
+
def get_trainable_parameters_cooccur(self):
|
77 |
+
return self.parameters()
|
78 |
+
|
79 |
+
|
80 |
+
class BaseFCNHigh(nn.Module):
|
81 |
+
def __init__(self, n_classes: int):
|
82 |
+
super(BaseFCNHigh, self).__init__()
|
83 |
+
|
84 |
+
self.f = nn.Flatten()
|
85 |
+
self.l = nn.Linear(625, 512)
|
86 |
+
self.d = nn.Dropout(0.5)
|
87 |
+
self.o = nn.Linear(512, n_classes)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.f(x)
|
91 |
+
x = self.l(x)
|
92 |
+
x = self.d(x)
|
93 |
+
out = self.o(x)
|
94 |
+
return out
|
95 |
+
|
96 |
+
def get_trainable_parameters_cooccur(self):
|
97 |
+
return self.parameters()
|
98 |
+
|
99 |
+
|
100 |
+
class BaseFCN4(nn.Module):
|
101 |
+
def __init__(self, n_classes: int):
|
102 |
+
super(BaseFCN4, self).__init__()
|
103 |
+
|
104 |
+
self.f = nn.Flatten()
|
105 |
+
self.l1 = nn.Linear(625, 512)
|
106 |
+
self.l2 = nn.Linear(512, 384)
|
107 |
+
self.l3 = nn.Linear(384, 256)
|
108 |
+
self.d = nn.Dropout(0.5)
|
109 |
+
self.o = nn.Linear(256, n_classes)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.f(x)
|
113 |
+
x = self.l1(x)
|
114 |
+
x = self.d(x)
|
115 |
+
x = self.l2(x)
|
116 |
+
x = self.d(x)
|
117 |
+
x = self.l3(x)
|
118 |
+
x = self.d(x)
|
119 |
+
out = self.o(x)
|
120 |
+
return out
|
121 |
+
|
122 |
+
def get_trainable_parameters_cooccur(self):
|
123 |
+
return self.parameters()
|
124 |
+
|
125 |
+
|
126 |
+
class BaseFCNBnR(nn.Module):
|
127 |
+
def __init__(self, n_classes: int):
|
128 |
+
super(BaseFCNBnR, self).__init__()
|
129 |
+
|
130 |
+
self.f = nn.Flatten()
|
131 |
+
self.b1 = nn.BatchNorm1d(625)
|
132 |
+
self.b2 = nn.BatchNorm1d(256)
|
133 |
+
self.l = nn.Linear(625, 256)
|
134 |
+
self.d = nn.Dropout(0.5)
|
135 |
+
self.o = nn.Linear(256, n_classes)
|
136 |
+
self.r = nn.ReLU()
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.f(x)
|
140 |
+
x = self.b1(x)
|
141 |
+
x = self.d(x)
|
142 |
+
x = self.l(x)
|
143 |
+
x = self.r(x)
|
144 |
+
x = self.b2(x)
|
145 |
+
x = self.d(x)
|
146 |
+
out = self.o(x)
|
147 |
+
return out
|
148 |
+
|
149 |
+
def get_trainable_parameters_cooccur(self):
|
150 |
+
return self.parameters()
|
151 |
+
|
152 |
+
|
153 |
+
def forward_resnet_conv(net, x, upto: int = 4):
|
154 |
+
"""
|
155 |
+
Forward ResNet only in its convolutional part
|
156 |
+
:param net:
|
157 |
+
:param x:
|
158 |
+
:param upto:
|
159 |
+
:return:
|
160 |
+
"""
|
161 |
+
x = net.conv1(x) # N / 2
|
162 |
+
x = net.bn1(x)
|
163 |
+
x = net.relu(x)
|
164 |
+
x = net.maxpool(x) # N / 4
|
165 |
+
|
166 |
+
if upto >= 1:
|
167 |
+
x = net.layer1(x) # N / 4
|
168 |
+
if upto >= 2:
|
169 |
+
x = net.layer2(x) # N / 8
|
170 |
+
if upto >= 3:
|
171 |
+
x = net.layer3(x) # N / 16
|
172 |
+
if upto >= 4:
|
173 |
+
x = net.layer4(x) # N / 32
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class FeatureExtractor(nn.Module):
|
178 |
+
"""
|
179 |
+
Abstract class to be extended when supporting features extraction.
|
180 |
+
It also provides standard normalized and parameters
|
181 |
+
"""
|
182 |
+
|
183 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
184 |
+
raise NotImplementedError
|
185 |
+
|
186 |
+
def get_trainable_parameters(self):
|
187 |
+
return self.parameters()
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def get_normalizer():
|
191 |
+
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureExtractorGray(nn.Module):
|
195 |
+
"""
|
196 |
+
Abstract class to be extended when supporting features extraction.
|
197 |
+
It also provides standard normalized and parameters
|
198 |
+
"""
|
199 |
+
|
200 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
201 |
+
raise NotImplementedError
|
202 |
+
|
203 |
+
def get_trainable_parameters(self):
|
204 |
+
return self.parameters()
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def get_normalizer():
|
208 |
+
return transforms.Normalize(mean=[0.479], std=[0.226])
|
209 |
+
|
210 |
+
|
211 |
+
class EfficientNetGen(FeatureExtractor):
|
212 |
+
def __init__(self, model: str, n_classes: int, pretrained: bool):
|
213 |
+
super(EfficientNetGen, self).__init__()
|
214 |
+
|
215 |
+
if pretrained:
|
216 |
+
self.efficientnet = EfficientNet.from_pretrained(model)
|
217 |
+
else:
|
218 |
+
self.efficientnet = EfficientNet.from_name(model)
|
219 |
+
|
220 |
+
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
|
221 |
+
del self.efficientnet._fc
|
222 |
+
|
223 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
224 |
+
x = self.efficientnet.extract_features(x)
|
225 |
+
x = self.efficientnet._avg_pooling(x)
|
226 |
+
x = x.flatten(start_dim=1)
|
227 |
+
return x
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x = self.features(x)
|
231 |
+
x = self.efficientnet._dropout(x)
|
232 |
+
x = self.classifier(x)
|
233 |
+
# x = F.softmax(x, dim=-1)
|
234 |
+
return x
|
235 |
+
|
236 |
+
|
237 |
+
class EfficientNetB0(EfficientNetGen):
|
238 |
+
def __init__(self, n_classes: int, pretrained: bool):
|
239 |
+
super(EfficientNetB0, self).__init__(model='efficientnet-b0', n_classes=n_classes, pretrained=pretrained)
|
240 |
+
|
241 |
+
|
242 |
+
class EfficientNetB4(EfficientNetGen):
|
243 |
+
def __init__(self, n_classes: int, pretrained: bool):
|
244 |
+
super(EfficientNetB4, self).__init__(model='efficientnet-b4', n_classes=n_classes, pretrained=pretrained)
|
245 |
+
|
246 |
+
|
247 |
+
class EfficientNetGenPostStem(FeatureExtractor):
|
248 |
+
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
|
249 |
+
super(EfficientNetGenPostStem, self).__init__()
|
250 |
+
|
251 |
+
if pretrained:
|
252 |
+
self.efficientnet = EfficientNet.from_pretrained(model)
|
253 |
+
else:
|
254 |
+
self.efficientnet = EfficientNet.from_name(model)
|
255 |
+
|
256 |
+
self.n_ir_blocks = n_ir_blocks
|
257 |
+
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
|
258 |
+
|
259 |
+
# modify STEM
|
260 |
+
in_channels = 3 # rgb
|
261 |
+
out_channels = round_filters(32, self.efficientnet._global_params)
|
262 |
+
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
|
263 |
+
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, bias=False)
|
264 |
+
|
265 |
+
self.init_blocks_args = self.efficientnet._blocks_args[0]
|
266 |
+
self.init_blocks_args = self.init_blocks_args._replace(output_filters=32)
|
267 |
+
self.init_block = MBConvBlock(self.init_blocks_args, self.efficientnet._global_params)
|
268 |
+
|
269 |
+
self.last_block_args = self.efficientnet._blocks_args[0]
|
270 |
+
self.last_block_args = self.last_block_args._replace(output_filters=32, stride=2)
|
271 |
+
self.last_block = MBConvBlock(self.last_block_args, self.efficientnet._global_params)
|
272 |
+
|
273 |
+
del self.efficientnet._fc
|
274 |
+
|
275 |
+
|
276 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
277 |
+
|
278 |
+
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
|
279 |
+
|
280 |
+
# init blocks
|
281 |
+
for b in range(self.n_ir_blocks - 1):
|
282 |
+
x = self.init_block(x, drop_connect_rate=0)
|
283 |
+
|
284 |
+
# last block
|
285 |
+
x = self.last_block(x, drop_connect_rate=0)
|
286 |
+
|
287 |
+
# standard blocks efficientNet:
|
288 |
+
for idx, block in enumerate(self.efficientnet._blocks):
|
289 |
+
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
|
290 |
+
if drop_connect_rate:
|
291 |
+
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
|
292 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
293 |
+
|
294 |
+
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
|
295 |
+
|
296 |
+
x = self.efficientnet._avg_pooling(x)
|
297 |
+
x = x.flatten(start_dim=1)
|
298 |
+
return x
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
x = self.features(x)
|
302 |
+
x = self.efficientnet._dropout(x)
|
303 |
+
x = self.classifier(x)
|
304 |
+
# x = F.softmax(x, dim=-1)
|
305 |
+
return x
|
306 |
+
|
307 |
+
|
308 |
+
class EfficientNetB0PostStemIR(EfficientNetGenPostStem):
|
309 |
+
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
|
310 |
+
super(EfficientNetB0PostStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
|
311 |
+
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
|
312 |
+
|
313 |
+
|
314 |
+
class EfficientNetGenPreStem(FeatureExtractor):
|
315 |
+
def __init__(self, model: str, n_classes: int, pretrained: bool, n_ir_blocks: int):
|
316 |
+
super(EfficientNetGenPreStem, self).__init__()
|
317 |
+
|
318 |
+
if pretrained:
|
319 |
+
self.efficientnet = EfficientNet.from_pretrained(model)
|
320 |
+
else:
|
321 |
+
self.efficientnet = EfficientNet.from_name(model)
|
322 |
+
|
323 |
+
self.n_ir_blocks = n_ir_blocks
|
324 |
+
self.classifier = nn.Linear(self.efficientnet._conv_head.out_channels, n_classes)
|
325 |
+
|
326 |
+
self.init_block_args = self.efficientnet._blocks_args[0]
|
327 |
+
self.init_block_args = self.init_block_args._replace(input_filters=3, output_filters=32)
|
328 |
+
self.init_block = MBConvBlock(self.init_block_args, self.efficientnet._global_params)
|
329 |
+
|
330 |
+
self.last_blocks_args = self.efficientnet._blocks_args[0]
|
331 |
+
self.last_blocks_args = self.last_blocks_args._replace(output_filters=32)
|
332 |
+
self.last_block = MBConvBlock(self.last_blocks_args, self.efficientnet._global_params)
|
333 |
+
|
334 |
+
# modify STEM
|
335 |
+
in_channels = 32
|
336 |
+
out_channels = round_filters(32, self.efficientnet._global_params)
|
337 |
+
Conv2d = get_same_padding_conv2d(image_size=self.efficientnet._global_params.image_size)
|
338 |
+
self.efficientnet._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
339 |
+
|
340 |
+
del self.efficientnet._fc
|
341 |
+
|
342 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
343 |
+
|
344 |
+
# init block
|
345 |
+
x = self.init_block(x, drop_connect_rate=0)
|
346 |
+
|
347 |
+
# other blocks
|
348 |
+
for b in range(self.n_ir_blocks - 1):
|
349 |
+
x = self.last_block(x, drop_connect_rate=0)
|
350 |
+
|
351 |
+
# standard stem efficientNet:
|
352 |
+
x = self.efficientnet._swish(self.efficientnet._bn0(self.efficientnet._conv_stem(x)))
|
353 |
+
|
354 |
+
# standard blocks efficientNet:
|
355 |
+
for idx, block in enumerate(self.efficientnet._blocks):
|
356 |
+
drop_connect_rate = self.efficientnet._global_params.drop_connect_rate
|
357 |
+
if drop_connect_rate:
|
358 |
+
drop_connect_rate *= float(idx) / len(self.efficientnet._blocks)
|
359 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
360 |
+
|
361 |
+
x = self.efficientnet._swish(self.efficientnet._bn1(self.efficientnet._conv_head(x)))
|
362 |
+
|
363 |
+
x = self.efficientnet._avg_pooling(x)
|
364 |
+
x = x.flatten(start_dim=1)
|
365 |
+
return x
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
x = self.features(x)
|
369 |
+
x = self.efficientnet._dropout(x)
|
370 |
+
x = self.classifier(x)
|
371 |
+
# x = F.softmax(x, dim=-1)
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
class EfficientNetB0PreStemIR(EfficientNetGenPreStem):
|
376 |
+
def __init__(self, n_classes: int, pretrained: bool, n_ir_blocks: int):
|
377 |
+
super(EfficientNetB0PreStemIR, self).__init__(model='efficientnet-b0', n_classes=n_classes,
|
378 |
+
pretrained=pretrained, n_ir_blocks=n_ir_blocks)
|
379 |
+
|
380 |
+
|
381 |
+
class ResNet50(FeatureExtractor):
|
382 |
+
def __init__(self, n_classes: int, pretrained: bool):
|
383 |
+
super(ResNet50, self).__init__()
|
384 |
+
self.resnet = resnet.resnet50(pretrained=pretrained)
|
385 |
+
self.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=n_classes)
|
386 |
+
del self.resnet.fc
|
387 |
+
|
388 |
+
def features(self, x):
|
389 |
+
x = forward_resnet_conv(self.resnet, x)
|
390 |
+
x = self.resnet.avgpool(x).flatten(start_dim=1)
|
391 |
+
return x
|
392 |
+
|
393 |
+
def forward(self, x):
|
394 |
+
x = self.features(x)
|
395 |
+
x = self.fc(x)
|
396 |
+
return x
|
397 |
+
|
398 |
+
|
399 |
+
"""
|
400 |
+
Xception from Kaggle
|
401 |
+
"""
|
402 |
+
|
403 |
+
|
404 |
+
class XceptionWeiHao(FeatureExtractor):
|
405 |
+
|
406 |
+
def __init__(self, n_classes: int, pretrained: bool):
|
407 |
+
super(XceptionWeiHao, self).__init__()
|
408 |
+
|
409 |
+
self.model = get_model("xception", pretrained=pretrained)
|
410 |
+
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove original output layer
|
411 |
+
self.model[0].final_block.pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)))
|
412 |
+
self.model = FCN(self.model, 2048, n_classes)
|
413 |
+
|
414 |
+
def features(self, x: torch.Tensor) -> torch.Tensor:
|
415 |
+
return self.model.base(x)
|
416 |
+
|
417 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
418 |
+
x = self.features(x)
|
419 |
+
return self.model.h1(x)
|
420 |
+
|
421 |
+
|
422 |
+
|
utils/python_patch_extractor/PatchExtractor.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Author: Nicolo' Bonettini
|
3 |
+
@Author: Luca Bondi
|
4 |
+
@Author: Francesco Picetti
|
5 |
+
"""
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from skimage.util import view_as_windows, view_as_blocks
|
9 |
+
|
10 |
+
|
11 |
+
# Score functions ---
|
12 |
+
|
13 |
+
def mid_intensity_high_texture(in_content):
|
14 |
+
"""
|
15 |
+
Quality function that returns higher scores for mid intensity patches with high texture levels. Empirical.
|
16 |
+
:type in_content: ndarray
|
17 |
+
:param in_content : 2D or 3D ndarray. Values are expected in [0,1] if in_content is float, in [0,255] if in_content is uint8
|
18 |
+
:return score: float
|
19 |
+
score in [0,1].
|
20 |
+
"""
|
21 |
+
|
22 |
+
if in_content.dtype == np.uint8:
|
23 |
+
in_content = in_content / 255.
|
24 |
+
|
25 |
+
mean_std_weight = .7
|
26 |
+
|
27 |
+
in_content = in_content.flatten()
|
28 |
+
|
29 |
+
mean_val = in_content.mean()
|
30 |
+
std_val = in_content.std()
|
31 |
+
|
32 |
+
ch_mean_score = -4 * mean_val ** 2 + 4 * mean_val
|
33 |
+
ch_std_score = 1 - np.exp(-2 * np.log(10) * std_val)
|
34 |
+
|
35 |
+
score = mean_std_weight * ch_mean_score + (1 - mean_std_weight) * ch_std_score
|
36 |
+
return score
|
37 |
+
|
38 |
+
|
39 |
+
def count_patches(in_size, patch_size, patch_stride):
|
40 |
+
"""
|
41 |
+
Compute the number of patches
|
42 |
+
:param in_size:
|
43 |
+
:param patch_size:
|
44 |
+
:param patch_stride:
|
45 |
+
:return:
|
46 |
+
"""
|
47 |
+
win_indices_shape = (((np.array(in_size) - np.array(patch_size))
|
48 |
+
// np.array(patch_stride)) + 1)
|
49 |
+
return int(np.prod(win_indices_shape))
|
50 |
+
|
51 |
+
|
52 |
+
class PatchExtractor:
|
53 |
+
|
54 |
+
def __init__(self, dim, offset=None, stride=None, rand=None, function=None, threshold=None,
|
55 |
+
num=None, indexes=None):
|
56 |
+
|
57 |
+
"""
|
58 |
+
N-dimensional patch extractor
|
59 |
+
Args:
|
60 |
+
:param in_content : ndarray
|
61 |
+
the content to process as a numpy array of ndim dimensions
|
62 |
+
|
63 |
+
:param dim : tuple
|
64 |
+
patch_array dimensions as a tuple of ndim elements
|
65 |
+
|
66 |
+
Named args:
|
67 |
+
:param offset : tuple
|
68 |
+
the offsets along each axis as a tuple of ndim elements
|
69 |
+
|
70 |
+
:param stride : tuple
|
71 |
+
the stride of each axis as a tuple of ndim elements
|
72 |
+
|
73 |
+
:param rand : bool
|
74 |
+
randomize patch_array order. Mutually exclusive with function_handler
|
75 |
+
|
76 |
+
:param function : function
|
77 |
+
patch quality function handler. Mutually exclusive with rand
|
78 |
+
|
79 |
+
:param threshold: float
|
80 |
+
minimum quality threshold
|
81 |
+
|
82 |
+
:param num : int
|
83 |
+
maximum number of returned patch_array. Mutually exclusive with indexes
|
84 |
+
|
85 |
+
:param indexes : list|ndarray
|
86 |
+
explicitly return corresponding patch indexes (function_handler or C order used to index patch_array).
|
87 |
+
Mutually exclusive with num
|
88 |
+
|
89 |
+
:return ndarray: patch_array
|
90 |
+
array of patch_array
|
91 |
+
if rand==False and function_handler==None and num==None and indexes==None:
|
92 |
+
patch_array.ndim = 2 * in_content.ndim
|
93 |
+
else:
|
94 |
+
patch_array.ndim = 1 + in_content.ndim
|
95 |
+
"""
|
96 |
+
|
97 |
+
# Arguments parser ---
|
98 |
+
if not isinstance(dim, tuple):
|
99 |
+
raise ValueError('dim must be a tuple')
|
100 |
+
self.dim = dim
|
101 |
+
|
102 |
+
ndim = len(dim)
|
103 |
+
self.ndim = ndim
|
104 |
+
|
105 |
+
if offset is None:
|
106 |
+
offset = tuple([0] * ndim)
|
107 |
+
if not isinstance(offset, tuple):
|
108 |
+
raise ValueError('offset must be a tuple')
|
109 |
+
if len(offset) != ndim:
|
110 |
+
raise ValueError('offset must a tuple of length {:d}'.format(ndim))
|
111 |
+
self.offset = offset
|
112 |
+
|
113 |
+
if stride is None:
|
114 |
+
stride = dim
|
115 |
+
if not isinstance(stride, tuple):
|
116 |
+
raise ValueError('stride must be a tuple')
|
117 |
+
if len(stride) != ndim:
|
118 |
+
raise ValueError('stride must a tuple of length {:d}'.format(ndim))
|
119 |
+
self.stride = stride
|
120 |
+
|
121 |
+
if rand is not None and function is not None:
|
122 |
+
raise ValueError('rand and function cannot be set at the same time')
|
123 |
+
|
124 |
+
if rand is None:
|
125 |
+
rand = False
|
126 |
+
if not isinstance(rand, bool):
|
127 |
+
raise ValueError('rand must be a boolean')
|
128 |
+
self.rand = rand
|
129 |
+
|
130 |
+
if function is not None and not callable(function):
|
131 |
+
raise ValueError('function must be a function handler')
|
132 |
+
self.function_handler = function
|
133 |
+
|
134 |
+
if threshold is None:
|
135 |
+
threshold = 0.0
|
136 |
+
if not isinstance(threshold, float):
|
137 |
+
raise ValueError('threshold must be a float')
|
138 |
+
self.threshold = threshold
|
139 |
+
|
140 |
+
if num is not None and indexes is not None:
|
141 |
+
raise ValueError('num and indexes cannot be set at the same time')
|
142 |
+
|
143 |
+
if num is not None and not isinstance(num, int):
|
144 |
+
raise ValueError('num must be an int')
|
145 |
+
self.num = num
|
146 |
+
|
147 |
+
if indexes is not None and not isinstance(indexes, list) and not isinstance(indexes, np.ndarray):
|
148 |
+
raise ValueError('indexes must be an list or a 1d ndarray')
|
149 |
+
if indexes is not None:
|
150 |
+
indexes = np.array(indexes).flatten()
|
151 |
+
self.indexes = indexes
|
152 |
+
|
153 |
+
self.in_content_original_shape = None
|
154 |
+
self.in_content_cropped_shape = None
|
155 |
+
|
156 |
+
def extract(self, in_content):
|
157 |
+
|
158 |
+
if not isinstance(in_content, np.ndarray):
|
159 |
+
raise ValueError('in_content must be of type: ' + str(np.ndarray))
|
160 |
+
|
161 |
+
if in_content.ndim != self.ndim:
|
162 |
+
raise ValueError('in_content shape must a tuple of length {:d}'.format(self.ndim))
|
163 |
+
|
164 |
+
self.in_content_original_shape = in_content.shape
|
165 |
+
|
166 |
+
# Offset ---
|
167 |
+
for dim_idx, dim_offset in enumerate(self.offset):
|
168 |
+
dim_max = in_content.shape[dim_idx]
|
169 |
+
in_content = in_content.take(range(dim_offset, dim_max), axis=dim_idx)
|
170 |
+
|
171 |
+
# Patch list ---
|
172 |
+
if self.dim == self.stride:
|
173 |
+
in_content_crop = in_content
|
174 |
+
for dim_idx in range(self.ndim):
|
175 |
+
dim_max = (in_content.shape[dim_idx] // self.dim[dim_idx]) * self.dim[dim_idx]
|
176 |
+
in_content_crop = in_content_crop.take(range(0, dim_max), axis=dim_idx)
|
177 |
+
patch_array = view_as_blocks(in_content_crop, self.dim)
|
178 |
+
else:
|
179 |
+
patch_array = view_as_windows(in_content, self.dim, self.stride)
|
180 |
+
|
181 |
+
patch_array = np.ascontiguousarray(patch_array)
|
182 |
+
|
183 |
+
patch_idx = patch_array.shape[:self.ndim]
|
184 |
+
self.in_content_cropped_shape = tuple((np.asarray(patch_idx) - 1) * np.asarray(self.stride) + np.asarray(self.dim))
|
185 |
+
|
186 |
+
# Evaluate patch_array or rand sort ---
|
187 |
+
if self.rand:
|
188 |
+
patch_array.shape = (-1,) + self.dim
|
189 |
+
random.shuffle(patch_array)
|
190 |
+
else:
|
191 |
+
if self.function_handler is not None:
|
192 |
+
patch_array.shape = (-1,) + self.dim
|
193 |
+
patch_scores = np.asarray(list(map(self.function_handler, patch_array)))
|
194 |
+
sort_idxs = np.argsort(patch_scores)[::-1]
|
195 |
+
patch_scores = patch_scores[sort_idxs]
|
196 |
+
patch_array = patch_array[sort_idxs]
|
197 |
+
patch_array = patch_array[patch_scores >= self.threshold]
|
198 |
+
|
199 |
+
if self.num is not None:
|
200 |
+
patch_array.shape = (-1,) + self.dim
|
201 |
+
patch_array = patch_array[:self.num]
|
202 |
+
|
203 |
+
if self.indexes is not None:
|
204 |
+
patch_array.shape = (-1,) + self.dim
|
205 |
+
patch_array = patch_array[self.indexes]
|
206 |
+
|
207 |
+
return patch_array
|
208 |
+
|
209 |
+
def extract_call(self, args): # TODO: verify
|
210 |
+
in_content = args.pop('in_content')
|
211 |
+
dim = args.pop('dim')
|
212 |
+
|
213 |
+
return self.extract(in_content)
|
214 |
+
|
215 |
+
def reconstruct(self, patch_array):
|
216 |
+
"""
|
217 |
+
Reconstruct the N-dim image from the patch_array that has been extracted previously
|
218 |
+
:param patch_array: array of patches as output of patch_extractor
|
219 |
+
:return:
|
220 |
+
"""
|
221 |
+
# Arguments parser ---
|
222 |
+
if not isinstance(patch_array, np.ndarray):
|
223 |
+
raise ValueError('patch_array must be of type: ' + str(np.ndarray))
|
224 |
+
|
225 |
+
ndim = patch_array.ndim // 2
|
226 |
+
|
227 |
+
# if not isinstance(patch_stride, tuple):
|
228 |
+
# raise ValueError('patch_stride must be a tuple')
|
229 |
+
# if len(patch_stride) != ndim:
|
230 |
+
# raise ValueError('patch_stride must be a tuple of length {:d}'.format(ndim))
|
231 |
+
#
|
232 |
+
# if not isinstance(image_shape, tuple):
|
233 |
+
# raise ValueError('patch_idx must be a tuple')
|
234 |
+
# if len(image_shape) != ndim:
|
235 |
+
# raise ValueError('patch_idx must be a tuple of length {:d}'.format(ndim))
|
236 |
+
|
237 |
+
patch_stride = self.stride
|
238 |
+
image_shape = self.in_content_cropped_shape
|
239 |
+
|
240 |
+
patch_shape = patch_array.shape[-ndim:]
|
241 |
+
patch_idx = patch_array.shape[:ndim]
|
242 |
+
image_shape_computed = tuple((np.array(patch_idx) - 1) * np.array(patch_stride) + np.array(patch_shape))
|
243 |
+
if not image_shape == image_shape_computed:
|
244 |
+
raise ValueError('There is something wrong with the dimensions!')
|
245 |
+
|
246 |
+
if ndim > 4:
|
247 |
+
raise ValueError('For now, it works only in 4D, sorry!')
|
248 |
+
numpatches = count_patches(image_shape, patch_shape, patch_stride)
|
249 |
+
patch_array_unwrapped = patch_array.reshape(numpatches, *patch_shape)
|
250 |
+
image_recon = np.zeros(image_shape)
|
251 |
+
norm_mask = np.zeros(image_shape)
|
252 |
+
counter = 0
|
253 |
+
|
254 |
+
for h in np.arange(0, image_shape[0] - patch_shape[0] + 1, patch_stride[0]):
|
255 |
+
if ndim > 1:
|
256 |
+
for i in np.arange(0, image_shape[1] - patch_shape[1] + 1, patch_stride[1]):
|
257 |
+
if ndim > 2:
|
258 |
+
for j in np.arange(0, image_shape[2] - patch_shape[2] + 1, patch_stride[2]):
|
259 |
+
if ndim > 3:
|
260 |
+
for k in np.arange(0, image_shape[3] - patch_shape[3] + 1, patch_stride[3]):
|
261 |
+
image_recon[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2],
|
262 |
+
k:k + patch_shape[3]] += patch_array_unwrapped[counter, :, :, :, :]
|
263 |
+
norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2],
|
264 |
+
k:k + patch_shape[3]] += 1
|
265 |
+
counter += 1
|
266 |
+
else:
|
267 |
+
image_recon[h:h + patch_shape[0], i:i + patch_shape[1],
|
268 |
+
j:j + patch_shape[2]] += patch_array_unwrapped[counter, :, :, :]
|
269 |
+
norm_mask[h:h + patch_shape[0], i:i + patch_shape[1], j:j + patch_shape[2]] += 1
|
270 |
+
counter += 1
|
271 |
+
else:
|
272 |
+
image_recon[h:h + patch_shape[0], i:i + patch_shape[1]] += patch_array_unwrapped[counter, :, :]
|
273 |
+
norm_mask[h:h + patch_shape[0], i:i + patch_shape[1]] += 1
|
274 |
+
counter += 1
|
275 |
+
else:
|
276 |
+
image_recon[h:h + patch_shape[0]] += patch_array_unwrapped[counter, :]
|
277 |
+
norm_mask[h:h + patch_shape[0]] += 1
|
278 |
+
counter += 1
|
279 |
+
|
280 |
+
image_recon /= norm_mask
|
281 |
+
|
282 |
+
return image_recon
|
283 |
+
|
284 |
+
|
285 |
+
def main():
|
286 |
+
in_shape = (644, 481, 3)
|
287 |
+
dim = (120, 120, 3)
|
288 |
+
stride = (7, 90, 90, 3)
|
289 |
+
offset = (1, 0, 0, 0)
|
290 |
+
in_content = np.random.randint(256, size=in_shape).astype(np.uint8)
|
291 |
+
# args = {'in_content': in_content,
|
292 |
+
# 'dim': dim,
|
293 |
+
# 'offset': offset,
|
294 |
+
# 'stride': stride,
|
295 |
+
# }
|
296 |
+
|
297 |
+
# patch_array = patch_extractor_call(args)
|
298 |
+
pe = PatchExtractor(dim)
|
299 |
+
patch_array = pe.extract(in_content)
|
300 |
+
print('patch_array.shape = ' + str(patch_array.shape))
|
301 |
+
img_recon = pe.reconstruct(patch_array)
|
302 |
+
print('img_recon.shape = ' + str(img_recon.shape))
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
main()
|
utils/python_patch_extractor/__init__.py
ADDED
File without changes
|