devforfu commited on
Commit
ea847ad
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ **/lightning_logs
132
+ .*
133
+ !.gitignore
134
+ *.out
135
+
nbs/prepare.ipynb ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a0dae3b6-0612-4744-a466-5c8be9c62923",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [
11
+ {
12
+ "name": "stdout",
13
+ "output_type": "stream",
14
+ "text": [
15
+ "/admin/home-devforfu/realfake\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "%cd .."
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 4,
26
+ "id": "4b5f4d1b-40a8-4a61-88ea-502103368b1c",
27
+ "metadata": {
28
+ "tags": []
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "import json\n",
33
+ "from pathlib import Path\n",
34
+ "import pandas as pd\n",
35
+ "from realfake.utils import read_jsonl, write_jsonl"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 6,
41
+ "id": "df083619-aa93-49b5-aa76-1d2680062927",
42
+ "metadata": {
43
+ "tags": []
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "df_all = pd.DataFrame(read_jsonl(\"metadata/all.jsonl\"))"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 7,
53
+ "id": "f91769e7-a4b9-4012-9079-441c364d32b3",
54
+ "metadata": {
55
+ "tags": []
56
+ },
57
+ "outputs": [],
58
+ "source": [
59
+ "df_fail = pd.DataFrame(read_jsonl(\"metadata/all.failed.jsonl\"))"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 8,
65
+ "id": "88c02485-fdf4-42c0-8896-70f43bfaf76f",
66
+ "metadata": {
67
+ "tags": []
68
+ },
69
+ "outputs": [
70
+ {
71
+ "data": {
72
+ "text/html": [
73
+ "<div>\n",
74
+ "<style scoped>\n",
75
+ " .dataframe tbody tr th:only-of-type {\n",
76
+ " vertical-align: middle;\n",
77
+ " }\n",
78
+ "\n",
79
+ " .dataframe tbody tr th {\n",
80
+ " vertical-align: top;\n",
81
+ " }\n",
82
+ "\n",
83
+ " .dataframe thead th {\n",
84
+ " text-align: right;\n",
85
+ " }\n",
86
+ "</style>\n",
87
+ "<table border=\"1\" class=\"dataframe\">\n",
88
+ " <thead>\n",
89
+ " <tr style=\"text-align: right;\">\n",
90
+ " <th></th>\n",
91
+ " <th>path</th>\n",
92
+ " <th>label</th>\n",
93
+ " <th>class</th>\n",
94
+ " </tr>\n",
95
+ " </thead>\n",
96
+ " <tbody>\n",
97
+ " <tr>\n",
98
+ " <th>0</th>\n",
99
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
100
+ " <td>real</td>\n",
101
+ " <td>n02797295</td>\n",
102
+ " </tr>\n",
103
+ " <tr>\n",
104
+ " <th>1</th>\n",
105
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
106
+ " <td>real</td>\n",
107
+ " <td>n02797295</td>\n",
108
+ " </tr>\n",
109
+ " <tr>\n",
110
+ " <th>2</th>\n",
111
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
112
+ " <td>real</td>\n",
113
+ " <td>n02797295</td>\n",
114
+ " </tr>\n",
115
+ " </tbody>\n",
116
+ "</table>\n",
117
+ "</div>"
118
+ ],
119
+ "text/plain": [
120
+ " path label class\n",
121
+ "0 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295\n",
122
+ "1 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295\n",
123
+ "2 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295"
124
+ ]
125
+ },
126
+ "execution_count": 8,
127
+ "metadata": {},
128
+ "output_type": "execute_result"
129
+ }
130
+ ],
131
+ "source": [
132
+ "df_all.head(3)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 9,
138
+ "id": "9bbd0247-1f34-4adc-8647-525792e6d3e5",
139
+ "metadata": {
140
+ "tags": []
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "df_ok = df_all[~df_all.path.isin(df_fail.path)].reset_index(drop=True)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 10,
150
+ "id": "77b400a8-83aa-4e38-983a-d56b71245ac9",
151
+ "metadata": {
152
+ "tags": []
153
+ },
154
+ "outputs": [
155
+ {
156
+ "data": {
157
+ "text/html": [
158
+ "<div>\n",
159
+ "<style scoped>\n",
160
+ " .dataframe tbody tr th:only-of-type {\n",
161
+ " vertical-align: middle;\n",
162
+ " }\n",
163
+ "\n",
164
+ " .dataframe tbody tr th {\n",
165
+ " vertical-align: top;\n",
166
+ " }\n",
167
+ "\n",
168
+ " .dataframe thead th {\n",
169
+ " text-align: right;\n",
170
+ " }\n",
171
+ "</style>\n",
172
+ "<table border=\"1\" class=\"dataframe\">\n",
173
+ " <thead>\n",
174
+ " <tr style=\"text-align: right;\">\n",
175
+ " <th></th>\n",
176
+ " <th>path</th>\n",
177
+ " <th>label</th>\n",
178
+ " <th>class</th>\n",
179
+ " </tr>\n",
180
+ " </thead>\n",
181
+ " <tbody>\n",
182
+ " <tr>\n",
183
+ " <th>1517638</th>\n",
184
+ " <td>/fsx/home-devforfu/data/fake_imagenet1k/n02027...</td>\n",
185
+ " <td>fake</td>\n",
186
+ " <td>n02027492</td>\n",
187
+ " </tr>\n",
188
+ " <tr>\n",
189
+ " <th>1026755</th>\n",
190
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n01669...</td>\n",
191
+ " <td>real</td>\n",
192
+ " <td>n01669191</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <th>7790495</th>\n",
196
+ " <td>/fsx/home-devforfu/data/fake_2m_all/d8713853-0...</td>\n",
197
+ " <td>fake</td>\n",
198
+ " <td>None</td>\n",
199
+ " </tr>\n",
200
+ " </tbody>\n",
201
+ "</table>\n",
202
+ "</div>"
203
+ ],
204
+ "text/plain": [
205
+ " path label class\n",
206
+ "1517638 /fsx/home-devforfu/data/fake_imagenet1k/n02027... fake n02027492\n",
207
+ "1026755 /fsx/home-devforfu/data/real_imagenet1k/n01669... real n01669191\n",
208
+ "7790495 /fsx/home-devforfu/data/fake_2m_all/d8713853-0... fake None"
209
+ ]
210
+ },
211
+ "execution_count": 10,
212
+ "metadata": {},
213
+ "output_type": "execute_result"
214
+ }
215
+ ],
216
+ "source": [
217
+ "df_ok.sample(3)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 11,
223
+ "id": "321ac973-5b3b-4626-aa95-e4306680099e",
224
+ "metadata": {
225
+ "tags": []
226
+ },
227
+ "outputs": [
228
+ {
229
+ "data": {
230
+ "text/plain": [
231
+ "real 4184273\n",
232
+ "fake 4160720\n",
233
+ "Name: label, dtype: int64"
234
+ ]
235
+ },
236
+ "execution_count": 11,
237
+ "metadata": {},
238
+ "output_type": "execute_result"
239
+ }
240
+ ],
241
+ "source": [
242
+ "df_ok[\"label\"].value_counts()"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": 12,
248
+ "id": "b0c8b90b-49b3-4aaf-89f2-be2277459ec7",
249
+ "metadata": {
250
+ "tags": []
251
+ },
252
+ "outputs": [],
253
+ "source": [
254
+ "from sklearn.model_selection import train_test_split\n",
255
+ "\n",
256
+ "def create_metadata(dataset, test_size: float = 0.1, sample: int = None, seed: int = 1):\n",
257
+ " if sample is not None:\n",
258
+ " real = dataset[dataset[\"label\"] == \"real\"].sample(sample)\n",
259
+ " fake = dataset[dataset[\"label\"] == \"fake\"].sample(sample)\n",
260
+ " dataset = pd.concat([real, fake])\n",
261
+ " \n",
262
+ " imagenet_classes = dataset[\"class\"].dropna().unique()\n",
263
+ " \n",
264
+ " trn, val = train_test_split(imagenet_classes, test_size=test_size, random_state=seed)\n",
265
+ " trn_data = dataset[dataset[\"class\"].isin(trn)]\n",
266
+ " val_data = dataset[dataset[\"class\"].isin(val)]\n",
267
+ "\n",
268
+ " no_class = dataset[dataset[\"class\"].isna()]\n",
269
+ " trn_data_null, val_data_null = train_test_split(no_class, test_size=test_size, random_state=seed)\n",
270
+ " \n",
271
+ " trn_data = pd.concat([trn_data, trn_data_null])\n",
272
+ " trn_data[\"valid\"] = False\n",
273
+ " val_data = pd.concat([val_data, val_data_null])\n",
274
+ " val_data[\"valid\"] = True\n",
275
+ " \n",
276
+ " assert not set(trn_data[\"class\"].dropna()).intersection(val_data[\"class\"].dropna())\n",
277
+ " \n",
278
+ " return pd.concat([trn_data, val_data])"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 13,
284
+ "id": "ae4e4d48-9f81-4ea8-bb84-592b50fef3a9",
285
+ "metadata": {
286
+ "tags": []
287
+ },
288
+ "outputs": [],
289
+ "source": [
290
+ "n = 1_000_000"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 14,
296
+ "id": "9ee41625-a30e-4eb9-8989-bda898094a83",
297
+ "metadata": {
298
+ "tags": []
299
+ },
300
+ "outputs": [
301
+ {
302
+ "data": {
303
+ "text/html": [
304
+ "<div>\n",
305
+ "<style scoped>\n",
306
+ " .dataframe tbody tr th:only-of-type {\n",
307
+ " vertical-align: middle;\n",
308
+ " }\n",
309
+ "\n",
310
+ " .dataframe tbody tr th {\n",
311
+ " vertical-align: top;\n",
312
+ " }\n",
313
+ "\n",
314
+ " .dataframe thead th {\n",
315
+ " text-align: right;\n",
316
+ " }\n",
317
+ "</style>\n",
318
+ "<table border=\"1\" class=\"dataframe\">\n",
319
+ " <thead>\n",
320
+ " <tr style=\"text-align: right;\">\n",
321
+ " <th></th>\n",
322
+ " <th>path</th>\n",
323
+ " <th>label</th>\n",
324
+ " <th>class</th>\n",
325
+ " <th>valid</th>\n",
326
+ " </tr>\n",
327
+ " </thead>\n",
328
+ " <tbody>\n",
329
+ " <tr>\n",
330
+ " <th>135038</th>\n",
331
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n01917...</td>\n",
332
+ " <td>real</td>\n",
333
+ " <td>n01917289</td>\n",
334
+ " <td>False</td>\n",
335
+ " </tr>\n",
336
+ " <tr>\n",
337
+ " <th>803039</th>\n",
338
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n01697...</td>\n",
339
+ " <td>real</td>\n",
340
+ " <td>n01697457</td>\n",
341
+ " <td>False</td>\n",
342
+ " </tr>\n",
343
+ " <tr>\n",
344
+ " <th>1280747</th>\n",
345
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n02992...</td>\n",
346
+ " <td>real</td>\n",
347
+ " <td>n02992211</td>\n",
348
+ " <td>False</td>\n",
349
+ " </tr>\n",
350
+ " <tr>\n",
351
+ " <th>130185</th>\n",
352
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n04599...</td>\n",
353
+ " <td>real</td>\n",
354
+ " <td>n04599235</td>\n",
355
+ " <td>False</td>\n",
356
+ " </tr>\n",
357
+ " <tr>\n",
358
+ " <th>701554</th>\n",
359
+ " <td>/fsx/home-devforfu/data/real_imagenet1k/n02108...</td>\n",
360
+ " <td>real</td>\n",
361
+ " <td>n02108000</td>\n",
362
+ " <td>False</td>\n",
363
+ " </tr>\n",
364
+ " <tr>\n",
365
+ " <th>...</th>\n",
366
+ " <td>...</td>\n",
367
+ " <td>...</td>\n",
368
+ " <td>...</td>\n",
369
+ " <td>...</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <th>7879868</th>\n",
373
+ " <td>/fsx/home-devforfu/data/fake_2m_all/3cf77f54-2...</td>\n",
374
+ " <td>fake</td>\n",
375
+ " <td>None</td>\n",
376
+ " <td>True</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <th>3542472</th>\n",
380
+ " <td>/fsx/home-devforfu/data/real_aes_400_700/00485...</td>\n",
381
+ " <td>real</td>\n",
382
+ " <td>None</td>\n",
383
+ " <td>True</td>\n",
384
+ " </tr>\n",
385
+ " <tr>\n",
386
+ " <th>6454613</th>\n",
387
+ " <td>/fsx/home-devforfu/data/fake_2m_all/1e7c20a8-8...</td>\n",
388
+ " <td>fake</td>\n",
389
+ " <td>None</td>\n",
390
+ " <td>True</td>\n",
391
+ " </tr>\n",
392
+ " <tr>\n",
393
+ " <th>5466667</th>\n",
394
+ " <td>/fsx/home-devforfu/data/real_aes_400_700/00441...</td>\n",
395
+ " <td>real</td>\n",
396
+ " <td>None</td>\n",
397
+ " <td>True</td>\n",
398
+ " </tr>\n",
399
+ " <tr>\n",
400
+ " <th>6469539</th>\n",
401
+ " <td>/fsx/home-devforfu/data/fake_2m_all/1b126896-e...</td>\n",
402
+ " <td>fake</td>\n",
403
+ " <td>None</td>\n",
404
+ " <td>True</td>\n",
405
+ " </tr>\n",
406
+ " </tbody>\n",
407
+ "</table>\n",
408
+ "<p>2000000 rows × 4 columns</p>\n",
409
+ "</div>"
410
+ ],
411
+ "text/plain": [
412
+ " path label class \\\n",
413
+ "135038 /fsx/home-devforfu/data/real_imagenet1k/n01917... real n01917289 \n",
414
+ "803039 /fsx/home-devforfu/data/real_imagenet1k/n01697... real n01697457 \n",
415
+ "1280747 /fsx/home-devforfu/data/real_imagenet1k/n02992... real n02992211 \n",
416
+ "130185 /fsx/home-devforfu/data/real_imagenet1k/n04599... real n04599235 \n",
417
+ "701554 /fsx/home-devforfu/data/real_imagenet1k/n02108... real n02108000 \n",
418
+ "... ... ... ... \n",
419
+ "7879868 /fsx/home-devforfu/data/fake_2m_all/3cf77f54-2... fake None \n",
420
+ "3542472 /fsx/home-devforfu/data/real_aes_400_700/00485... real None \n",
421
+ "6454613 /fsx/home-devforfu/data/fake_2m_all/1e7c20a8-8... fake None \n",
422
+ "5466667 /fsx/home-devforfu/data/real_aes_400_700/00441... real None \n",
423
+ "6469539 /fsx/home-devforfu/data/fake_2m_all/1b126896-e... fake None \n",
424
+ "\n",
425
+ " valid \n",
426
+ "135038 False \n",
427
+ "803039 False \n",
428
+ "1280747 False \n",
429
+ "130185 False \n",
430
+ "701554 False \n",
431
+ "... ... \n",
432
+ "7879868 True \n",
433
+ "3542472 True \n",
434
+ "6454613 True \n",
435
+ "5466667 True \n",
436
+ "6469539 True \n",
437
+ "\n",
438
+ "[2000000 rows x 4 columns]"
439
+ ]
440
+ },
441
+ "execution_count": 14,
442
+ "metadata": {},
443
+ "output_type": "execute_result"
444
+ }
445
+ ],
446
+ "source": [
447
+ "df = create_metadata(df_ok, sample=n)\n",
448
+ "df"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": 20,
454
+ "id": "3094035a-cdf2-4d81-bbaf-d8d185150a27",
455
+ "metadata": {
456
+ "tags": []
457
+ },
458
+ "outputs": [
459
+ {
460
+ "data": {
461
+ "text/plain": [
462
+ "'metadata/prepared.2000k.jsonl'"
463
+ ]
464
+ },
465
+ "execution_count": 20,
466
+ "metadata": {},
467
+ "output_type": "execute_result"
468
+ }
469
+ ],
470
+ "source": [
471
+ "filename = \"prepared.jsonl\" if n is None else f\"prepared.{2*n//1000}k.jsonl\" \n",
472
+ "filename = f\"metadata/{filename}\"\n",
473
+ "filename"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 37,
479
+ "id": "bc990be8-16ca-4a3b-bbc6-eaa652d46d81",
480
+ "metadata": {
481
+ "tags": []
482
+ },
483
+ "outputs": [],
484
+ "source": [
485
+ "write_jsonl(filename, df.to_dict(\"records\"))"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 21,
491
+ "id": "7f6a3549-affe-459b-bf4a-bb8e2bc0ac62",
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": [
495
+ "df = pd.DataFrame(read_jsonl(filename))"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": 22,
501
+ "id": "f36ed292-4f7e-4876-a1ca-b6fc841227b8",
502
+ "metadata": {
503
+ "tags": []
504
+ },
505
+ "outputs": [
506
+ {
507
+ "data": {
508
+ "text/plain": [
509
+ "(2000000, 4)"
510
+ ]
511
+ },
512
+ "execution_count": 22,
513
+ "metadata": {},
514
+ "output_type": "execute_result"
515
+ }
516
+ ],
517
+ "source": [
518
+ "df.shape"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": 23,
524
+ "id": "b72d0d93-d265-4149-b816-7a62f7a5a17a",
525
+ "metadata": {
526
+ "tags": []
527
+ },
528
+ "outputs": [
529
+ {
530
+ "data": {
531
+ "text/plain": [
532
+ "False 0.899385\n",
533
+ "True 0.100614\n",
534
+ "Name: valid, dtype: float64"
535
+ ]
536
+ },
537
+ "execution_count": 23,
538
+ "metadata": {},
539
+ "output_type": "execute_result"
540
+ }
541
+ ],
542
+ "source": [
543
+ "df[\"valid\"].value_counts(normalize=True)"
544
+ ]
545
+ }
546
+ ],
547
+ "metadata": {
548
+ "kernelspec": {
549
+ "display_name": "Python 3 (ipykernel)",
550
+ "language": "python",
551
+ "name": "python3"
552
+ },
553
+ "language_info": {
554
+ "codemirror_mode": {
555
+ "name": "ipython",
556
+ "version": 3
557
+ },
558
+ "file_extension": ".py",
559
+ "mimetype": "text/x-python",
560
+ "name": "python",
561
+ "nbconvert_exporter": "python",
562
+ "pygments_lexer": "ipython3",
563
+ "version": "3.8.10"
564
+ }
565
+ },
566
+ "nbformat": 4,
567
+ "nbformat_minor": 5
568
+ }
realfake/bin/check_files.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Check that files that are references from JSONL file are valid.
3
+ """
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ from joblib import delayed, Parallel
10
+
11
+ from realfake.utils import inject_args, Args, read_jsonl
12
+
13
+
14
+ class CheckFilesArgs(Args):
15
+ jsonl_file: Path
16
+
17
+
18
+ @inject_args
19
+ def main(args: CheckFilesArgs) -> None:
20
+ records = read_jsonl(args.jsonl_file)
21
+ results = Parallel(n_jobs=-1, verbose=100)(delayed(check_file)(record) for record in records)
22
+ failed = [result for result in results if result["error"] is not None]
23
+ if not failed:
24
+ print("All files are valid")
25
+ else:
26
+ saved_file = args.jsonl_file.with_suffix(".failed.jsonl")
27
+ print(f"{len(failed)} files are invalid, saved errors to {saved_file}")
28
+ with open(saved_file, "w") as f:
29
+ for record in failed:
30
+ f.write(json.dumps(record) + "\n")
31
+
32
+
33
+ def check_file(record: dict) -> dict:
34
+ path = Path(record["path"])
35
+ error = None
36
+ if not path.exists():
37
+ error = "File does not exist"
38
+ elif not path.is_file():
39
+ error = "Path is not a file"
40
+ elif path.suffix.lower() not in (".jpg", ".jpeg", ".png"):
41
+ error = "File is not an image file"
42
+ else:
43
+ try:
44
+ np.asarray(PIL.Image.open(path))
45
+ except Exception as e:
46
+ error = f"Image cannot be opened: {e}"
47
+ return dict(record, error=error)
48
+
49
+
50
+ if __name__ == '__main__':
51
+ main()
realfake/bin/create_metadata.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates a meta-data file by combining the information from directory structure.
3
+ """
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from pydantic import Field
8
+
9
+ from realfake.config import IMAGE_FORMATS
10
+ from realfake.utils import inject_args, Args
11
+
12
+
13
+ class CreateMetadataArgs(Args):
14
+ root_dir: Path
15
+ datasets: str = Field(..., help="Comma-separated list of datasets to include in the meta-data file")
16
+ jsonl_file: Path = Field(..., help="Path to the output JSONL file")
17
+
18
+
19
+ @inject_args
20
+ def main(args: CreateMetadataArgs) -> None:
21
+ datasets = args.datasets.split(",")
22
+ records = []
23
+ for dataset in datasets:
24
+ label = "real" if dataset.startswith("real") else "fake"
25
+ dirpath = args.root_dir/dataset
26
+ assert dirpath.exists(), f"dataset dir does not exist: {dirpath}"
27
+ records.extend((parse_imagenet if "imagenet" in dataset else parse_flat)(dirpath, label))
28
+ with open(args.jsonl_file, "w") as f:
29
+ for record in records:
30
+ f.write(json.dumps(record) + "\n")
31
+
32
+
33
+ def parse_imagenet(dirpath: Path, label: str) -> list:
34
+ records = []
35
+ for classdir in dirpath.iterdir():
36
+ assert classdir.is_dir(), f"class directory is not a directory: {classdir}"
37
+ for fn in classdir.iterdir():
38
+ if fn.suffix.lower() in IMAGE_FORMATS:
39
+ records.append({"path": str(fn), "label": label, "class": classdir.name})
40
+ else:
41
+ print("Not an image file:", fn)
42
+ return records
43
+
44
+
45
+ def parse_flat(dirpath: Path, label: str) -> list:
46
+ records = []
47
+ for fn in dirpath.iterdir():
48
+ if fn.suffix.lower() in IMAGE_FORMATS:
49
+ records.append({"path": str(fn), "label": label, "class": None})
50
+ else:
51
+ print("Not an image file:", fn)
52
+ return records
53
+
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
realfake/bin/diffusion_db.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from hashlib import md5
3
+ from pathlib import Path
4
+
5
+ import datasets
6
+ from tqdm import tqdm
7
+
8
+ from realfake.utils import Args, inject_args
9
+
10
+
11
+ class DownloadParams(Args):
12
+ output_dir: Path
13
+ subset: str = "2m_first_1k"
14
+
15
+
16
+ @inject_args
17
+ def main(params: DownloadParams) -> None:
18
+ dataset = datasets.load_dataset("poloclub/diffusiondb", params.subset, split="train", streaming=True)
19
+
20
+ output_dir = params.output_dir/params.subset
21
+ output_dir.mkdir(parents=True, exist_ok=True)
22
+
23
+ with (output_dir/"test.jsonl").open("w") as fp:
24
+ for item in tqdm(dataset, total=None):
25
+ image_id = md5((item["prompt"] + str(item["seed"])).encode()).hexdigest()
26
+ filename = output_dir/f"{image_id}.png"
27
+ if not filename.exists():
28
+ item["image"].save(filename)
29
+ record = {"path": str(filename), "label": "fake", "class": None, "valid": False}
30
+ fp.write(f"{json.dumps(record)}\n")
31
+
32
+ print(f"Saved records to {output_dir}")
33
+
34
+
35
+ if __name__ == "__main__":
36
+ main()
realfake/bin/download_s3.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import tarfile
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import boto3
7
+ from joblib import Parallel, delayed
8
+
9
+ from realfake.utils import get_user_name
10
+
11
+
12
+ def main() -> None:
13
+ bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
14
+ start_idx, end_idx = 400, 700
15
+ keys_range = list(range(start_idx, end_idx))
16
+
17
+ output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
18
+ output_dir.mkdir(parents=True, exist_ok=True)
19
+
20
+ jobs = get_jobs(keys_range, bucket, prefix, output_dir)
21
+
22
+ Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
23
+
24
+
25
+ @dataclass
26
+ class Job:
27
+ bucket: str
28
+ key: Path
29
+ output_dir: Path
30
+
31
+
32
+ def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> list[Job]:
33
+ client = boto3.client("s3")
34
+
35
+ token, jobs = None, []
36
+
37
+ while True:
38
+ conf = dict(Bucket=bucket, Prefix=prefix)
39
+ if token is not None: conf["ContinuationToken"] = token
40
+ response = client.list_objects_v2(**conf)
41
+
42
+ for item in response.get("Contents"):
43
+ key = Path(item["Key"])
44
+ if key.suffix == ".tar" and int(key.stem) in keys_range:
45
+ jobs.append(Job(bucket, key, output_dir))
46
+
47
+ if not response["IsTruncated"]: break
48
+ token = response["NextContinuationToken"]
49
+
50
+ return jobs
51
+
52
+
53
+ def download_and_extract(job: Job) -> None:
54
+ client = boto3.client("s3")
55
+ tar_file = job.output_dir / job.key.name
56
+
57
+ print(f"{job.key}: downloading...")
58
+ client.download_file(job.bucket, str(job.key), tar_file)
59
+
60
+ print(f"{job.key}: extracting...")
61
+ with tarfile.open(tar_file) as tar:
62
+ for name in tar.getnames():
63
+ if name.endswith(".jpg"):
64
+ tar.extract(name, job.output_dir)
65
+
66
+ print(f"{job.key}: done!")
67
+ tar_file.unlink()
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
realfake/bin/imagenet.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unpacks tar files from Imagenet-1k dataset while keeping the original directory structure.
3
+
4
+ The script only unpacks the files from training subset.
5
+ """
6
+ import tarfile
7
+ from pathlib import Path
8
+ from joblib import delayed, Parallel
9
+ from realfake.utils import inject_args, Args
10
+
11
+
12
+ class ImagenetArgs(Args):
13
+ imagenet_dir: Path
14
+ unpacked_dir: Path
15
+
16
+
17
+ @inject_args
18
+ def main(args: ImagenetArgs) -> None:
19
+ train_dir = args.imagenet_dir/"train"
20
+ assert train_dir.exists(), f"Directory {train_dir} does not exist"
21
+ archives = train_dir.glob("*.tar")
22
+ Parallel(n_jobs=-1, verbose=100)(delayed(unpack_tar)(tar_file, args.unpacked_dir) for tar_file in archives)
23
+
24
+
25
+ def unpack_tar(tar_file: Path, output_dir: Path) -> None:
26
+ output_subdir = output_dir/tar_file.stem
27
+ with tarfile.open(tar_file) as tar:
28
+ tar.extractall(output_subdir)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
realfake/bin/inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ from realfake.data import DictDataset, get_augs
9
+ from realfake.models import RealFakeClassifier, RealFakeParams
10
+ from realfake.utils import Args, inject_args, read_jsonl
11
+
12
+
13
+ class InferenceParams(Args):
14
+ checkpoint_path: Path
15
+ test_file: Path
16
+ map_location: str = "cpu"
17
+ num_workers: int = 16
18
+
19
+
20
+ @inject_args
21
+ def main(params: InferenceParams) -> None:
22
+ checkpoint = torch.load(params.checkpoint_path, map_location=params.map_location)
23
+
24
+ # todo: use PL mechanism to store hparams
25
+ model = RealFakeClassifier(RealFakeParams.parse_file(params.checkpoint_path.parent/"params.json"))
26
+ model.load_state_dict(checkpoint["state_dict"])
27
+ model.eval()
28
+
29
+ records = read_jsonl(params.test_file)
30
+
31
+ for _ in range(10):
32
+ selected = random.sample(records, k=1000)
33
+
34
+ with torch.inference_mode():
35
+ ds = DictDataset(selected, get_augs(train=False))
36
+ dl = DataLoader(ds, batch_size=32, num_workers=params.num_workers, shuffle=False)
37
+ matched, total = 0, len(ds)
38
+
39
+ for batch in dl:
40
+ _, logits, y_true_onehot = model(batch)
41
+ y_true = y_true_onehot.argmax(dim=1)
42
+ y_pred = logits.softmax(dim=1).argmax(dim=1)
43
+ matched += (y_true == y_pred).sum().item()
44
+
45
+ print(f"Accuracy: {matched/total:2.2%}")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
realfake/bin/unpack_diffusion_db.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import zipfile
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from joblib import Parallel, delayed
6
+ from realfake.utils import get_user_name, inject_args, Args
7
+
8
+
9
+ class UnpackParams(Args):
10
+ meta_file: Path
11
+ jsonl_file: Path
12
+ num_workers: int = 16
13
+
14
+
15
+ def unpack(zip_path: Path, output_dir: Path):
16
+ print("extracting", zip_path)
17
+ with zipfile.ZipFile(zip_path, "r") as arch:
18
+ paths = [str(output_dir/fn) for fn in arch.namelist() if fn.endswith(".png")]
19
+ arch.extractall(output_dir)
20
+ return paths
21
+
22
+
23
+ @inject_args
24
+ def main(params: UnpackParams) -> None:
25
+ subset_name = params.meta_file.stem
26
+ output_dir = Path(f"/fsx/{get_user_name()}/data/fake_{subset_name}")
27
+ output_dir.mkdir(parents=True, exist_ok=True)
28
+ meta = json.loads(params.meta_file.read_text())
29
+ with Parallel(n_jobs=params.num_workers, verbose=100) as parallel:
30
+ results = parallel(delayed(unpack)(Path(m["path"]), output_dir) for m in meta if m["ok"])
31
+ records = [
32
+ {"path": str(fn), "label": "fake", "class": None, "valid": None}
33
+ for fn in chain.from_iterable(results)
34
+ ]
35
+ with params.jsonl_file.open("w") as fp:
36
+ for record in records:
37
+ fp.write(json.dumps(record) + "\n")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
realfake/callbacks.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.callbacks import BaseFinetuning, Callback
6
+ from pytorch_lightning.utilities import rank_zero_info
7
+
8
+
9
+ class ConsoleLogger(Callback):
10
+
11
+ def __init__(self):
12
+ super().__init__()
13
+ self._reset()
14
+
15
+ def get_history(self) -> list:
16
+ return list(self._history)
17
+
18
+ def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
19
+ self._reset()
20
+
21
+ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
22
+ [lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() # type: ignore
23
+ log = {"epoch": trainer.current_epoch, "lr": lr}
24
+ log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()})
25
+ self._history.append(log)
26
+ formatted = []
27
+ for key, value in log.items():
28
+ if isinstance(value, int):
29
+ kv = f"{key}={value:3d}"
30
+ elif isinstance(value, float):
31
+ kv = f"{key}={value:.4f}"
32
+ else:
33
+ kv = f"{key}={value}"
34
+ formatted.append(kv)
35
+ rank_zero_info(" | ".join(formatted))
36
+
37
+ def _reset(self):
38
+ self._history = []
39
+
40
+
41
+ class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
42
+
43
+ def __init__(self, unfreeze_at_epoch: int):
44
+ super().__init__()
45
+ self._unfreeze_at_epoch = unfreeze_at_epoch
46
+
47
+ def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
48
+ rank_zero_info("Freezing backbone")
49
+ self.freeze(_get_backbone(pl_module.model))
50
+
51
+ def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
52
+ if epoch == self._unfreeze_at_epoch:
53
+ rank_zero_info(f"Unfreezing backbone at epoch {epoch}")
54
+ self.unfreeze_and_add_param_group(
55
+ modules=_get_backbone(pl_module.model),
56
+ optimizer=optimizer,
57
+ train_bn=True,
58
+ )
59
+
60
+
61
+ def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]:
62
+ for name, child in module.named_children():
63
+ if name.startswith("head"):
64
+ continue
65
+ yield child
realfake/config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ SEED = 1
2
+ LABELS = {"real": 0, "fake": 1}
3
+ SUBSETS = ("train", "validation")
4
+ IMAGE_FORMATS = ".jpeg", ".jpg", ".png"
realfake/data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import random
3
+
4
+ import albumentations as A
5
+ import numpy as np
6
+ import PIL.Image
7
+ from albumentations.pytorch.transforms import ToTensorV2
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+ from realfake.config import LABELS
11
+
12
+ IMG_RESIZE = 256
13
+ IMG_CROP = 224
14
+
15
+
16
+ class DictDataset(Dataset):
17
+ def __init__(self, records: list[dict], transform_x=None):
18
+ self.records = records
19
+ self.transform_x = transform_x
20
+
21
+ def __len__(self):
22
+ return len(self.records)
23
+
24
+ def __getitem__(self, idx):
25
+ record = self.records[idx]
26
+ image = np.asarray(PIL.Image.open(record["path"]))
27
+ if self.transform_x is not None:
28
+ image = self.transform_x(image=image)["image"]
29
+ item = {"image": image}
30
+ if "label" in record:
31
+ item["label"] = LABELS[record["label"]]
32
+ return item
33
+
34
+
35
+ def get_augs(train: bool = True) -> A.Compose:
36
+ if train:
37
+ return A.Compose([
38
+ A.Resize(IMG_RESIZE, IMG_RESIZE),
39
+ A.RandomCrop(IMG_CROP, IMG_CROP),
40
+ A.HorizontalFlip(),
41
+ A.VerticalFlip(),
42
+ A.RandomBrightnessContrast(),
43
+ A.Affine(),
44
+ A.Rotate(),
45
+ A.CoarseDropout(),
46
+ ExpandChannels(),
47
+ RGBAtoRGB(),
48
+ A.Normalize(),
49
+ ToTensorV2(),
50
+ ])
51
+ else:
52
+ return A.Compose([
53
+ A.Resize(IMG_RESIZE, IMG_RESIZE),
54
+ A.CenterCrop(IMG_CROP, IMG_CROP),
55
+ ExpandChannels(),
56
+ RGBAtoRGB(),
57
+ A.Normalize(),
58
+ ToTensorV2(),
59
+ ])
60
+
61
+
62
+ class ExpandChannels(A.ImageOnlyTransform):
63
+ """Expands image up to three channes if the image is grayscale."""
64
+
65
+ def __init__(self, always_apply: bool = False, p: float = 0.5):
66
+ super().__init__(True, 1.0)
67
+
68
+ def apply(self, image, **params):
69
+ if image.ndim == 2:
70
+ image = np.repeat(image[..., None], 3, axis=2)
71
+ elif image.shape[2] == 1:
72
+ image = np.repeat(image, 3, axis=2)
73
+ return image
74
+
75
+
76
+ class RGBAtoRGB(A.ImageOnlyTransform):
77
+ """Converts RGBA image to RGB."""
78
+
79
+ def __init__(self, always_apply: bool = False, p: float = 0.5):
80
+ super().__init__(True, 1.0)
81
+
82
+ def apply(self, image, **params):
83
+ if image.shape[2] == 4:
84
+ image = image[:, :, :3]
85
+ return image
86
+
87
+
88
+ def get_dss(records: list) -> tuple[DictDataset, DictDataset]:
89
+ train_records = [x for x in records if not x["valid"]]
90
+ valid_records = [x for x in records if x["valid"]]
91
+ assert len(train_records) + len(valid_records) == len(records)
92
+ random.shuffle(train_records)
93
+ train_ds = DictDataset(train_records, transform_x=get_augs(train=True))
94
+ valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False))
95
+ return train_ds, valid_ds
96
+
97
+
98
+ def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]:
99
+ train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers)
100
+ valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
101
+ return train_dl, valid_dl
realfake/models.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import pytorch_lightning as pl
6
+ import timm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchmetrics
10
+ from pydantic import BaseModel, Field
11
+
12
+ from realfake.data import get_dss, get_dls
13
+ from realfake.utils import Args
14
+
15
+ N_CLASSES = 2
16
+
17
+
18
+ class AcceleratorParams(BaseModel):
19
+ """PyTorch Lightning accelerator parameters."""
20
+
21
+ name: str = Field("gpu")
22
+ devices: int = Field(4)
23
+ strategy: str = Field("dp")
24
+ precision: int = Field(16)
25
+ override_float32_matmul: bool = Field(True)
26
+ float32_matmul: str = Field("medium")
27
+
28
+
29
+ class RealFakeParams(Args):
30
+ jsonl_file: Path
31
+ dry_run: bool = Field(False)
32
+ model_name: str = Field("convnext_tiny")
33
+ batch_size: int = Field(256)
34
+ freeze_epochs: int = Field(3)
35
+ epochs: int = Field(6)
36
+ base_lr: float = Field(1e-3)
37
+ pretrained: bool = Field(True)
38
+ accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
39
+
40
+
41
+ class RealFakeDataModule(pl.LightningDataModule):
42
+
43
+ def __init__(self, jsonl_records: Path, batch_size: int, num_workers: int = 0):
44
+ super().__init__()
45
+ self.jsonl_records = jsonl_records
46
+ self.batch_size = batch_size
47
+ self.num_workers = num_workers
48
+ self.dss = self.dls = None
49
+
50
+ def setup(self, stage=None):
51
+ records = [json.loads(line) for line in self.jsonl_records.open()]
52
+ self.dss = get_dss(records)
53
+ self.dls = get_dls(*self.dss, self.batch_size, self.num_workers)
54
+
55
+ def train_dataloader(self):
56
+ return self.dls[0]
57
+
58
+ def val_dataloader(self):
59
+ return self.dls[1]
60
+
61
+
62
+ class RealFakeClassifier(pl.LightningModule):
63
+
64
+ def __init__(self, params: RealFakeParams):
65
+ super().__init__()
66
+ self.params = params
67
+ self.ce = nn.BCEWithLogitsLoss()
68
+ self.model = timm.create_model(params.model_name, pretrained=params.pretrained, num_classes=N_CLASSES)
69
+ self.acc = torchmetrics.Accuracy(task="binary")
70
+
71
+ def train_dataloader(self):
72
+ return self.dls.train
73
+
74
+ def val_dataloader(self):
75
+ return self.dls.valid
76
+
77
+ def forward(self, batch):
78
+ x, y = batch["image"], batch["label"]
79
+ y = torch.nn.functional.one_hot(y, num_classes=N_CLASSES).float()
80
+ out = self.model(x)
81
+ loss = self.ce(out, y)
82
+ return loss, out, y
83
+
84
+ def training_step(self, batch, batch_idx):
85
+ loss, _, _ = self.forward(batch)
86
+ self.log("train_loss", loss, on_epoch=True, on_step=False)
87
+ return loss
88
+
89
+ def validation_step(self, batch, batch_idx):
90
+ loss, out, y = self.forward(batch)
91
+ y_pred = out.sigmoid().argmax(dim=-1)
92
+ y_true = y.argmax(dim=-1)
93
+ self.log("val_loss", loss, on_epoch=True, on_step=False)
94
+ return {"gt": y_true, "yhat": y_pred}
95
+
96
+ def validation_step_end(self, outputs):
97
+ self.acc.update(outputs["yhat"], outputs["gt"])
98
+
99
+ def validation_epoch_end(self, outputs):
100
+ self.log("val_acc", self.acc.compute(), on_epoch=True)
101
+ self.acc.reset()
102
+
103
+ def configure_optimizers(self):
104
+ adamw = torch.optim.AdamW(self.parameters(), lr=self.params.base_lr)
105
+ one_cycle = torch.optim.lr_scheduler.OneCycleLR(
106
+ adamw,
107
+ max_lr=self.params.base_lr,
108
+ total_steps=self.trainer.estimated_stepping_batches
109
+ )
110
+ return [adamw], [one_cycle]
realfake/train.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import signal
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from pytorch_lightning.plugins.environments import SLURMEnvironment
10
+
11
+ from realfake.callbacks import ConsoleLogger
12
+ from realfake.models import RealFakeParams
13
+ from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
14
+
15
+
16
+ def get_existing_checkpoint(job_id: str | None = None) -> tuple:
17
+ if job_id is None:
18
+ checkpoints_dir = get_checkpoints_dir(timestamp=True)
19
+ else:
20
+ checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id
21
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
22
+ existing_checkpoint = find_latest_checkpoint(checkpoints_dir)
23
+ return checkpoints_dir, existing_checkpoint
24
+
25
+
26
+ def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
27
+ job_id = os.environ.get("SLURM_JOB_ID")
28
+ checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id)
29
+
30
+ if job_id is None:
31
+ print("SLURM job id is not found, running locally.")
32
+
33
+ if existing_checkpoint is None:
34
+ print("No existing checkpoint found, starting from scratch.")
35
+
36
+ if args.accelerator.override_float32_matmul:
37
+ torch.set_float32_matmul_precision(args.accelerator.float32_matmul)
38
+
39
+ with (checkpoints_dir/"params.json").open("w") as fp:
40
+ fp.write(args.json())
41
+
42
+ trainer_params = dict(
43
+ accelerator=args.accelerator.name,
44
+ devices=args.accelerator.devices,
45
+ precision=args.accelerator.precision,
46
+ max_epochs=args.epochs,
47
+ num_nodes=1,
48
+ num_sanity_val_steps=0,
49
+ enable_progress_bar=False,
50
+ callbacks=[
51
+ ConsoleLogger(),
52
+ ModelCheckpoint(
53
+ monitor="val_acc",
54
+ mode="max",
55
+ save_last=True,
56
+ save_top_k=1,
57
+ dirpath=checkpoints_dir,
58
+ filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
59
+ ),
60
+ ],
61
+ resume_from_checkpoint=existing_checkpoint,
62
+ )
63
+
64
+ if job_id is not None:
65
+ trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP),
66
+ trainer_params["strategy"] = args.accelerator.strategy
67
+
68
+ return pl.Trainer(**trainer_params)
realfake/train_cluster.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import pytorch_lightning as pl
4
+
5
+ from realfake.config import SEED
6
+ from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams
7
+ from realfake.train import prepare_trainer
8
+
9
+
10
+ def main() -> None:
11
+ pl.seed_everything(SEED)
12
+ args = RealFakeParams.from_args()
13
+ model = RealFakeClassifier(args)
14
+ data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4)
15
+ trainer = prepare_trainer(args)
16
+
17
+ if args.dry_run:
18
+ print("Dry run, skipping training.")
19
+ print("Model summary:")
20
+ print(model)
21
+ print("Data summary:")
22
+ data.setup()
23
+ print("Train batches:", len(data.dls[0]))
24
+ print("Valid batches:", len(data.dls[1]))
25
+
26
+ else:
27
+ trainer.fit(model, datamodule=data)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ with warnings.catch_warnings():
32
+ warnings.filterwarnings("ignore", category=Warning)
33
+ main()
realfake/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import datetime
4
+ import json
5
+ import os
6
+ from operator import itemgetter
7
+ from pathlib import Path
8
+ from typing import Callable
9
+
10
+ import requests
11
+ import pynvml
12
+ import PIL.Image
13
+ import torch
14
+ from pydantic import BaseSettings, BaseModel
15
+
16
+
17
+ class Args(BaseSettings):
18
+
19
+ @classmethod
20
+ def from_args(cls):
21
+ parser = argparse.ArgumentParser()
22
+ for field in cls.__fields__.values():
23
+ if issubclass(field.type_, BaseModel):
24
+ prefix = field.type_.__name__.lower()
25
+ for subfield in field.type_.__fields__.values():
26
+ short = "".join([x[0] for x in subfield.name.split("_")])
27
+ parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required)
28
+ else:
29
+ short = "".join([x[0] for x in field.name.split("_")])
30
+ parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required)
31
+ args = vars(parser.parse_known_args()[0])
32
+ to_delete = set()
33
+ for field in cls.__fields__.values():
34
+ if issubclass(field.type_, BaseModel):
35
+ prefix = field.type_.__name__.lower()
36
+ sub_args = {}
37
+ for k, v in args.items():
38
+ if k.startswith(prefix):
39
+ to_delete.add(k)
40
+ sub_args[k.replace(f"{prefix}.", "")] = v
41
+ args[field.name] = sub_args
42
+ args = {k: v for k, v in args.items() if k not in to_delete}
43
+ return cls(**args)
44
+
45
+ class Config:
46
+ env_file = ".env"
47
+ env_file_encoding = "utf-8"
48
+ env_prefix = "ARG_"
49
+
50
+
51
+ def inject_args(func: Callable) -> Callable:
52
+ """Decorates a function to inject the arguments."""
53
+
54
+ injected = None
55
+ for type_ in func.__annotations__.values():
56
+ if issubclass(type_, Args):
57
+ injected = type_.from_args()
58
+ break
59
+
60
+ if injected is None:
61
+ raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.")
62
+
63
+ def wrapper(*args, **kwargs):
64
+ return func(injected, *args, **kwargs)
65
+
66
+ return wrapper
67
+
68
+
69
+ def get_free_gpu() -> int:
70
+ pynvml.nvmlInit()
71
+ total = torch.cuda.device_count()
72
+ gpus = []
73
+ for i in range(total):
74
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
75
+ info = pynvml.nvmlDeviceGetMemoryInfo(handle)
76
+ gpus.append((i, info.free))
77
+ gpus = sorted(gpus, key=itemgetter(1), reverse=True)
78
+ return gpus[0][0]
79
+
80
+
81
+ def get_user_name() -> str:
82
+ return Path(os.environ["HOME"]).stem
83
+
84
+
85
+ def get_storage_dir() -> Path:
86
+ return Path(f"/fsx/{get_user_name()}")
87
+
88
+
89
+ def get_checkpoints_dir(*, timestamp: bool) -> Path:
90
+ base_dir = get_storage_dir()/"checkpoints"
91
+ return Path(f"{base_dir}/{now()}") if timestamp else base_dir
92
+
93
+
94
+ def now() -> str:
95
+ return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
96
+
97
+
98
+ def read_jsonl(path: Path) -> list:
99
+ return [json.loads(x) for x in Path(path).read_text().split("\n") if x]
100
+
101
+
102
+ def write_jsonl(path: Path, data: list):
103
+ with Path(path).open("w") as f:
104
+ for x in data:
105
+ f.write(json.dumps(x) + "\n")
106
+
107
+
108
+ def get_image(url: str, filename: Path | None = None):
109
+ if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg")
110
+ filename = Path(filename)
111
+ if filename.exists(): return filename
112
+ PIL.Image.open(requests.get(url, stream=True).raw).save(filename)
113
+ return filename
114
+
115
+
116
+ def find_latest_checkpoint(dirname: Path) -> Path:
117
+ checkpoints = list(dirname.glob("*.ckpt"))
118
+ if not checkpoints:
119
+ return None
120
+ latest = max(checkpoints, key=lambda path: path.stat().st_mtime)
121
+ return latest
submit.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -l
2
+
3
+ # SLURM SUBMIT SCRIPT
4
+ #SBATCH --partition=g40
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gpus=8
7
+ #SBATCH --cpus-per-gpu=6
8
+ #SBATCH --job-name=realfake
9
+ #SBATCH --comment=laion
10
+ #SBATCH --signal=SIGUSR1@90
11
+
12
+ source "${HOME}/venv/bin/activate"
13
+
14
+ export NCCL_DEBUG=INFO
15
+ export PYTHONFAULTHANDLER=1
16
+ export PYTHONPATH="${HOME}/realfake"
17
+
18
+ echo "Working directory: `pwd`"
19
+
20
+ srun python3 realfake/train_cluster.py \
21
+ -jf "${HOME}/realfake/metadata/prepared.2000k.jsonl" \
22
+ -mn convnext_large -e 5 -bs 128 \
23
+ --acceleratorparams.devices=8 \
24
+ --acceleratorparams.strategy=ddp_find_unused_parameters_false