Jacob Logas commited on
Commit
1173b78
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v2.3.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+ - repo: https://github.com/astral-sh/ruff-pre-commit
9
+ # Ruff version.
10
+ rev: v0.5.2
11
+ hooks:
12
+ # Run the linter.
13
+ - id: ruff
14
+ args: [ --fix ]
15
+ # Run the formatter.
16
+ - id: ruff-format
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LowKey
3
+ emoji: 😒
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
align/__init__.py ADDED
File without changes
align/align_trans.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from align.matlab_cp2tform import get_similarity_transform_for_cv2
4
+
5
+
6
+ # reference facial points, a list of coordinates (x,y)
7
+ REFERENCE_FACIAL_POINTS = [ # default reference facial points for crop_size = (112, 112); should adjust REFERENCE_FACIAL_POINTS accordingly for other crop_size
8
+ [30.29459953, 51.69630051],
9
+ [65.53179932, 51.50139999],
10
+ [48.02519989, 71.73660278],
11
+ [33.54930115, 92.3655014],
12
+ [62.72990036, 92.20410156],
13
+ ]
14
+
15
+ DEFAULT_CROP_SIZE = (96, 112)
16
+
17
+
18
+ class FaceWarpException(Exception):
19
+ def __str__(self):
20
+ return "In File {}:{}".format(__file__, super.__str__(self))
21
+
22
+
23
+ def get_reference_facial_points(
24
+ output_size=None,
25
+ inner_padding_factor=0.0,
26
+ outer_padding=(0, 0),
27
+ default_square=False,
28
+ ):
29
+ """
30
+ Function:
31
+ ----------
32
+ get reference 5 key points according to crop settings:
33
+ 0. Set default crop_size:
34
+ if default_square:
35
+ crop_size = (112, 112)
36
+ else:
37
+ crop_size = (96, 112)
38
+ 1. Pad the crop_size by inner_padding_factor in each side;
39
+ 2. Resize crop_size into (output_size - outer_padding*2),
40
+ pad into output_size with outer_padding;
41
+ 3. Output reference_5point;
42
+ Parameters:
43
+ ----------
44
+ @output_size: (w, h) or None
45
+ size of aligned face image
46
+ @inner_padding_factor: (w_factor, h_factor)
47
+ padding factor for inner (w, h)
48
+ @outer_padding: (w_pad, h_pad)
49
+ each row is a pair of coordinates (x, y)
50
+ @default_square: True or False
51
+ if True:
52
+ default crop_size = (112, 112)
53
+ else:
54
+ default crop_size = (96, 112);
55
+ !!! make sure, if output_size is not None:
56
+ (output_size - outer_padding)
57
+ = some_scale * (default crop_size * (1.0 + inner_padding_factor))
58
+ Returns:
59
+ ----------
60
+ @reference_5point: 5x2 np.array
61
+ each row is a pair of transformed coordinates (x, y)
62
+ """
63
+ # print('\n===> get_reference_facial_points():')
64
+
65
+ # print('---> Params:')
66
+ # print(' output_size: ', output_size)
67
+ # print(' inner_padding_factor: ', inner_padding_factor)
68
+ # print(' outer_padding:', outer_padding)
69
+ # print(' default_square: ', default_square)
70
+
71
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
72
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
73
+
74
+ # 0) make the inner region a square
75
+ if default_square:
76
+ size_diff = max(tmp_crop_size) - tmp_crop_size
77
+ tmp_5pts += size_diff / 2
78
+ tmp_crop_size += size_diff
79
+
80
+ # print('---> default:')
81
+ # print(' crop_size = ', tmp_crop_size)
82
+ # print(' reference_5pts = ', tmp_5pts)
83
+
84
+ if (
85
+ output_size
86
+ and output_size[0] == tmp_crop_size[0]
87
+ and output_size[1] == tmp_crop_size[1]
88
+ ):
89
+ # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
90
+ return tmp_5pts
91
+
92
+ if inner_padding_factor == 0 and outer_padding == (0, 0):
93
+ if output_size is None:
94
+ # print('No paddings to do: return default reference points')
95
+ return tmp_5pts
96
+ else:
97
+ raise FaceWarpException(
98
+ "No paddings to do, output_size must be None or {}".format(
99
+ tmp_crop_size
100
+ )
101
+ )
102
+
103
+ # check output size
104
+ if not (0 <= inner_padding_factor <= 1.0):
105
+ raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)")
106
+
107
+ if (
108
+ inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0
109
+ ) and output_size is None:
110
+ output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32)
111
+ output_size += np.array(outer_padding)
112
+ # print(' deduced from paddings, output_size = ', output_size)
113
+
114
+ if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
115
+ raise FaceWarpException(
116
+ "Not (outer_padding[0] < output_size[0]"
117
+ "and outer_padding[1] < output_size[1])"
118
+ )
119
+
120
+ # 1) pad the inner region according inner_padding_factor
121
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
122
+ if inner_padding_factor > 0:
123
+ size_diff = tmp_crop_size * inner_padding_factor * 2
124
+ tmp_5pts += size_diff / 2
125
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
126
+
127
+ # print(' crop_size = ', tmp_crop_size)
128
+ # print(' reference_5pts = ', tmp_5pts)
129
+
130
+ # 2) resize the padded inner region
131
+ # print('---> STEP2: resize the padded inner region')
132
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
133
+ # print(' crop_size = ', tmp_crop_size)
134
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
135
+
136
+ if (
137
+ size_bf_outer_pad[0] * tmp_crop_size[1]
138
+ != size_bf_outer_pad[1] * tmp_crop_size[0]
139
+ ):
140
+ raise FaceWarpException(
141
+ "Must have (output_size - outer_padding)"
142
+ "= some_scale * (crop_size * (1.0 + inner_padding_factor)"
143
+ )
144
+
145
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
146
+ # print(' resize scale_factor = ', scale_factor)
147
+ tmp_5pts = tmp_5pts * scale_factor
148
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
149
+ # tmp_5pts = tmp_5pts + size_diff / 2
150
+ tmp_crop_size = size_bf_outer_pad
151
+ # print(' crop_size = ', tmp_crop_size)
152
+ # print(' reference_5pts = ', tmp_5pts)
153
+
154
+ # 3) add outer_padding to make output_size
155
+ reference_5point = tmp_5pts + np.array(outer_padding)
156
+ tmp_crop_size = output_size
157
+ # print('---> STEP3: add outer_padding to make output_size')
158
+ # print(' crop_size = ', tmp_crop_size)
159
+ # print(' reference_5pts = ', tmp_5pts)
160
+
161
+ # print('===> end get_reference_facial_points\n')
162
+
163
+ return reference_5point
164
+
165
+
166
+ def get_affine_transform_matrix(src_pts, dst_pts):
167
+ """
168
+ Function:
169
+ ----------
170
+ get affine transform matrix 'tfm' from src_pts to dst_pts
171
+ Parameters:
172
+ ----------
173
+ @src_pts: Kx2 np.array
174
+ source points matrix, each row is a pair of coordinates (x, y)
175
+ @dst_pts: Kx2 np.array
176
+ destination points matrix, each row is a pair of coordinates (x, y)
177
+ Returns:
178
+ ----------
179
+ @tfm: 2x3 np.array
180
+ transform matrix from src_pts to dst_pts
181
+ """
182
+
183
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
184
+ n_pts = src_pts.shape[0]
185
+ ones = np.ones((n_pts, 1), src_pts.dtype)
186
+ src_pts_ = np.hstack([src_pts, ones])
187
+ dst_pts_ = np.hstack([dst_pts, ones])
188
+
189
+ # #print(('src_pts_:\n' + str(src_pts_))
190
+ # #print(('dst_pts_:\n' + str(dst_pts_))
191
+
192
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
193
+
194
+ # #print(('np.linalg.lstsq return A: \n' + str(A))
195
+ # #print(('np.linalg.lstsq return res: \n' + str(res))
196
+ # #print(('np.linalg.lstsq return rank: \n' + str(rank))
197
+ # #print(('np.linalg.lstsq return s: \n' + str(s))
198
+
199
+ if rank == 3:
200
+ tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
201
+ elif rank == 2:
202
+ tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
203
+
204
+ return tfm
205
+
206
+
207
+ def warp_and_crop_face(
208
+ src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"
209
+ ):
210
+ """
211
+ Function:
212
+ ----------
213
+ apply affine transform 'trans' to uv
214
+ Parameters:
215
+ ----------
216
+ @src_img: 3x3 np.array
217
+ input image
218
+ @facial_pts: could be
219
+ 1)a list of K coordinates (x,y)
220
+ or
221
+ 2) Kx2 or 2xK np.array
222
+ each row or col is a pair of coordinates (x, y)
223
+ @reference_pts: could be
224
+ 1) a list of K coordinates (x,y)
225
+ or
226
+ 2) Kx2 or 2xK np.array
227
+ each row or col is a pair of coordinates (x, y)
228
+ or
229
+ 3) None
230
+ if None, use default reference facial points
231
+ @crop_size: (w, h)
232
+ output face image size
233
+ @align_type: transform type, could be one of
234
+ 1) 'similarity': use similarity transform
235
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
236
+ by calling cv2.getAffineTransform()
237
+ 3) 'affine': use all points to do affine transform
238
+ Returns:
239
+ ----------
240
+ @face_img: output face image with size (w, h) = @crop_size
241
+ """
242
+
243
+ if reference_pts is None:
244
+ if crop_size[0] == 96 and crop_size[1] == 112:
245
+ reference_pts = REFERENCE_FACIAL_POINTS
246
+ else:
247
+ default_square = False
248
+ inner_padding_factor = 0
249
+ outer_padding = (0, 0)
250
+ output_size = crop_size
251
+
252
+ reference_pts = get_reference_facial_points(
253
+ output_size, inner_padding_factor, outer_padding, default_square
254
+ )
255
+
256
+ ref_pts = np.float32(reference_pts)
257
+ ref_pts_shp = ref_pts.shape
258
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
259
+ raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2")
260
+
261
+ if ref_pts_shp[0] == 2:
262
+ ref_pts = ref_pts.T
263
+
264
+ src_pts = np.float32(facial_pts)
265
+ src_pts_shp = src_pts.shape
266
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
267
+ raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2")
268
+
269
+ if src_pts_shp[0] == 2:
270
+ src_pts = src_pts.T
271
+
272
+ # #print('--->src_pts:\n', src_pts
273
+ # #print('--->ref_pts\n', ref_pts
274
+
275
+ if src_pts.shape != ref_pts.shape:
276
+ raise FaceWarpException("facial_pts and reference_pts must have the same shape")
277
+
278
+ if align_type == "cv2_affine":
279
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
280
+ # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
281
+ elif align_type == "affine":
282
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
283
+ # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
284
+ else:
285
+ tfm, tfm_inv = get_similarity_transform_for_cv2(src_pts, ref_pts)
286
+ # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
287
+
288
+ # #print('--->Transform matrix: '
289
+ # #print(('type(tfm):' + str(type(tfm)))
290
+ # #print(('tfm.dtype:' + str(tfm.dtype))
291
+ # #print( tfm
292
+
293
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
294
+
295
+ return face_img, tfm
align/box_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def nms(boxes, overlap_threshold=0.5, mode="union"):
6
+ """Non-maximum suppression.
7
+
8
+ Arguments:
9
+ boxes: a float numpy array of shape [n, 5],
10
+ where each row is (xmin, ymin, xmax, ymax, score).
11
+ overlap_threshold: a float number.
12
+ mode: 'union' or 'min'.
13
+
14
+ Returns:
15
+ list with indices of the selected boxes
16
+ """
17
+
18
+ # if there are no boxes, return the empty list
19
+ if len(boxes) == 0:
20
+ return []
21
+
22
+ # list of picked indices
23
+ pick = []
24
+
25
+ # grab the coordinates of the bounding boxes
26
+ x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
27
+
28
+ area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
29
+ ids = np.argsort(score) # in increasing order
30
+
31
+ while len(ids) > 0:
32
+ # grab index of the largest value
33
+ last = len(ids) - 1
34
+ i = ids[last]
35
+ pick.append(i)
36
+
37
+ # compute intersections
38
+ # of the box with the largest score
39
+ # with the rest of boxes
40
+
41
+ # left top corner of intersection boxes
42
+ ix1 = np.maximum(x1[i], x1[ids[:last]])
43
+ iy1 = np.maximum(y1[i], y1[ids[:last]])
44
+
45
+ # right bottom corner of intersection boxes
46
+ ix2 = np.minimum(x2[i], x2[ids[:last]])
47
+ iy2 = np.minimum(y2[i], y2[ids[:last]])
48
+
49
+ # width and height of intersection boxes
50
+ w = np.maximum(0.0, ix2 - ix1 + 1.0)
51
+ h = np.maximum(0.0, iy2 - iy1 + 1.0)
52
+
53
+ # intersections' areas
54
+ inter = w * h
55
+ if mode == "min":
56
+ overlap = inter / np.minimum(area[i], area[ids[:last]])
57
+ elif mode == "union":
58
+ # intersection over union (IoU)
59
+ overlap = inter / (area[i] + area[ids[:last]] - inter)
60
+
61
+ # delete all boxes where overlap is too big
62
+ ids = np.delete(
63
+ ids, np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
64
+ )
65
+
66
+ return pick
67
+
68
+
69
+ def convert_to_square(bboxes):
70
+ """Convert bounding boxes to a square form.
71
+
72
+ Arguments:
73
+ bboxes: a float numpy array of shape [n, 5].
74
+
75
+ Returns:
76
+ a float numpy array of shape [n, 5],
77
+ squared bounding boxes.
78
+ """
79
+
80
+ square_bboxes = np.zeros_like(bboxes)
81
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
82
+ h = y2 - y1 + 1.0
83
+ w = x2 - x1 + 1.0
84
+ max_side = np.maximum(h, w)
85
+ square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
86
+ square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
87
+ square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
88
+ square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
89
+ return square_bboxes
90
+
91
+
92
+ def calibrate_box(bboxes, offsets):
93
+ """Transform bounding boxes to be more like true bounding boxes.
94
+ 'offsets' is one of the outputs of the nets.
95
+
96
+ Arguments:
97
+ bboxes: a float numpy array of shape [n, 5].
98
+ offsets: a float numpy array of shape [n, 4].
99
+
100
+ Returns:
101
+ a float numpy array of shape [n, 5].
102
+ """
103
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
104
+ w = x2 - x1 + 1.0
105
+ h = y2 - y1 + 1.0
106
+ w = np.expand_dims(w, 1)
107
+ h = np.expand_dims(h, 1)
108
+
109
+ # this is what happening here:
110
+ # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
111
+ # x1_true = x1 + tx1*w
112
+ # y1_true = y1 + ty1*h
113
+ # x2_true = x2 + tx2*w
114
+ # y2_true = y2 + ty2*h
115
+ # below is just more compact form of this
116
+
117
+ # are offsets always such that
118
+ # x1 < x2 and y1 < y2 ?
119
+
120
+ translation = np.hstack([w, h, w, h]) * offsets
121
+ bboxes[:, 0:4] = bboxes[:, 0:4] + translation
122
+ return bboxes
123
+
124
+
125
+ def get_image_boxes(bounding_boxes, img, size=24):
126
+ """Cut out boxes from the image.
127
+
128
+ Arguments:
129
+ bounding_boxes: a float numpy array of shape [n, 5].
130
+ img: an instance of PIL.Image.
131
+ size: an integer, size of cutouts.
132
+
133
+ Returns:
134
+ a float numpy array of shape [n, 3, size, size].
135
+ """
136
+
137
+ num_boxes = len(bounding_boxes)
138
+ width, height = img.size
139
+
140
+ [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(
141
+ bounding_boxes, width, height
142
+ )
143
+ img_boxes = np.zeros((num_boxes, 3, size, size), "float32")
144
+
145
+ for i in range(num_boxes):
146
+ img_box = np.zeros((h[i], w[i], 3), "uint8")
147
+
148
+ img_array = np.asarray(img, "uint8")
149
+ img_box[dy[i] : (edy[i] + 1), dx[i] : (edx[i] + 1), :] = img_array[
150
+ y[i] : (ey[i] + 1), x[i] : (ex[i] + 1), :
151
+ ]
152
+
153
+ # resize
154
+ img_box = Image.fromarray(img_box)
155
+ img_box = img_box.resize((size, size), Image.BILINEAR)
156
+ img_box = np.asarray(img_box, "float32")
157
+
158
+ img_boxes[i, :, :, :] = _preprocess(img_box)
159
+
160
+ return img_boxes
161
+
162
+
163
+ def correct_bboxes(bboxes, width, height):
164
+ """Crop boxes that are too big and get coordinates
165
+ with respect to cutouts.
166
+
167
+ Arguments:
168
+ bboxes: a float numpy array of shape [n, 5],
169
+ where each row is (xmin, ymin, xmax, ymax, score).
170
+ width: a float number.
171
+ height: a float number.
172
+
173
+ Returns:
174
+ dy, dx, edy, edx: a int numpy arrays of shape [n],
175
+ coordinates of the boxes with respect to the cutouts.
176
+ y, x, ey, ex: a int numpy arrays of shape [n],
177
+ corrected ymin, xmin, ymax, xmax.
178
+ h, w: a int numpy arrays of shape [n],
179
+ just heights and widths of boxes.
180
+
181
+ in the following order:
182
+ [dy, edy, dx, edx, y, ey, x, ex, w, h].
183
+ """
184
+
185
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
186
+ w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
187
+ num_boxes = bboxes.shape[0]
188
+
189
+ # 'e' stands for end
190
+ # (x, y) -> (ex, ey)
191
+ x, y, ex, ey = x1, y1, x2, y2
192
+
193
+ # we need to cut out a box from the image.
194
+ # (x, y, ex, ey) are corrected coordinates of the box
195
+ # in the image.
196
+ # (dx, dy, edx, edy) are coordinates of the box in the cutout
197
+ # from the image.
198
+ dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
199
+ edx, edy = w.copy() - 1.0, h.copy() - 1.0
200
+
201
+ # if box's bottom right corner is too far right
202
+ ind = np.where(ex > width - 1.0)[0]
203
+ edx[ind] = w[ind] + width - 2.0 - ex[ind]
204
+ ex[ind] = width - 1.0
205
+
206
+ # if box's bottom right corner is too low
207
+ ind = np.where(ey > height - 1.0)[0]
208
+ edy[ind] = h[ind] + height - 2.0 - ey[ind]
209
+ ey[ind] = height - 1.0
210
+
211
+ # if box's top left corner is too far left
212
+ ind = np.where(x < 0.0)[0]
213
+ dx[ind] = 0.0 - x[ind]
214
+ x[ind] = 0.0
215
+
216
+ # if box's top left corner is too high
217
+ ind = np.where(y < 0.0)[0]
218
+ dy[ind] = 0.0 - y[ind]
219
+ y[ind] = 0.0
220
+
221
+ return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
222
+ return_list = [i.astype("int32") for i in return_list]
223
+
224
+ return return_list
225
+
226
+
227
+ def _preprocess(img):
228
+ """Preprocessing step before feeding the network.
229
+
230
+ Arguments:
231
+ img: a float numpy array of shape [h, w, c].
232
+
233
+ Returns:
234
+ a float numpy array of shape [1, c, h, w].
235
+ """
236
+ img = img.transpose((2, 0, 1))
237
+ img = np.expand_dims(img, 0)
238
+ img = (img - 127.5) * 0.0078125
239
+ return img
align/detector.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.autograd import Variable
4
+ import sys
5
+
6
+ sys.path.append("./")
7
+ from align.get_nets import PNet, RNet, ONet
8
+ from align.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
9
+ from align.first_stage import run_first_stage
10
+
11
+
12
+ def detect_faces(
13
+ image,
14
+ min_face_size=20.0,
15
+ thresholds=[0.6, 0.7, 0.8],
16
+ nms_thresholds=[0.7, 0.7, 0.7],
17
+ ):
18
+ """
19
+ Arguments:
20
+ image: an instance of PIL.Image.
21
+ min_face_size: a float number.
22
+ thresholds: a list of length 3.
23
+ nms_thresholds: a list of length 3.
24
+
25
+ Returns:
26
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
27
+ bounding boxes and facial landmarks.
28
+ """
29
+ # LOAD MODELS
30
+ pnet = PNet()
31
+ rnet = RNet()
32
+ onet = ONet()
33
+ onet.eval()
34
+
35
+ # BUILD AN IMAGE PYRAMID
36
+ width, height = image.size
37
+ min_length = min(height, width)
38
+
39
+ min_detection_size = 12
40
+ factor = 0.707 # sqrt(0.5)
41
+
42
+ # scales for scaling the image
43
+ scales = []
44
+
45
+ # scales the image so that
46
+ # minimum size that we can detect equals to
47
+ # minimum face size that we want to detect
48
+ m = min_detection_size / min_face_size
49
+ min_length *= m
50
+
51
+ factor_count = 0
52
+ while min_length > min_detection_size:
53
+ scales.append(m * factor**factor_count)
54
+ min_length *= factor
55
+ factor_count += 1
56
+
57
+ # STAGE 1
58
+
59
+ # it will be returned
60
+ bounding_boxes = []
61
+
62
+ # run P-Net on different scales
63
+ for s in scales:
64
+ boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
65
+ bounding_boxes.append(boxes)
66
+
67
+ # collect boxes (and offsets, and scores) from different scales
68
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
69
+ bounding_boxes = np.vstack(bounding_boxes)
70
+
71
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
72
+ bounding_boxes = bounding_boxes[keep]
73
+
74
+ # use offsets predicted by pnet to transform bounding boxes
75
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
76
+ # shape [n_boxes, 5]
77
+
78
+ bounding_boxes = convert_to_square(bounding_boxes)
79
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
80
+
81
+ # STAGE 2
82
+
83
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
84
+ img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
85
+ output = rnet(img_boxes)
86
+ offsets = output[0].data.numpy() # shape [n_boxes, 4]
87
+ probs = output[1].data.numpy() # shape [n_boxes, 2]
88
+
89
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
90
+ bounding_boxes = bounding_boxes[keep]
91
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
92
+ offsets = offsets[keep]
93
+
94
+ keep = nms(bounding_boxes, nms_thresholds[1])
95
+ bounding_boxes = bounding_boxes[keep]
96
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
97
+ bounding_boxes = convert_to_square(bounding_boxes)
98
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
99
+
100
+ # STAGE 3
101
+
102
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
103
+ if len(img_boxes) == 0:
104
+ return [], []
105
+ img_boxes = Variable(torch.FloatTensor(img_boxes), volatile=True)
106
+ output = onet(img_boxes)
107
+ landmarks = output[0].data.numpy() # shape [n_boxes, 10]
108
+ offsets = output[1].data.numpy() # shape [n_boxes, 4]
109
+ probs = output[2].data.numpy() # shape [n_boxes, 2]
110
+
111
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
112
+ bounding_boxes = bounding_boxes[keep]
113
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
114
+ offsets = offsets[keep]
115
+ landmarks = landmarks[keep]
116
+
117
+ # compute landmark points
118
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
119
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
120
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
121
+ landmarks[:, 0:5] = (
122
+ np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
123
+ )
124
+ landmarks[:, 5:10] = (
125
+ np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
126
+ )
127
+
128
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
129
+ keep = nms(bounding_boxes, nms_thresholds[2], mode="min")
130
+ bounding_boxes = bounding_boxes[keep]
131
+ landmarks = landmarks[keep]
132
+
133
+ return bounding_boxes, landmarks
align/first_stage.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import math
4
+ from PIL import Image
5
+ import numpy as np
6
+ from align.box_utils import nms, _preprocess
7
+
8
+
9
+ def run_first_stage(image, net, scale, threshold):
10
+ """Run P-Net, generate bounding boxes, and do NMS.
11
+
12
+ Arguments:
13
+ image: an instance of PIL.Image.
14
+ net: an instance of pytorch's nn.Module, P-Net.
15
+ scale: a float number,
16
+ scale width and height of the image by this number.
17
+ threshold: a float number,
18
+ threshold on the probability of a face when generating
19
+ bounding boxes from predictions of the net.
20
+
21
+ Returns:
22
+ a float numpy array of shape [n_boxes, 9],
23
+ bounding boxes with scores and offsets (4 + 1 + 4).
24
+ """
25
+
26
+ # scale the image and convert it to a float array
27
+ width, height = image.size
28
+ sw, sh = math.ceil(width * scale), math.ceil(height * scale)
29
+ img = image.resize((sw, sh), Image.BILINEAR)
30
+ img = np.asarray(img, "float32")
31
+
32
+ img = Variable(torch.FloatTensor(_preprocess(img)), volatile=True)
33
+ output = net(img)
34
+ probs = output[1].data.numpy()[0, 1, :, :]
35
+ offsets = output[0].data.numpy()
36
+ # probs: probability of a face at each sliding window
37
+ # offsets: transformations to true bounding boxes
38
+
39
+ boxes = _generate_bboxes(probs, offsets, scale, threshold)
40
+ if len(boxes) == 0:
41
+ return None
42
+
43
+ keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
44
+ return boxes[keep]
45
+
46
+
47
+ def _generate_bboxes(probs, offsets, scale, threshold):
48
+ """Generate bounding boxes at places
49
+ where there is probably a face.
50
+
51
+ Arguments:
52
+ probs: a float numpy array of shape [n, m].
53
+ offsets: a float numpy array of shape [1, 4, n, m].
54
+ scale: a float number,
55
+ width and height of the image were scaled by this number.
56
+ threshold: a float number.
57
+
58
+ Returns:
59
+ a float numpy array of shape [n_boxes, 9]
60
+ """
61
+
62
+ # applying P-Net is equivalent, in some sense, to
63
+ # moving 12x12 window with stride 2
64
+ stride = 2
65
+ cell_size = 12
66
+
67
+ # indices of boxes where there is probably a face
68
+ inds = np.where(probs > threshold)
69
+
70
+ if inds[0].size == 0:
71
+ return np.array([])
72
+
73
+ # transformations of bounding boxes
74
+ tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
75
+ # they are defined as:
76
+ # w = x2 - x1 + 1
77
+ # h = y2 - y1 + 1
78
+ # x1_true = x1 + tx1*w
79
+ # x2_true = x2 + tx2*w
80
+ # y1_true = y1 + ty1*h
81
+ # y2_true = y2 + ty2*h
82
+
83
+ offsets = np.array([tx1, ty1, tx2, ty2])
84
+ score = probs[inds[0], inds[1]]
85
+
86
+ # P-Net is applied to scaled images
87
+ # so we need to rescale bounding boxes back
88
+ bounding_boxes = np.vstack(
89
+ [
90
+ np.round((stride * inds[1] + 1.0) / scale),
91
+ np.round((stride * inds[0] + 1.0) / scale),
92
+ np.round((stride * inds[1] + 1.0 + cell_size) / scale),
93
+ np.round((stride * inds[0] + 1.0 + cell_size) / scale),
94
+ score,
95
+ offsets,
96
+ ]
97
+ )
98
+ # why one is added?
99
+
100
+ return bounding_boxes.T
align/get_nets.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+
7
+
8
+ class Flatten(nn.Module):
9
+ def __init__(self):
10
+ super(Flatten, self).__init__()
11
+
12
+ def forward(self, x):
13
+ """
14
+ Arguments:
15
+ x: a float tensor with shape [batch_size, c, h, w].
16
+ Returns:
17
+ a float tensor with shape [batch_size, c*h*w].
18
+ """
19
+
20
+ # without this pretrained model isn't working
21
+ x = x.transpose(3, 2).contiguous()
22
+
23
+ return x.view(x.size(0), -1)
24
+
25
+
26
+ class PNet(nn.Module):
27
+ def __init__(self):
28
+ super(PNet, self).__init__()
29
+
30
+ # suppose we have input with size HxW, then
31
+ # after first layer: H - 2,
32
+ # after pool: ceil((H - 2)/2),
33
+ # after second conv: ceil((H - 2)/2) - 2,
34
+ # after last conv: ceil((H - 2)/2) - 4,
35
+ # and the same for W
36
+
37
+ self.features = nn.Sequential(
38
+ OrderedDict(
39
+ [
40
+ ("conv1", nn.Conv2d(3, 10, 3, 1)),
41
+ ("prelu1", nn.PReLU(10)),
42
+ ("pool1", nn.MaxPool2d(2, 2, ceil_mode=True)),
43
+ ("conv2", nn.Conv2d(10, 16, 3, 1)),
44
+ ("prelu2", nn.PReLU(16)),
45
+ ("conv3", nn.Conv2d(16, 32, 3, 1)),
46
+ ("prelu3", nn.PReLU(32)),
47
+ ]
48
+ )
49
+ )
50
+
51
+ self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
52
+ self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
53
+
54
+ weights = np.load("align/pnet.npy", allow_pickle=True)[()]
55
+ for n, p in self.named_parameters():
56
+ p.data = torch.FloatTensor(weights[n])
57
+
58
+ def forward(self, x):
59
+ """
60
+ Arguments:
61
+ x: a float tensor with shape [batch_size, 3, h, w].
62
+ Returns:
63
+ b: a float tensor with shape [batch_size, 4, h', w'].
64
+ a: a float tensor with shape [batch_size, 2, h', w'].
65
+ """
66
+ x = self.features(x)
67
+ a = self.conv4_1(x)
68
+ b = self.conv4_2(x)
69
+ a = F.softmax(a)
70
+ return b, a
71
+
72
+
73
+ class RNet(nn.Module):
74
+ def __init__(self):
75
+ super(RNet, self).__init__()
76
+
77
+ self.features = nn.Sequential(
78
+ OrderedDict(
79
+ [
80
+ ("conv1", nn.Conv2d(3, 28, 3, 1)),
81
+ ("prelu1", nn.PReLU(28)),
82
+ ("pool1", nn.MaxPool2d(3, 2, ceil_mode=True)),
83
+ ("conv2", nn.Conv2d(28, 48, 3, 1)),
84
+ ("prelu2", nn.PReLU(48)),
85
+ ("pool2", nn.MaxPool2d(3, 2, ceil_mode=True)),
86
+ ("conv3", nn.Conv2d(48, 64, 2, 1)),
87
+ ("prelu3", nn.PReLU(64)),
88
+ ("flatten", Flatten()),
89
+ ("conv4", nn.Linear(576, 128)),
90
+ ("prelu4", nn.PReLU(128)),
91
+ ]
92
+ )
93
+ )
94
+
95
+ self.conv5_1 = nn.Linear(128, 2)
96
+ self.conv5_2 = nn.Linear(128, 4)
97
+
98
+ weights = np.load("align/rnet.npy", allow_pickle=True)[()]
99
+ for n, p in self.named_parameters():
100
+ p.data = torch.FloatTensor(weights[n])
101
+
102
+ def forward(self, x):
103
+ """
104
+ Arguments:
105
+ x: a float tensor with shape [batch_size, 3, h, w].
106
+ Returns:
107
+ b: a float tensor with shape [batch_size, 4].
108
+ a: a float tensor with shape [batch_size, 2].
109
+ """
110
+ x = self.features(x)
111
+ a = self.conv5_1(x)
112
+ b = self.conv5_2(x)
113
+ a = F.softmax(a)
114
+ return b, a
115
+
116
+
117
+ class ONet(nn.Module):
118
+ def __init__(self):
119
+ super(ONet, self).__init__()
120
+
121
+ self.features = nn.Sequential(
122
+ OrderedDict(
123
+ [
124
+ ("conv1", nn.Conv2d(3, 32, 3, 1)),
125
+ ("prelu1", nn.PReLU(32)),
126
+ ("pool1", nn.MaxPool2d(3, 2, ceil_mode=True)),
127
+ ("conv2", nn.Conv2d(32, 64, 3, 1)),
128
+ ("prelu2", nn.PReLU(64)),
129
+ ("pool2", nn.MaxPool2d(3, 2, ceil_mode=True)),
130
+ ("conv3", nn.Conv2d(64, 64, 3, 1)),
131
+ ("prelu3", nn.PReLU(64)),
132
+ ("pool3", nn.MaxPool2d(2, 2, ceil_mode=True)),
133
+ ("conv4", nn.Conv2d(64, 128, 2, 1)),
134
+ ("prelu4", nn.PReLU(128)),
135
+ ("flatten", Flatten()),
136
+ ("conv5", nn.Linear(1152, 256)),
137
+ ("drop5", nn.Dropout(0.25)),
138
+ ("prelu5", nn.PReLU(256)),
139
+ ]
140
+ )
141
+ )
142
+
143
+ self.conv6_1 = nn.Linear(256, 2)
144
+ self.conv6_2 = nn.Linear(256, 4)
145
+ self.conv6_3 = nn.Linear(256, 10)
146
+
147
+ weights = np.load("align/onet.npy", allow_pickle=True)[()]
148
+ for n, p in self.named_parameters():
149
+ p.data = torch.FloatTensor(weights[n])
150
+
151
+ def forward(self, x):
152
+ """
153
+ Arguments:
154
+ x: a float tensor with shape [batch_size, 3, h, w].
155
+ Returns:
156
+ c: a float tensor with shape [batch_size, 10].
157
+ b: a float tensor with shape [batch_size, 4].
158
+ a: a float tensor with shape [batch_size, 2].
159
+ """
160
+ x = self.features(x)
161
+ a = self.conv6_1(x)
162
+ b = self.conv6_2(x)
163
+ c = self.conv6_3(x)
164
+ a = F.softmax(a)
165
+ return c, b, a
align/matlab_cp2tform.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numpy.linalg import inv, norm, lstsq
3
+ from numpy.linalg import matrix_rank as rank
4
+
5
+
6
+ class MatlabCp2tormException(Exception):
7
+ def __str__(self):
8
+ return "In File {}:{}".format(__file__, super.__str__(self))
9
+
10
+
11
+ def tformfwd(trans, uv):
12
+ """
13
+ Function:
14
+ ----------
15
+ apply affine transform 'trans' to uv
16
+
17
+ Parameters:
18
+ ----------
19
+ @trans: 3x3 np.array
20
+ transform matrix
21
+ @uv: Kx2 np.array
22
+ each row is a pair of coordinates (x, y)
23
+
24
+ Returns:
25
+ ----------
26
+ @xy: Kx2 np.array
27
+ each row is a pair of transformed coordinates (x, y)
28
+ """
29
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
30
+ xy = np.dot(uv, trans)
31
+ xy = xy[:, 0:-1]
32
+ return xy
33
+
34
+
35
+ def tforminv(trans, uv):
36
+ """
37
+ Function:
38
+ ----------
39
+ apply the inverse of affine transform 'trans' to uv
40
+
41
+ Parameters:
42
+ ----------
43
+ @trans: 3x3 np.array
44
+ transform matrix
45
+ @uv: Kx2 np.array
46
+ each row is a pair of coordinates (x, y)
47
+
48
+ Returns:
49
+ ----------
50
+ @xy: Kx2 np.array
51
+ each row is a pair of inverse-transformed coordinates (x, y)
52
+ """
53
+ Tinv = inv(trans)
54
+ xy = tformfwd(Tinv, uv)
55
+ return xy
56
+
57
+
58
+ def findNonreflectiveSimilarity(uv, xy, options=None):
59
+ options = {"K": 2}
60
+
61
+ K = options["K"]
62
+ M = xy.shape[0]
63
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
64
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
65
+ # print('--->x, y:\n', x, y
66
+
67
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
68
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
69
+ X = np.vstack((tmp1, tmp2))
70
+ # print('--->X.shape: ', X.shape
71
+ # print('X:\n', X
72
+
73
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
74
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
75
+ U = np.vstack((u, v))
76
+ # print('--->U.shape: ', U.shape
77
+ # print('U:\n', U
78
+
79
+ # We know that X * r = U
80
+ if rank(X) >= 2 * K:
81
+ r, _, _, _ = lstsq(X, U)
82
+ r = np.squeeze(r)
83
+ else:
84
+ raise Exception("cp2tform: two Unique Points Req")
85
+
86
+ # print('--->r:\n', r
87
+
88
+ sc = r[0]
89
+ ss = r[1]
90
+ tx = r[2]
91
+ ty = r[3]
92
+
93
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
94
+
95
+ # print('--->Tinv:\n', Tinv
96
+
97
+ T = inv(Tinv)
98
+ # print('--->T:\n', T
99
+
100
+ T[:, 2] = np.array([0, 0, 1])
101
+
102
+ return T, Tinv
103
+
104
+
105
+ def findSimilarity(uv, xy, options=None):
106
+ options = {"K": 2}
107
+
108
+ # uv = np.array(uv)
109
+ # xy = np.array(xy)
110
+
111
+ # Solve for trans1
112
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
113
+
114
+ # Solve for trans2
115
+
116
+ # manually reflect the xy data across the Y-axis
117
+ xyR = xy
118
+ xyR[:, 0] = -1 * xyR[:, 0]
119
+
120
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
121
+
122
+ # manually reflect the tform to undo the reflection done on xyR
123
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
124
+
125
+ trans2 = np.dot(trans2r, TreflectY)
126
+
127
+ # Figure out if trans1 or trans2 is better
128
+ xy1 = tformfwd(trans1, uv)
129
+ norm1 = norm(xy1 - xy)
130
+
131
+ xy2 = tformfwd(trans2, uv)
132
+ norm2 = norm(xy2 - xy)
133
+
134
+ if norm1 <= norm2:
135
+ return trans1, trans1_inv
136
+ else:
137
+ trans2_inv = inv(trans2)
138
+ return trans2, trans2_inv
139
+
140
+
141
+ def get_similarity_transform(src_pts, dst_pts, reflective=True):
142
+ """
143
+ Function:
144
+ ----------
145
+ Find Similarity Transform Matrix 'trans':
146
+ u = src_pts[:, 0]
147
+ v = src_pts[:, 1]
148
+ x = dst_pts[:, 0]
149
+ y = dst_pts[:, 1]
150
+ [x, y, 1] = [u, v, 1] * trans
151
+
152
+ Parameters:
153
+ ----------
154
+ @src_pts: Kx2 np.array
155
+ source points, each row is a pair of coordinates (x, y)
156
+ @dst_pts: Kx2 np.array
157
+ destination points, each row is a pair of transformed
158
+ coordinates (x, y)
159
+ @reflective: True or False
160
+ if True:
161
+ use reflective similarity transform
162
+ else:
163
+ use non-reflective similarity transform
164
+
165
+ Returns:
166
+ ----------
167
+ @trans: 3x3 np.array
168
+ transform matrix from uv to xy
169
+ trans_inv: 3x3 np.array
170
+ inverse of trans, transform matrix from xy to uv
171
+ """
172
+
173
+ if reflective:
174
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
175
+ else:
176
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
177
+
178
+ return trans, trans_inv
179
+
180
+
181
+ def cvt_tform_mat_for_cv2(trans):
182
+ """
183
+ Function:
184
+ ----------
185
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
186
+ directly used by cv2.warpAffine():
187
+ u = src_pts[:, 0]
188
+ v = src_pts[:, 1]
189
+ x = dst_pts[:, 0]
190
+ y = dst_pts[:, 1]
191
+ [x, y].T = cv_trans * [u, v, 1].T
192
+
193
+ Parameters:
194
+ ----------
195
+ @trans: 3x3 np.array
196
+ transform matrix from uv to xy
197
+
198
+ Returns:
199
+ ----------
200
+ @cv2_trans: 2x3 np.array
201
+ transform matrix from src_pts to dst_pts, could be directly used
202
+ for cv2.warpAffine()
203
+ """
204
+ cv2_trans = trans[:, 0:2].T
205
+
206
+ return cv2_trans
207
+
208
+
209
+ def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
210
+ """
211
+ Function:
212
+ ----------
213
+ Find Similarity Transform Matrix 'cv2_trans' which could be
214
+ directly used by cv2.warpAffine():
215
+ u = src_pts[:, 0]
216
+ v = src_pts[:, 1]
217
+ x = dst_pts[:, 0]
218
+ y = dst_pts[:, 1]
219
+ [x, y].T = cv_trans * [u, v, 1].T
220
+
221
+ Parameters:
222
+ ----------
223
+ @src_pts: Kx2 np.array
224
+ source points, each row is a pair of coordinates (x, y)
225
+ @dst_pts: Kx2 np.array
226
+ destination points, each row is a pair of transformed
227
+ coordinates (x, y)
228
+ reflective: True or False
229
+ if True:
230
+ use reflective similarity transform
231
+ else:
232
+ use non-reflective similarity transform
233
+
234
+ Returns:
235
+ ----------
236
+ @cv2_trans: 2x3 np.array
237
+ transform matrix from src_pts to dst_pts, could be directly used
238
+ for cv2.warpAffine()
239
+ """
240
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
241
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
242
+ cv2_trans_inv = cvt_tform_mat_for_cv2(trans_inv)
243
+
244
+ return cv2_trans, cv2_trans_inv
245
+
246
+
247
+ if __name__ == "__main__":
248
+ """
249
+ u = [0, 6, -2]
250
+ v = [0, 3, 5]
251
+ x = [-1, 0, 4]
252
+ y = [-1, -10, 4]
253
+
254
+ # In Matlab, run:
255
+ #
256
+ # uv = [u'; v'];
257
+ # xy = [x'; y'];
258
+ # tform_sim=cp2tform(uv,xy,'similarity');
259
+ #
260
+ # trans = tform_sim.tdata.T
261
+ # ans =
262
+ # -0.0764 -1.6190 0
263
+ # 1.6190 -0.0764 0
264
+ # -3.2156 0.0290 1.0000
265
+ # trans_inv = tform_sim.tdata.Tinv
266
+ # ans =
267
+ #
268
+ # -0.0291 0.6163 0
269
+ # -0.6163 -0.0291 0
270
+ # -0.0756 1.9826 1.0000
271
+ # xy_m=tformfwd(tform_sim, u,v)
272
+ #
273
+ # xy_m =
274
+ #
275
+ # -3.2156 0.0290
276
+ # 1.1833 -9.9143
277
+ # 5.0323 2.8853
278
+ # uv_m=tforminv(tform_sim, x,y)
279
+ #
280
+ # uv_m =
281
+ #
282
+ # 0.5698 1.3953
283
+ # 6.0872 2.2733
284
+ # -2.6570 4.3314
285
+ """
286
+ u = [0, 6, -2]
287
+ v = [0, 3, 5]
288
+ x = [-1, 0, 4]
289
+ y = [-1, -10, 4]
290
+
291
+ uv = np.array((u, v)).T
292
+ xy = np.array((x, y)).T
293
+
294
+ print("\n--->uv:")
295
+ print(uv)
296
+ print("\n--->xy:")
297
+ print(xy)
298
+
299
+ trans, trans_inv = get_similarity_transform(uv, xy)
300
+
301
+ print("\n--->trans matrix:")
302
+ print(trans)
303
+
304
+ print("\n--->trans_inv matrix:")
305
+ print(trans_inv)
306
+
307
+ print("\n---> apply transform to uv")
308
+ print("\nxy_m = uv_augmented * trans")
309
+ uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
310
+ xy_m = np.dot(uv_aug, trans)
311
+ print(xy_m)
312
+
313
+ print("\nxy_m = tformfwd(trans, uv)")
314
+ xy_m = tformfwd(trans, uv)
315
+ print(xy_m)
316
+
317
+ print("\n---> apply inverse transform to xy")
318
+ print("\nuv_m = xy_augmented * trans_inv")
319
+ xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
320
+ uv_m = np.dot(xy_aug, trans_inv)
321
+ print(uv_m)
322
+
323
+ print("\nuv_m = tformfwd(trans_inv, xy)")
324
+ uv_m = tformfwd(trans_inv, xy)
325
+ print(uv_m)
326
+
327
+ uv_m = tforminv(trans, xy)
328
+ print("\nuv_m = tforminv(trans, xy)")
329
+ print(uv_m)
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from util.feature_extraction_utils import normalize_transforms
6
+ from util.attack_utils import Attack
7
+ from util.prepare_utils import prepare_models, prepare_dir_vec, get_ensemble
8
+ from align.detector import detect_faces
9
+ from align.align_trans import get_reference_facial_points, warp_and_crop_face
10
+ import torchvision.transforms as transforms
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(device)
14
+ to_tensor = transforms.ToTensor()
15
+
16
+ eps = 0.05
17
+ n_iters = 50
18
+ input_size = [112, 112]
19
+ attack_type = "lpips"
20
+ c_tv = None
21
+ c_sim = 0.05
22
+ lr = 0.0025
23
+ net_type = "alex"
24
+ noise_size = 0.005
25
+ n_starts = 1
26
+ kernel_size_gf = 7
27
+ sigma_gf = 3
28
+ combination = True
29
+ using_subspace = False
30
+ V_reduction_root = "./"
31
+ model_backbones = ["IR_152", "IR_152", "ResNet_152", "ResNet_152"]
32
+ model_roots = [
33
+ "models/Backbone_IR_152_Arcface_Epoch_112.pth",
34
+ "models/Backbone_IR_152_Cosface_Epoch_70.pth",
35
+ "models/Backbone_ResNet_152_Arcface_Epoch_65.pth",
36
+ "models/Backbone_ResNet_152_Cosface_Epoch_68.pth",
37
+ ]
38
+ direction = 1
39
+ crop_size = 112
40
+ scale = crop_size / 112.0
41
+
42
+ models_attack, V_reduction, dim = prepare_models(
43
+ model_backbones,
44
+ input_size,
45
+ model_roots,
46
+ kernel_size_gf,
47
+ sigma_gf,
48
+ combination,
49
+ using_subspace,
50
+ V_reduction_root,
51
+ )
52
+
53
+
54
+ def protect(img):
55
+ img = Image.fromarray(img)
56
+ reference = get_reference_facial_points(default_square=True) * scale
57
+ h, w, c = np.array(img).shape
58
+
59
+ _, landmarks = detect_faces(img)
60
+ facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
61
+
62
+ _, tfm = warp_and_crop_face(
63
+ np.array(img), facial5points, reference, crop_size=(crop_size, crop_size)
64
+ )
65
+
66
+ # pytorch transform
67
+ theta = normalize_transforms(tfm, w, h)
68
+ tensor_img = to_tensor(img).unsqueeze(0).to(device)
69
+
70
+ V_reduction = None
71
+ dim = 512
72
+
73
+ # Find gradient direction vector
74
+ dir_vec_extractor = get_ensemble(
75
+ models=models_attack,
76
+ sigma_gf=None,
77
+ kernel_size_gf=None,
78
+ combination=False,
79
+ V_reduction=V_reduction,
80
+ warp=True,
81
+ theta_warp=theta,
82
+ )
83
+ dir_vec = prepare_dir_vec(dir_vec_extractor, tensor_img, dim, combination)
84
+
85
+ img_attacked = tensor_img.clone()
86
+ attack = Attack(
87
+ models_attack,
88
+ dim,
89
+ attack_type,
90
+ eps,
91
+ c_sim,
92
+ net_type,
93
+ lr,
94
+ n_iters,
95
+ noise_size,
96
+ n_starts,
97
+ c_tv,
98
+ sigma_gf,
99
+ kernel_size_gf,
100
+ combination,
101
+ warp=True,
102
+ theta_warp=theta,
103
+ V_reduction=V_reduction,
104
+ )
105
+ img_attacked = attack.execute(tensor_img, dir_vec, direction).detach().cpu()
106
+
107
+ img_attacked_pil = transforms.ToPILImage()(img_attacked[0])
108
+ return img_attacked_pil
109
+
110
+
111
+ gr.Interface(
112
+ fn=protect,
113
+ inputs=gr.components.Image(shape=(512, 512)),
114
+ outputs=gr.components.Image(type="pil"),
115
+ allow_flagging="never",
116
+ ).launch(show_error=True, quiet=False, share=True)
backbone/__init__.py ADDED
File without changes
backbone/model_irse.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import (
4
+ Linear,
5
+ Conv2d,
6
+ BatchNorm1d,
7
+ BatchNorm2d,
8
+ PReLU,
9
+ ReLU,
10
+ Sigmoid,
11
+ Dropout,
12
+ MaxPool2d,
13
+ AdaptiveAvgPool2d,
14
+ Sequential,
15
+ Module,
16
+ )
17
+ from collections import namedtuple
18
+
19
+
20
+ # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152']
21
+
22
+
23
+ class Flatten(Module):
24
+ def forward(self, input):
25
+ return input.view(input.size(0), -1)
26
+
27
+
28
+ def l2_norm(input, axis=1):
29
+ norm = torch.norm(input, 2, axis, True)
30
+ output = torch.div(input, norm)
31
+
32
+ return output
33
+
34
+
35
+ class SEModule(Module):
36
+ def __init__(self, channels, reduction):
37
+ super(SEModule, self).__init__()
38
+ self.avg_pool = AdaptiveAvgPool2d(1)
39
+ self.fc1 = Conv2d(
40
+ channels, channels // reduction, kernel_size=1, padding=0, bias=False
41
+ )
42
+
43
+ nn.init.xavier_uniform_(self.fc1.weight.data)
44
+
45
+ self.relu = ReLU(inplace=True)
46
+ self.fc2 = Conv2d(
47
+ channels // reduction, channels, kernel_size=1, padding=0, bias=False
48
+ )
49
+
50
+ self.sigmoid = Sigmoid()
51
+
52
+ def forward(self, x):
53
+ module_input = x
54
+ x = self.avg_pool(x)
55
+ x = self.fc1(x)
56
+ x = self.relu(x)
57
+ x = self.fc2(x)
58
+ x = self.sigmoid(x)
59
+
60
+ return module_input * x
61
+
62
+
63
+ class bottleneck_IR(Module):
64
+ def __init__(self, in_channel, depth, stride):
65
+ super(bottleneck_IR, self).__init__()
66
+ if in_channel == depth:
67
+ self.shortcut_layer = MaxPool2d(1, stride)
68
+ else:
69
+ self.shortcut_layer = Sequential(
70
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
71
+ BatchNorm2d(depth),
72
+ )
73
+ self.res_layer = Sequential(
74
+ BatchNorm2d(in_channel),
75
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
76
+ PReLU(depth),
77
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
78
+ BatchNorm2d(depth),
79
+ )
80
+
81
+ def forward(self, x):
82
+ shortcut = self.shortcut_layer(x)
83
+ res = self.res_layer(x)
84
+
85
+ return res + shortcut
86
+
87
+
88
+ class bottleneck_IR_SE(Module):
89
+ def __init__(self, in_channel, depth, stride):
90
+ super(bottleneck_IR_SE, self).__init__()
91
+ if in_channel == depth:
92
+ self.shortcut_layer = MaxPool2d(1, stride)
93
+ else:
94
+ self.shortcut_layer = Sequential(
95
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
96
+ BatchNorm2d(depth),
97
+ )
98
+ self.res_layer = Sequential(
99
+ BatchNorm2d(in_channel),
100
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
101
+ PReLU(depth),
102
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
103
+ BatchNorm2d(depth),
104
+ SEModule(depth, 16),
105
+ )
106
+
107
+ def forward(self, x):
108
+ shortcut = self.shortcut_layer(x)
109
+ res = self.res_layer(x)
110
+
111
+ return res + shortcut
112
+
113
+
114
+ class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
115
+ """A named tuple describing a ResNet block."""
116
+
117
+
118
+ def get_block(in_channel, depth, num_units, stride=2):
119
+ return [Bottleneck(in_channel, depth, stride)] + [
120
+ Bottleneck(depth, depth, 1) for i in range(num_units - 1)
121
+ ]
122
+
123
+
124
+ def get_blocks(num_layers):
125
+ if num_layers == 50:
126
+ blocks = [
127
+ get_block(in_channel=64, depth=64, num_units=3),
128
+ get_block(in_channel=64, depth=128, num_units=4),
129
+ get_block(in_channel=128, depth=256, num_units=14),
130
+ get_block(in_channel=256, depth=512, num_units=3),
131
+ ]
132
+ elif num_layers == 100:
133
+ blocks = [
134
+ get_block(in_channel=64, depth=64, num_units=3),
135
+ get_block(in_channel=64, depth=128, num_units=13),
136
+ get_block(in_channel=128, depth=256, num_units=30),
137
+ get_block(in_channel=256, depth=512, num_units=3),
138
+ ]
139
+ elif num_layers == 152:
140
+ blocks = [
141
+ get_block(in_channel=64, depth=64, num_units=3),
142
+ get_block(in_channel=64, depth=128, num_units=8),
143
+ get_block(in_channel=128, depth=256, num_units=36),
144
+ get_block(in_channel=256, depth=512, num_units=3),
145
+ ]
146
+
147
+ return blocks
148
+
149
+
150
+ class Backbone(Module):
151
+ def __init__(self, input_size, num_layers, mode="ir"):
152
+ super(Backbone, self).__init__()
153
+ assert input_size[0] in [
154
+ 112,
155
+ 224,
156
+ ], "input_size should be [112, 112] or [224, 224]"
157
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
158
+ assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
159
+ blocks = get_blocks(num_layers)
160
+ if mode == "ir":
161
+ unit_module = bottleneck_IR
162
+ elif mode == "ir_se":
163
+ unit_module = bottleneck_IR_SE
164
+ self.input_layer = Sequential(
165
+ Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
166
+ )
167
+ if input_size[0] == 112:
168
+ self.output_layer = Sequential(
169
+ BatchNorm2d(512),
170
+ Dropout(),
171
+ Flatten(),
172
+ Linear(512 * 7 * 7, 512),
173
+ BatchNorm1d(512),
174
+ )
175
+ else:
176
+ self.output_layer = Sequential(
177
+ BatchNorm2d(512),
178
+ Dropout(),
179
+ Flatten(),
180
+ Linear(512 * 14 * 14, 512),
181
+ BatchNorm1d(512),
182
+ )
183
+
184
+ modules = []
185
+ for block in blocks:
186
+ for bottleneck in block:
187
+ modules.append(
188
+ unit_module(
189
+ bottleneck.in_channel, bottleneck.depth, bottleneck.stride
190
+ )
191
+ )
192
+ self.body = Sequential(*modules)
193
+
194
+ self._initialize_weights()
195
+
196
+ def forward(self, x):
197
+ x = self.input_layer(x)
198
+ x = self.body(x)
199
+ x = self.output_layer(x)
200
+
201
+ return x
202
+
203
+ def _initialize_weights(self):
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.xavier_uniform_(m.weight.data)
207
+ if m.bias is not None:
208
+ m.bias.data.zero_()
209
+ elif isinstance(m, nn.BatchNorm2d):
210
+ m.weight.data.fill_(1)
211
+ m.bias.data.zero_()
212
+ elif isinstance(m, nn.BatchNorm1d):
213
+ m.weight.data.fill_(1)
214
+ m.bias.data.zero_()
215
+ elif isinstance(m, nn.Linear):
216
+ nn.init.xavier_uniform_(m.weight.data)
217
+ if m.bias is not None:
218
+ m.bias.data.zero_()
219
+
220
+
221
+ def IR_50(input_size):
222
+ """Constructs a ir-50 model."""
223
+ model = Backbone(input_size, 50, "ir")
224
+
225
+ return model
226
+
227
+
228
+ def IR_101(input_size):
229
+ """Constructs a ir-101 model."""
230
+ model = Backbone(input_size, 100, "ir")
231
+
232
+ return model
233
+
234
+
235
+ def IR_152(input_size):
236
+ """Constructs a ir-152 model."""
237
+ model = Backbone(input_size, 152, "ir")
238
+
239
+ return model
240
+
241
+
242
+ def IR_SE_50(input_size):
243
+ """Constructs a ir_se-50 model."""
244
+ model = Backbone(input_size, 50, "ir_se")
245
+
246
+ return model
247
+
248
+
249
+ def IR_SE_101(input_size):
250
+ """Constructs a ir_se-101 model."""
251
+ model = Backbone(input_size, 100, "ir_se")
252
+
253
+ return model
254
+
255
+
256
+ def IR_SE_152(input_size):
257
+ """Constructs a ir_se-152 model."""
258
+ model = Backbone(input_size, 152, "ir_se")
259
+
260
+ return model
backbone/model_resnet.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import (
3
+ Linear,
4
+ Conv2d,
5
+ BatchNorm1d,
6
+ BatchNorm2d,
7
+ ReLU,
8
+ Dropout,
9
+ MaxPool2d,
10
+ Sequential,
11
+ Module,
12
+ )
13
+
14
+
15
+ # Support: ['ResNet_50', 'ResNet_101', 'ResNet_152']
16
+
17
+
18
+ def conv3x3(in_planes, out_planes, stride=1):
19
+ """3x3 convolution with padding"""
20
+
21
+ return Conv2d(
22
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
23
+ )
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ """1x1 convolution"""
28
+
29
+ return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
30
+
31
+
32
+ class BasicBlock(Module):
33
+ expansion = 1
34
+
35
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
36
+ super(BasicBlock, self).__init__()
37
+ self.conv1 = conv3x3(inplanes, planes, stride)
38
+ self.bn1 = BatchNorm2d(planes)
39
+ self.relu = ReLU(inplace=True)
40
+ self.conv2 = conv3x3(planes, planes)
41
+ self.bn2 = BatchNorm2d(planes)
42
+ self.downsample = downsample
43
+ self.stride = stride
44
+
45
+ def forward(self, x):
46
+ identity = x
47
+
48
+ out = self.conv1(x)
49
+ out = self.bn1(out)
50
+ out = self.relu(out)
51
+
52
+ out = self.conv2(out)
53
+ out = self.bn2(out)
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu(out)
60
+
61
+ return out
62
+
63
+
64
+ class Bottleneck(Module):
65
+ expansion = 4
66
+
67
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
68
+ super(Bottleneck, self).__init__()
69
+ self.conv1 = conv1x1(inplanes, planes)
70
+ self.bn1 = BatchNorm2d(planes)
71
+ self.conv2 = conv3x3(planes, planes, stride)
72
+ self.bn2 = BatchNorm2d(planes)
73
+ self.conv3 = conv1x1(planes, planes * self.expansion)
74
+ self.bn3 = BatchNorm2d(planes * self.expansion)
75
+ self.relu = ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ identity = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ identity = self.downsample(x)
95
+
96
+ out += identity
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+
102
+ class ResNet(Module):
103
+ def __init__(self, input_size, block, layers, zero_init_residual=True):
104
+ super(ResNet, self).__init__()
105
+ assert input_size[0] in [
106
+ 112,
107
+ 224,
108
+ ], "input_size should be [112, 112] or [224, 224]"
109
+ self.inplanes = 64
110
+ self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
111
+ self.bn1 = BatchNorm2d(64)
112
+ self.relu = ReLU(inplace=True)
113
+ self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
114
+ self.layer1 = self._make_layer(block, 64, layers[0])
115
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
116
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
117
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
118
+
119
+ self.bn_o1 = BatchNorm2d(2048)
120
+ self.dropout = Dropout()
121
+ if input_size[0] == 112:
122
+ self.fc = Linear(2048 * 4 * 4, 512)
123
+ else:
124
+ self.fc = Linear(2048 * 8 * 8, 512)
125
+ self.bn_o2 = BatchNorm1d(512)
126
+
127
+ for m in self.modules():
128
+ if isinstance(m, Conv2d):
129
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
130
+ elif isinstance(m, BatchNorm2d):
131
+ nn.init.constant_(m.weight, 1)
132
+ nn.init.constant_(m.bias, 0)
133
+
134
+ # Zero-initialize the last BN in each residual branch,
135
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
136
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
137
+ if zero_init_residual:
138
+ for m in self.modules():
139
+ if isinstance(m, Bottleneck):
140
+ nn.init.constant_(m.bn3.weight, 0)
141
+ elif isinstance(m, BasicBlock):
142
+ nn.init.constant_(m.bn2.weight, 0)
143
+
144
+ def _make_layer(self, block, planes, blocks, stride=1):
145
+ downsample = None
146
+ if stride != 1 or self.inplanes != planes * block.expansion:
147
+ downsample = Sequential(
148
+ conv1x1(self.inplanes, planes * block.expansion, stride),
149
+ BatchNorm2d(planes * block.expansion),
150
+ )
151
+
152
+ layers = []
153
+ layers.append(block(self.inplanes, planes, stride, downsample))
154
+ self.inplanes = planes * block.expansion
155
+ for _ in range(1, blocks):
156
+ layers.append(block(self.inplanes, planes))
157
+
158
+ return Sequential(*layers)
159
+
160
+ def forward(self, x):
161
+ x = self.conv1(x)
162
+ x = self.bn1(x)
163
+ x = self.relu(x)
164
+ x = self.maxpool(x)
165
+
166
+ x = self.layer1(x)
167
+ x = self.layer2(x)
168
+ x = self.layer3(x)
169
+ x = self.layer4(x)
170
+
171
+ x = self.bn_o1(x)
172
+ x = self.dropout(x)
173
+ x = x.view(x.size(0), -1)
174
+ x = self.fc(x)
175
+ x = self.bn_o2(x)
176
+
177
+ return x
178
+
179
+
180
+ def ResNet_18(input_size, **kwargs):
181
+ """Constructs a ResNet-50 model."""
182
+ model = ResNet(input_size, Bottleneck, [2, 2, 2, 2], **kwargs)
183
+
184
+ return model
185
+
186
+
187
+ def ResNet_50(input_size, **kwargs):
188
+ """Constructs a ResNet-50 model."""
189
+ model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
190
+
191
+ return model
192
+
193
+
194
+ def ResNet_101(input_size, **kwargs):
195
+ """Constructs a ResNet-101 model."""
196
+ model = ResNet(input_size, Bottleneck, [3, 4, 23, 3], **kwargs)
197
+
198
+ return model
199
+
200
+
201
+ def ResNet_152(input_size, **kwargs):
202
+ """Constructs a ResNet-152 model."""
203
+ model = ResNet(input_size, Bottleneck, [3, 8, 36, 3], **kwargs)
204
+
205
+ return model
backbone/models2.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import (
2
+ Linear,
3
+ Conv2d,
4
+ BatchNorm1d,
5
+ BatchNorm2d,
6
+ PReLU,
7
+ ReLU,
8
+ Sigmoid,
9
+ Dropout,
10
+ MaxPool2d,
11
+ AdaptiveAvgPool2d,
12
+ Sequential,
13
+ Module,
14
+ Parameter,
15
+ )
16
+ import torch
17
+ from collections import namedtuple
18
+ import math
19
+
20
+ ################################## Original Arcface Model #############################################################
21
+
22
+
23
+ class Flatten(Module):
24
+ def forward(self, input):
25
+ return input.view(input.size(0), -1)
26
+
27
+
28
+ def l2_norm(input, axis=1):
29
+ norm = torch.norm(input, 2, axis, True)
30
+ output = torch.div(input, norm)
31
+ return output
32
+
33
+
34
+ class SEModule(Module):
35
+ def __init__(self, channels, reduction):
36
+ super(SEModule, self).__init__()
37
+ self.avg_pool = AdaptiveAvgPool2d(1)
38
+ self.fc1 = Conv2d(
39
+ channels, channels // reduction, kernel_size=1, padding=0, bias=False
40
+ )
41
+ self.relu = ReLU(inplace=True)
42
+ self.fc2 = Conv2d(
43
+ channels // reduction, channels, kernel_size=1, padding=0, bias=False
44
+ )
45
+ self.sigmoid = Sigmoid()
46
+
47
+ def forward(self, x):
48
+ module_input = x
49
+ x = self.avg_pool(x)
50
+ x = self.fc1(x)
51
+ x = self.relu(x)
52
+ x = self.fc2(x)
53
+ x = self.sigmoid(x)
54
+ return module_input * x
55
+
56
+
57
+ class bottleneck_IR(Module):
58
+ def __init__(self, in_channel, depth, stride):
59
+ super(bottleneck_IR, self).__init__()
60
+ if in_channel == depth:
61
+ self.shortcut_layer = MaxPool2d(1, stride)
62
+ else:
63
+ self.shortcut_layer = Sequential(
64
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
65
+ BatchNorm2d(depth),
66
+ )
67
+ self.res_layer = Sequential(
68
+ BatchNorm2d(in_channel),
69
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
70
+ PReLU(depth),
71
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
72
+ BatchNorm2d(depth),
73
+ )
74
+
75
+ def forward(self, x):
76
+ shortcut = self.shortcut_layer(x)
77
+ res = self.res_layer(x)
78
+ return res + shortcut
79
+
80
+
81
+ class bottleneck_IR_SE(Module):
82
+ def __init__(self, in_channel, depth, stride):
83
+ super(bottleneck_IR_SE, self).__init__()
84
+ if in_channel == depth:
85
+ self.shortcut_layer = MaxPool2d(1, stride)
86
+ else:
87
+ self.shortcut_layer = Sequential(
88
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
89
+ BatchNorm2d(depth),
90
+ )
91
+ self.res_layer = Sequential(
92
+ BatchNorm2d(in_channel),
93
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
94
+ PReLU(depth),
95
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
96
+ BatchNorm2d(depth),
97
+ SEModule(depth, 16),
98
+ )
99
+
100
+ def forward(self, x):
101
+ shortcut = self.shortcut_layer(x)
102
+ res = self.res_layer(x)
103
+ return res + shortcut
104
+
105
+
106
+ class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
107
+ """A named tuple describing a ResNet block."""
108
+
109
+
110
+ def get_block(in_channel, depth, num_units, stride=2):
111
+ return [Bottleneck(in_channel, depth, stride)] + [
112
+ Bottleneck(depth, depth, 1) for i in range(num_units - 1)
113
+ ]
114
+
115
+
116
+ def get_blocks(num_layers):
117
+ if num_layers == 50:
118
+ blocks = [
119
+ get_block(in_channel=64, depth=64, num_units=3),
120
+ get_block(in_channel=64, depth=128, num_units=4),
121
+ get_block(in_channel=128, depth=256, num_units=14),
122
+ get_block(in_channel=256, depth=512, num_units=3),
123
+ ]
124
+ elif num_layers == 100:
125
+ blocks = [
126
+ get_block(in_channel=64, depth=64, num_units=3),
127
+ get_block(in_channel=64, depth=128, num_units=13),
128
+ get_block(in_channel=128, depth=256, num_units=30),
129
+ get_block(in_channel=256, depth=512, num_units=3),
130
+ ]
131
+ elif num_layers == 152:
132
+ blocks = [
133
+ get_block(in_channel=64, depth=64, num_units=3),
134
+ get_block(in_channel=64, depth=128, num_units=8),
135
+ get_block(in_channel=128, depth=256, num_units=36),
136
+ get_block(in_channel=256, depth=512, num_units=3),
137
+ ]
138
+ return blocks
139
+
140
+
141
+ class Backbone(Module):
142
+ def __init__(self, num_layers, drop_ratio, mode="ir"):
143
+ super(Backbone, self).__init__()
144
+ assert num_layers in [50, 100, 152], "num_layers should be 50,100, or 152"
145
+ assert mode in ["ir", "ir_se"], "mode should be ir or ir_se"
146
+ blocks = get_blocks(num_layers)
147
+ if mode == "ir":
148
+ unit_module = bottleneck_IR
149
+ elif mode == "ir_se":
150
+ unit_module = bottleneck_IR_SE
151
+ self.input_layer = Sequential(
152
+ Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)
153
+ )
154
+ self.output_layer = Sequential(
155
+ BatchNorm2d(512),
156
+ Dropout(drop_ratio),
157
+ Flatten(),
158
+ Linear(512 * 7 * 7, 512),
159
+ BatchNorm1d(512),
160
+ )
161
+ modules = []
162
+ for block in blocks:
163
+ for bottleneck in block:
164
+ modules.append(
165
+ unit_module(
166
+ bottleneck.in_channel, bottleneck.depth, bottleneck.stride
167
+ )
168
+ )
169
+ self.body = Sequential(*modules)
170
+
171
+ def forward(self, x):
172
+ x = self.input_layer(x)
173
+ x = self.body(x)
174
+ x = self.output_layer(x)
175
+ return l2_norm(x)
176
+
177
+
178
+ ################################## MobileFaceNet #############################################################
179
+
180
+
181
+ class Conv_block(Module):
182
+ def __init__(
183
+ self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1
184
+ ):
185
+ super(Conv_block, self).__init__()
186
+ self.conv = Conv2d(
187
+ in_c,
188
+ out_channels=out_c,
189
+ kernel_size=kernel,
190
+ groups=groups,
191
+ stride=stride,
192
+ padding=padding,
193
+ bias=False,
194
+ )
195
+ self.bn = BatchNorm2d(out_c)
196
+ self.prelu = PReLU(out_c)
197
+
198
+ def forward(self, x):
199
+ x = self.conv(x)
200
+ x = self.bn(x)
201
+ x = self.prelu(x)
202
+ return x
203
+
204
+
205
+ class Linear_block(Module):
206
+ def __init__(
207
+ self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1
208
+ ):
209
+ super(Linear_block, self).__init__()
210
+ self.conv = Conv2d(
211
+ in_c,
212
+ out_channels=out_c,
213
+ kernel_size=kernel,
214
+ groups=groups,
215
+ stride=stride,
216
+ padding=padding,
217
+ bias=False,
218
+ )
219
+ self.bn = BatchNorm2d(out_c)
220
+
221
+ def forward(self, x):
222
+ x = self.conv(x)
223
+ x = self.bn(x)
224
+ return x
225
+
226
+
227
+ class Depth_Wise(Module):
228
+ def __init__(
229
+ self,
230
+ in_c,
231
+ out_c,
232
+ residual=False,
233
+ kernel=(3, 3),
234
+ stride=(2, 2),
235
+ padding=(1, 1),
236
+ groups=1,
237
+ ):
238
+ super(Depth_Wise, self).__init__()
239
+ self.conv = Conv_block(
240
+ in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
241
+ )
242
+ self.conv_dw = Conv_block(
243
+ groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride
244
+ )
245
+ self.project = Linear_block(
246
+ groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)
247
+ )
248
+ self.residual = residual
249
+
250
+ def forward(self, x):
251
+ if self.residual:
252
+ short_cut = x
253
+ x = self.conv(x)
254
+ x = self.conv_dw(x)
255
+ x = self.project(x)
256
+ if self.residual:
257
+ output = short_cut + x
258
+ else:
259
+ output = x
260
+ return output
261
+
262
+
263
+ class Residual(Module):
264
+ def __init__(
265
+ self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
266
+ ):
267
+ super(Residual, self).__init__()
268
+ modules = []
269
+ for _ in range(num_block):
270
+ modules.append(
271
+ Depth_Wise(
272
+ c,
273
+ c,
274
+ residual=True,
275
+ kernel=kernel,
276
+ padding=padding,
277
+ stride=stride,
278
+ groups=groups,
279
+ )
280
+ )
281
+ self.model = Sequential(*modules)
282
+
283
+ def forward(self, x):
284
+ return self.model(x)
285
+
286
+
287
+ class MobileFaceNet(Module):
288
+ def __init__(self, embedding_size):
289
+ super(MobileFaceNet, self).__init__()
290
+ self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
291
+ self.conv2_dw = Conv_block(
292
+ 64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64
293
+ )
294
+ self.conv_23 = Depth_Wise(
295
+ 64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128
296
+ )
297
+ self.conv_3 = Residual(
298
+ 64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
299
+ )
300
+ self.conv_34 = Depth_Wise(
301
+ 64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256
302
+ )
303
+ self.conv_4 = Residual(
304
+ 128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
305
+ )
306
+ self.conv_45 = Depth_Wise(
307
+ 128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512
308
+ )
309
+ self.conv_5 = Residual(
310
+ 128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)
311
+ )
312
+ self.conv_6_sep = Conv_block(
313
+ 128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)
314
+ )
315
+ self.conv_6_dw = Linear_block(
316
+ 512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)
317
+ )
318
+ self.conv_6_flatten = Flatten()
319
+ self.linear = Linear(512, embedding_size, bias=False)
320
+ self.bn = BatchNorm1d(embedding_size)
321
+
322
+ def forward(self, x):
323
+ out = self.conv1(x)
324
+
325
+ out = self.conv2_dw(out)
326
+
327
+ out = self.conv_23(out)
328
+
329
+ out = self.conv_3(out)
330
+
331
+ out = self.conv_34(out)
332
+
333
+ out = self.conv_4(out)
334
+
335
+ out = self.conv_45(out)
336
+
337
+ out = self.conv_5(out)
338
+
339
+ out = self.conv_6_sep(out)
340
+
341
+ out = self.conv_6_dw(out)
342
+
343
+ out = self.conv_6_flatten(out)
344
+
345
+ out = self.linear(out)
346
+
347
+ out = self.bn(out)
348
+ return l2_norm(out)
349
+
350
+
351
+ ################################## Arcface head #############################################################
352
+
353
+
354
+ class Arcface(Module):
355
+ # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
356
+ def __init__(self, embedding_size=512, classnum=51332, s=64.0, m=0.5):
357
+ super(Arcface, self).__init__()
358
+ self.classnum = classnum
359
+ self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
360
+ # initial kernel
361
+ self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
362
+ self.m = m # the margin value, default is 0.5
363
+ self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
364
+ self.cos_m = math.cos(m)
365
+ self.sin_m = math.sin(m)
366
+ self.mm = self.sin_m * m # issue 1
367
+ self.threshold = math.cos(math.pi - m)
368
+
369
+ def forward(self, embbedings, label):
370
+ # weights norm
371
+ nB = len(embbedings)
372
+ kernel_norm = l2_norm(self.kernel, axis=0)
373
+ # cos(theta+m)
374
+ cos_theta = torch.mm(embbedings, kernel_norm)
375
+ # output = torch.mm(embbedings,kernel_norm)
376
+ cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
377
+ cos_theta_2 = torch.pow(cos_theta, 2)
378
+ sin_theta_2 = 1 - cos_theta_2
379
+ sin_theta = torch.sqrt(sin_theta_2)
380
+ cos_theta_m = cos_theta * self.cos_m - sin_theta * self.sin_m
381
+ # this condition controls the theta+m should in range [0, pi]
382
+ # 0<=theta+m<=pi
383
+ # -m<=theta<=pi-m
384
+ cond_v = cos_theta - self.threshold
385
+ cond_mask = cond_v <= 0
386
+ keep_val = cos_theta - self.mm # when theta not in [0,pi], use cosface instead
387
+ cos_theta_m[cond_mask] = keep_val[cond_mask]
388
+ output = (
389
+ cos_theta * 1.0
390
+ ) # a little bit hacky way to prevent in_place operation on cos_theta
391
+ idx_ = torch.arange(0, nB, dtype=torch.long)
392
+ output[idx_, label] = cos_theta_m[idx_, label]
393
+ output *= (
394
+ self.s
395
+ ) # scale up in order to make softmax work, first introduced in normface
396
+ return output
397
+
398
+
399
+ ################################## Cosface head #############################################################
400
+
401
+
402
+ class Am_softmax(Module):
403
+ # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
404
+ def __init__(self, embedding_size=512, classnum=51332):
405
+ super(Am_softmax, self).__init__()
406
+ self.classnum = classnum
407
+ self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
408
+ # initial kernel
409
+ self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
410
+ self.m = 0.35 # additive margin recommended by the paper
411
+ self.s = 30.0 # see normface https://arxiv.org/abs/1704.06369
412
+
413
+ def forward(self, embbedings, label):
414
+ kernel_norm = l2_norm(self.kernel, axis=0)
415
+ cos_theta = torch.mm(embbedings, kernel_norm)
416
+ cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
417
+ phi = cos_theta - self.m
418
+ label = label.view(-1, 1) # size=(B,1)
419
+ index = cos_theta.data * 0.0 # size=(B,Classnum)
420
+ index.scatter_(1, label.data.view(-1, 1), 1)
421
+ index = index.byte()
422
+ output = cos_theta * 1.0
423
+ output[index] = phi[index] # only change the correct predicted output
424
+ output *= (
425
+ self.s
426
+ ) # scale up in order to make softmax work, first introduced in normface
427
+ return output
requirements-dev.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ruff
2
+ pre-commit
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.38.1
2
+ numpy>=2.0.0
3
+ Pillow>=10.4.0
4
+ torch>=2.3.1
5
+ torchvision>=0.18.1
6
+ tqdm>=4.66.4
7
+ lpips>=0.1.4
util/__init__.py ADDED
File without changes
util/attack_utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Helper function for extracting features from pre-trained models
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from torch.autograd import Variable
6
+ from util.feature_extraction_utils import warp_image, normalize_batch
7
+ from util.prepare_utils import get_ensemble, extract_features
8
+ from lpips_pytorch import LPIPS
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ tensor_transform = transforms.ToTensor()
12
+ pil_transform = transforms.ToPILImage()
13
+
14
+
15
+ class Attack(nn.Module):
16
+ def __init__(
17
+ self,
18
+ models,
19
+ dim,
20
+ attack_type,
21
+ eps,
22
+ c_sim=0.5,
23
+ net_type="alex",
24
+ lr=0.05,
25
+ n_iters=100,
26
+ noise_size=0.001,
27
+ n_starts=10,
28
+ c_tv=None,
29
+ sigma_gf=None,
30
+ kernel_size_gf=None,
31
+ combination=False,
32
+ warp=False,
33
+ theta_warp=None,
34
+ V_reduction=None,
35
+ ):
36
+ super(Attack, self).__init__()
37
+ self.extractor_ens = get_ensemble(
38
+ models, sigma_gf, kernel_size_gf, combination, V_reduction, warp, theta_warp
39
+ )
40
+ # print("There are '{}'' models in the attack ensemble".format(len(self.extractor_ens)))
41
+ self.dim = dim
42
+ self.eps = eps
43
+ self.c_sim = c_sim
44
+ self.net_type = net_type
45
+ self.lr = lr
46
+ self.n_iters = n_iters
47
+ self.noise_size = noise_size
48
+ self.n_starts = n_starts
49
+ self.c_tv = None
50
+ self.attack_type = attack_type
51
+ self.warp = warp
52
+ self.theta_warp = theta_warp
53
+ if self.attack_type == "lpips":
54
+ self.lpips_loss = LPIPS(self.net_type).to(device)
55
+
56
+ def execute(self, images, dir_vec, direction):
57
+ images = Variable(images).to(device)
58
+ dir_vec = dir_vec.to(device)
59
+ # take norm wrt dim
60
+ dir_vec_norm = dir_vec.norm(dim=2).unsqueeze(2).to(device)
61
+ dist = torch.zeros(images.shape[0]).to(device)
62
+ adv_images = images.detach().clone()
63
+
64
+ if self.warp:
65
+ self.face_img = warp_image(images, self.theta_warp)
66
+
67
+ for start in range(self.n_starts):
68
+ # update adversarial images old and distance old
69
+ adv_images_old = adv_images.detach().clone()
70
+ dist_old = dist.clone()
71
+ # add noise to initialize ( - noise_size, noise_size)
72
+ noise_uniform = Variable(
73
+ 2 * self.noise_size * torch.rand(images.size()) - self.noise_size
74
+ ).to(device)
75
+ adv_images = Variable(
76
+ images.detach().clone() + noise_uniform, requires_grad=True
77
+ ).to(device)
78
+
79
+ for i in range(self.n_iters):
80
+ adv_features = extract_features(
81
+ adv_images, self.extractor_ens, self.dim
82
+ ).to(device)
83
+ # normalize feature vectors in ensembles
84
+ loss = direction * torch.mean(
85
+ (adv_features - dir_vec) ** 2 / dir_vec_norm
86
+ )
87
+
88
+ if self.c_tv is not None:
89
+ tv_out = self.total_var_reg(images, adv_images)
90
+ loss -= self.c_tv * tv_out
91
+
92
+ if self.attack_type == "lpips":
93
+ lpips_out = self.lpips_reg(images, adv_images)
94
+ loss -= self.c_sim * lpips_out
95
+
96
+ grad = torch.autograd.grad(loss, [adv_images])
97
+ adv_images = adv_images + self.lr * grad[0].sign()
98
+ perturbation = adv_images - images
99
+
100
+ if self.attack_type == "sgd":
101
+ perturbation = torch.clamp(
102
+ perturbation, min=-self.eps, max=self.eps
103
+ )
104
+ adv_images = images + perturbation
105
+
106
+ adv_images = torch.clamp(adv_images, min=0, max=1)
107
+ adv_features = extract_features(
108
+ adv_images, self.extractor_ens, self.dim
109
+ ).to(device)
110
+ dist = torch.mean((adv_features - dir_vec) ** 2 / dir_vec_norm, dim=[1, 2])
111
+
112
+ if direction == 1:
113
+ adv_images[dist < dist_old] = adv_images_old[dist < dist_old]
114
+ dist[dist < dist_old] = dist_old[dist < dist_old]
115
+ else:
116
+ adv_images[dist > dist_old] = adv_images_old[dist > dist_old]
117
+ dist[dist > dist_old] = dist_old[dist > dist_old]
118
+
119
+ return adv_images.detach().cpu()
120
+
121
+ def lpips_reg(self, images, adv_images):
122
+ if self.warp:
123
+ face_adv = warp_image(adv_images, self.theta_warp)
124
+ lpips_out = self.lpips_loss(
125
+ normalize_batch(self.face_img).to(device),
126
+ normalize_batch(face_adv).to(device),
127
+ )[0][0][0][0] / (2 * adv_images.shape[0])
128
+ lpips_out += self.lpips_loss(
129
+ normalize_batch(images).to(device),
130
+ normalize_batch(adv_images).to(device),
131
+ )[0][0][0][0] / (2 * adv_images.shape[0])
132
+
133
+ else:
134
+ lpips_out = (
135
+ self.lpips_loss(
136
+ normalize_batch(images).to(device),
137
+ normalize_batch(adv_images).to(device),
138
+ )[0][0][0][0]
139
+ / adv_images.shape[0]
140
+ )
141
+
142
+ return lpips_out
143
+
144
+ def total_var_reg(images, adv_images):
145
+ perturbation = adv_images - images
146
+ tv = torch.mean(
147
+ torch.abs(perturbation[:, :, :, :-1] - perturbation[:, :, :, 1:])
148
+ ) + torch.mean(
149
+ torch.abs(perturbation[:, :, :-1, :] - perturbation[:, :, 1:, :])
150
+ )
151
+
152
+ return tv
util/feature_extraction_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Helper function for extracting features from pre-trained models
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+
11
+ def warp_image(tensor_img, theta_warp, crop_size=112):
12
+ # applies affine transform theta to image and crops it
13
+
14
+ theta_warp = torch.Tensor(theta_warp).unsqueeze(0).to(device)
15
+ grid = F.affine_grid(theta_warp, tensor_img.size())
16
+ img_warped = F.grid_sample(tensor_img, grid)
17
+ img_cropped = img_warped[:, :, 0:crop_size, 0:crop_size]
18
+ return img_cropped
19
+
20
+
21
+ def normalize_transforms(tfm, W, H):
22
+ # normalizes affine transform from cv2 for pytorch
23
+ tfm_t = np.concatenate((tfm, np.array([[0, 0, 1]])), axis=0)
24
+ transforms = np.linalg.inv(tfm_t)[0:2, :]
25
+ transforms[0, 0] = transforms[0, 0]
26
+ transforms[0, 1] = transforms[0, 1] * H / W
27
+ transforms[0, 2] = (
28
+ transforms[0, 2] * 2 / W + transforms[0, 0] + transforms[0, 1] - 1
29
+ )
30
+
31
+ transforms[1, 0] = transforms[1, 0] * W / H
32
+ transforms[1, 1] = transforms[1, 1]
33
+ transforms[1, 2] = (
34
+ transforms[1, 2] * 2 / H + transforms[1, 0] + transforms[1, 1] - 1
35
+ )
36
+
37
+ return transforms
38
+
39
+
40
+ def l2_norm(input, axis=1):
41
+ # normalizes input with respect to second norm
42
+ norm = torch.norm(input, 2, axis, True)
43
+ output = torch.div(input, norm)
44
+ return output
45
+
46
+
47
+ def de_preprocess(tensor):
48
+ # normalize images from [-1,1] to [0,1]
49
+ return tensor * 0.5 + 0.5
50
+
51
+
52
+ # normalize image to [-1,1]
53
+ normalize = transforms.Compose([transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
54
+
55
+
56
+ def normalize_batch(imgs_tensor):
57
+ normalized_imgs = torch.empty_like(imgs_tensor)
58
+ for i, img_ten in enumerate(imgs_tensor):
59
+ normalized_imgs[i] = normalize(img_ten)
60
+
61
+ return normalized_imgs
62
+
63
+
64
+ def resize2d(img, size):
65
+ # resizes image
66
+ return F.adaptive_avg_pool2d(img, size)
67
+
68
+
69
+ class face_extractor(nn.Module):
70
+ def __init__(self, crop_size=112, warp=False, theta_warp=None):
71
+ super(face_extractor, self).__init__()
72
+ self.crop_size = crop_size
73
+ self.warp = warp
74
+ self.theta_warp = theta_warp
75
+
76
+ def forward(self, input):
77
+ if self.warp:
78
+ assert input.shape[0] == 1
79
+ input = warp_image(input, self.theta_warp, self.crop_size)
80
+
81
+ return input
82
+
83
+
84
+ class feature_extractor(nn.Module):
85
+ def __init__(self, model, crop_size=112, tta=True, warp=False, theta_warp=None):
86
+ super(feature_extractor, self).__init__()
87
+ self.model = model
88
+ self.crop_size = crop_size
89
+ self.tta = tta
90
+ self.warp = warp
91
+ self.theta_warp = theta_warp
92
+
93
+ self.model = model
94
+
95
+ def forward(self, input):
96
+ if self.warp:
97
+ assert input.shape[0] == 1
98
+ input = warp_image(input, self.theta_warp, self.crop_size)
99
+
100
+ batch_normalized = normalize_batch(input)
101
+ batch_flipped = torch.flip(batch_normalized, [3])
102
+ # extract features
103
+ self.model.eval() # set to evaluation mode
104
+ if self.tta:
105
+ embed = self.model(batch_normalized) + self.model(batch_flipped)
106
+ features = l2_norm(embed)
107
+ else:
108
+ features = l2_norm(self.model(batch_normalized))
109
+ return features
util/prepare_utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Helper function for extracting features from pre-trained models
2
+ import math
3
+ import numbers
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ import numpy as np
9
+ import torchvision.datasets as datasets
10
+ from util.feature_extraction_utils import feature_extractor
11
+ from backbone.model_irse import IR_50, IR_152
12
+ from backbone.model_resnet import ResNet_50, ResNet_152
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ tensor_transform = transforms.ToTensor()
16
+ pil_transform = transforms.ToPILImage()
17
+
18
+
19
+ class ImageFolderWithPaths(datasets.ImageFolder):
20
+ """Custom dataset that includes image file paths. Extends
21
+ torchvision.datasets.ImageFolder
22
+ """
23
+
24
+ # override the __getitem__ method. this is the method that dataloader calls
25
+ def __getitem__(self, index):
26
+ # this is what ImageFolder normally returns
27
+ original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
28
+ # the image file path
29
+ path = self.imgs[index][0]
30
+ # make a new tuple that includes original and the path
31
+ tuple_with_path = original_tuple + (path,)
32
+ return tuple_with_path
33
+
34
+
35
+ class GaussianSmoothing(nn.Module):
36
+ """
37
+ Apply gaussian smoothing on a
38
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
39
+ in the input using a depthwise convolution.
40
+ Arguments:
41
+ channels (int, sequence): Number of channels of the input tensors. Output will
42
+ have this number of channels as well.
43
+ kernel_size (int, sequence): Size of the gaussian kernel.
44
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
45
+ dim (int, optional): The number of dimensions of the data.
46
+ Default value is 2 (spatial).
47
+ """
48
+
49
+ def __init__(self, channels, kernel_size, sigma, dim=2):
50
+ super(GaussianSmoothing, self).__init__()
51
+ if isinstance(kernel_size, numbers.Number):
52
+ kernel_size = [kernel_size] * dim
53
+ if isinstance(sigma, numbers.Number):
54
+ sigma = [sigma] * dim
55
+
56
+ # The gaussian kernel is the product of the
57
+ # gaussian function of each dimension.
58
+ kernel = 1
59
+ meshgrids = torch.meshgrid(
60
+ [torch.arange(size, dtype=torch.float32) for size in kernel_size]
61
+ )
62
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
63
+ mean = (size - 1) / 2
64
+ kernel *= (
65
+ 1
66
+ / (std * math.sqrt(2 * math.pi))
67
+ * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
68
+ )
69
+
70
+ # Make sure sum of values in gaussian kernel equals 1.
71
+ kernel = kernel / torch.sum(kernel)
72
+
73
+ # Reshape to depthwise convolutional weight
74
+ kernel = kernel.view(1, 1, *kernel.size())
75
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
76
+
77
+ self.register_buffer("weight", kernel)
78
+ self.groups = channels
79
+
80
+ if dim == 1:
81
+ self.conv = F.conv1d
82
+ elif dim == 2:
83
+ self.conv = F.conv2d
84
+ elif dim == 3:
85
+ self.conv = F.conv3d
86
+ else:
87
+ raise RuntimeError(
88
+ "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)
89
+ )
90
+ self.pad_size = int(kernel_size[0] / 2)
91
+
92
+ def forward(self, input):
93
+ """
94
+ Apply gaussian filter to input.
95
+ Arguments:
96
+ input (torch.Tensor): Input to apply gaussian filter on.
97
+ Returns:
98
+ filtered (torch.Tensor): Filtered output.
99
+ """
100
+ input = F.pad(
101
+ input,
102
+ (self.pad_size, self.pad_size, self.pad_size, self.pad_size),
103
+ mode="reflect",
104
+ )
105
+ return self.conv(input, weight=self.weight, groups=self.groups)
106
+
107
+
108
+ class dim_reduction(nn.Module):
109
+ def __init__(self, V):
110
+ super(dim_reduction, self).__init__()
111
+ self.V = V
112
+
113
+ def forward(self, input):
114
+ return torch.matmul(input, self.V.to(input.device))
115
+
116
+
117
+ def get_ensemble(
118
+ models,
119
+ sigma_gf,
120
+ kernel_size_gf,
121
+ combination,
122
+ V_reduction,
123
+ warp=False,
124
+ theta_warp=None,
125
+ ):
126
+ # function prepares ensemble of feature extractors
127
+ # outputs list of pytorch nn models
128
+ feature_extractor_ensemble = []
129
+ if sigma_gf is not None:
130
+ # if apply gaussian filterng during attack
131
+ gaussian_filtering = GaussianSmoothing(3, kernel_size_gf, sigma_gf)
132
+ if V_reduction is None:
133
+ for model in models:
134
+ feature_extractor_model = nn.DataParallel(
135
+ nn.Sequential(
136
+ gaussian_filtering,
137
+ feature_extractor(
138
+ model=model, warp=warp, theta_warp=theta_warp
139
+ ),
140
+ )
141
+ ).to(device)
142
+ feature_extractor_ensemble.append(feature_extractor_model)
143
+ if combination:
144
+ feature_extractor_model = nn.DataParallel(
145
+ feature_extractor(model=model, warp=warp, theta_warp=theta_warp)
146
+ ).to(device)
147
+ feature_extractor_ensemble.append(feature_extractor_model)
148
+
149
+ else:
150
+ for i, model in enumerate(models):
151
+ feature_extractor_model = nn.DataParallel(
152
+ nn.Sequential(
153
+ gaussian_filtering,
154
+ feature_extractor(
155
+ model=model, warp=warp, theta_warp=theta_warp
156
+ ),
157
+ dim_reduction(V_reduction[i]),
158
+ )
159
+ ).to(device)
160
+ feature_extractor_ensemble.append(feature_extractor_model)
161
+ if combination:
162
+ feature_extractor_model = nn.DataParallel(
163
+ nn.Sequential(
164
+ feature_extractor(
165
+ model=model, warp=warp, theta_warp=theta_warp
166
+ ),
167
+ dim_reduction(V_reduction[i]),
168
+ )
169
+ ).to(device)
170
+ feature_extractor_ensemble.append(feature_extractor_model)
171
+
172
+ else:
173
+ if V_reduction is None:
174
+ for model in models:
175
+ feature_extractor_model = nn.DataParallel(
176
+ feature_extractor(model=model, warp=warp, theta_warp=theta_warp)
177
+ ).to(device)
178
+ feature_extractor_ensemble.append(feature_extractor_model)
179
+ else:
180
+ for i, model in enumerate(models):
181
+ feature_extractor_model = nn.DataParallel(
182
+ nn.Sequential(
183
+ feature_extractor(
184
+ model=model, warp=warp, theta_warp=theta_warp
185
+ ),
186
+ dim_reduction(V_reduction[i]),
187
+ )
188
+ ).to(device)
189
+ feature_extractor_ensemble.append(feature_extractor_model)
190
+
191
+ return feature_extractor_ensemble
192
+
193
+
194
+ def extract_features(imgs, feature_extractor_ensemble, dim):
195
+ # function computes mean feature vector of images with ensemble of feature extractors
196
+
197
+ features = torch.zeros(imgs.shape[0], len(feature_extractor_ensemble), dim)
198
+ for i, feature_extractor_model in enumerate(feature_extractor_ensemble):
199
+ # batch size, model in ensemble, dim
200
+ features_model = feature_extractor_model(imgs)
201
+ features[:, i, :] = features_model
202
+
203
+ return features
204
+
205
+
206
+ def prepare_models(
207
+ model_backbones,
208
+ input_size,
209
+ model_roots,
210
+ kernel_size_attack,
211
+ sigma_attack,
212
+ combination,
213
+ using_subspace,
214
+ V_reduction_root,
215
+ ):
216
+ backbone_dict = {
217
+ "IR_50": IR_50(input_size),
218
+ "IR_152": IR_152(input_size),
219
+ "ResNet_50": ResNet_50(input_size),
220
+ "ResNet_152": ResNet_152(input_size),
221
+ }
222
+
223
+ print("Loading Attack Backbone Checkpoint '{}'".format(model_roots))
224
+ print("=" * 20)
225
+
226
+ models_attack = []
227
+ for i in range(len(model_backbones)):
228
+ model = backbone_dict[model_backbones[i]]
229
+ model.load_state_dict(torch.load(model_roots[i], map_location=device))
230
+ models_attack.append(model)
231
+
232
+ if using_subspace:
233
+ V_reduction = []
234
+ for i in range(len(model_backbones)):
235
+ V_reduction.append(torch.tensor(np.load(V_reduction_root[i])))
236
+
237
+ dim = V_reduction[0].shape[1]
238
+ else:
239
+ V_reduction = None
240
+ dim = 512
241
+
242
+ return models_attack, V_reduction, dim
243
+
244
+
245
+ def prepare_data(
246
+ query_data_root, target_data_root, freq, batch_size, warp=False, theta_warp=None
247
+ ):
248
+ data = datasets.ImageFolder(query_data_root, tensor_transform)
249
+
250
+ subset_query = list(range(0, len(data), freq))
251
+ subset_gallery = [x for x in list(range(0, len(data))) if x not in subset_query]
252
+ query_set = torch.utils.data.Subset(data, subset_query)
253
+ gallery_set = torch.utils.data.Subset(data, subset_gallery)
254
+
255
+ if target_data_root is not None:
256
+ target_data = datasets.ImageFolder(target_data_root, tensor_transform)
257
+ target_loader = torch.utils.data.DataLoader(target_data, batch_size=batch_size)
258
+ else:
259
+ target_loader = None
260
+
261
+ query_loader = torch.utils.data.DataLoader(query_set, batch_size=batch_size)
262
+ gallery_loader = torch.utils.data.DataLoader(gallery_set, batch_size=batch_size)
263
+
264
+ return query_loader, gallery_loader, target_loader
265
+
266
+
267
+ def prepare_dir_vec(dir_vec_extractor, imgs, dim, combination):
268
+ dir_vec = extract_features(imgs, dir_vec_extractor, dim).detach().cpu()
269
+ if combination:
270
+ dir_vec = torch.repeat_interleave(dir_vec, 2, 1)
271
+ return dir_vec