devforfu
commited on
Commit
•
ea847ad
0
Parent(s):
Init
Browse files- .gitignore +135 -0
- nbs/prepare.ipynb +568 -0
- realfake/bin/check_files.py +51 -0
- realfake/bin/create_metadata.py +57 -0
- realfake/bin/diffusion_db.py +36 -0
- realfake/bin/download_s3.py +71 -0
- realfake/bin/imagenet.py +32 -0
- realfake/bin/inference.py +49 -0
- realfake/bin/unpack_diffusion_db.py +41 -0
- realfake/callbacks.py +65 -0
- realfake/config.py +4 -0
- realfake/data.py +101 -0
- realfake/models.py +110 -0
- realfake/train.py +68 -0
- realfake/train_cluster.py +33 -0
- realfake/utils.py +121 -0
- submit.sh +24 -0
.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
**/lightning_logs
|
132 |
+
.*
|
133 |
+
!.gitignore
|
134 |
+
*.out
|
135 |
+
|
nbs/prepare.ipynb
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "a0dae3b6-0612-4744-a466-5c8be9c62923",
|
7 |
+
"metadata": {
|
8 |
+
"tags": []
|
9 |
+
},
|
10 |
+
"outputs": [
|
11 |
+
{
|
12 |
+
"name": "stdout",
|
13 |
+
"output_type": "stream",
|
14 |
+
"text": [
|
15 |
+
"/admin/home-devforfu/realfake\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"%cd .."
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 4,
|
26 |
+
"id": "4b5f4d1b-40a8-4a61-88ea-502103368b1c",
|
27 |
+
"metadata": {
|
28 |
+
"tags": []
|
29 |
+
},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"import json\n",
|
33 |
+
"from pathlib import Path\n",
|
34 |
+
"import pandas as pd\n",
|
35 |
+
"from realfake.utils import read_jsonl, write_jsonl"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 6,
|
41 |
+
"id": "df083619-aa93-49b5-aa76-1d2680062927",
|
42 |
+
"metadata": {
|
43 |
+
"tags": []
|
44 |
+
},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"df_all = pd.DataFrame(read_jsonl(\"metadata/all.jsonl\"))"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 7,
|
53 |
+
"id": "f91769e7-a4b9-4012-9079-441c364d32b3",
|
54 |
+
"metadata": {
|
55 |
+
"tags": []
|
56 |
+
},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"df_fail = pd.DataFrame(read_jsonl(\"metadata/all.failed.jsonl\"))"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 8,
|
65 |
+
"id": "88c02485-fdf4-42c0-8896-70f43bfaf76f",
|
66 |
+
"metadata": {
|
67 |
+
"tags": []
|
68 |
+
},
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"data": {
|
72 |
+
"text/html": [
|
73 |
+
"<div>\n",
|
74 |
+
"<style scoped>\n",
|
75 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
76 |
+
" vertical-align: middle;\n",
|
77 |
+
" }\n",
|
78 |
+
"\n",
|
79 |
+
" .dataframe tbody tr th {\n",
|
80 |
+
" vertical-align: top;\n",
|
81 |
+
" }\n",
|
82 |
+
"\n",
|
83 |
+
" .dataframe thead th {\n",
|
84 |
+
" text-align: right;\n",
|
85 |
+
" }\n",
|
86 |
+
"</style>\n",
|
87 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
88 |
+
" <thead>\n",
|
89 |
+
" <tr style=\"text-align: right;\">\n",
|
90 |
+
" <th></th>\n",
|
91 |
+
" <th>path</th>\n",
|
92 |
+
" <th>label</th>\n",
|
93 |
+
" <th>class</th>\n",
|
94 |
+
" </tr>\n",
|
95 |
+
" </thead>\n",
|
96 |
+
" <tbody>\n",
|
97 |
+
" <tr>\n",
|
98 |
+
" <th>0</th>\n",
|
99 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
|
100 |
+
" <td>real</td>\n",
|
101 |
+
" <td>n02797295</td>\n",
|
102 |
+
" </tr>\n",
|
103 |
+
" <tr>\n",
|
104 |
+
" <th>1</th>\n",
|
105 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
|
106 |
+
" <td>real</td>\n",
|
107 |
+
" <td>n02797295</td>\n",
|
108 |
+
" </tr>\n",
|
109 |
+
" <tr>\n",
|
110 |
+
" <th>2</th>\n",
|
111 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n02797...</td>\n",
|
112 |
+
" <td>real</td>\n",
|
113 |
+
" <td>n02797295</td>\n",
|
114 |
+
" </tr>\n",
|
115 |
+
" </tbody>\n",
|
116 |
+
"</table>\n",
|
117 |
+
"</div>"
|
118 |
+
],
|
119 |
+
"text/plain": [
|
120 |
+
" path label class\n",
|
121 |
+
"0 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295\n",
|
122 |
+
"1 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295\n",
|
123 |
+
"2 /fsx/home-devforfu/data/real_imagenet1k/n02797... real n02797295"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
"execution_count": 8,
|
127 |
+
"metadata": {},
|
128 |
+
"output_type": "execute_result"
|
129 |
+
}
|
130 |
+
],
|
131 |
+
"source": [
|
132 |
+
"df_all.head(3)"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": 9,
|
138 |
+
"id": "9bbd0247-1f34-4adc-8647-525792e6d3e5",
|
139 |
+
"metadata": {
|
140 |
+
"tags": []
|
141 |
+
},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"df_ok = df_all[~df_all.path.isin(df_fail.path)].reset_index(drop=True)"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": 10,
|
150 |
+
"id": "77b400a8-83aa-4e38-983a-d56b71245ac9",
|
151 |
+
"metadata": {
|
152 |
+
"tags": []
|
153 |
+
},
|
154 |
+
"outputs": [
|
155 |
+
{
|
156 |
+
"data": {
|
157 |
+
"text/html": [
|
158 |
+
"<div>\n",
|
159 |
+
"<style scoped>\n",
|
160 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
161 |
+
" vertical-align: middle;\n",
|
162 |
+
" }\n",
|
163 |
+
"\n",
|
164 |
+
" .dataframe tbody tr th {\n",
|
165 |
+
" vertical-align: top;\n",
|
166 |
+
" }\n",
|
167 |
+
"\n",
|
168 |
+
" .dataframe thead th {\n",
|
169 |
+
" text-align: right;\n",
|
170 |
+
" }\n",
|
171 |
+
"</style>\n",
|
172 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
173 |
+
" <thead>\n",
|
174 |
+
" <tr style=\"text-align: right;\">\n",
|
175 |
+
" <th></th>\n",
|
176 |
+
" <th>path</th>\n",
|
177 |
+
" <th>label</th>\n",
|
178 |
+
" <th>class</th>\n",
|
179 |
+
" </tr>\n",
|
180 |
+
" </thead>\n",
|
181 |
+
" <tbody>\n",
|
182 |
+
" <tr>\n",
|
183 |
+
" <th>1517638</th>\n",
|
184 |
+
" <td>/fsx/home-devforfu/data/fake_imagenet1k/n02027...</td>\n",
|
185 |
+
" <td>fake</td>\n",
|
186 |
+
" <td>n02027492</td>\n",
|
187 |
+
" </tr>\n",
|
188 |
+
" <tr>\n",
|
189 |
+
" <th>1026755</th>\n",
|
190 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n01669...</td>\n",
|
191 |
+
" <td>real</td>\n",
|
192 |
+
" <td>n01669191</td>\n",
|
193 |
+
" </tr>\n",
|
194 |
+
" <tr>\n",
|
195 |
+
" <th>7790495</th>\n",
|
196 |
+
" <td>/fsx/home-devforfu/data/fake_2m_all/d8713853-0...</td>\n",
|
197 |
+
" <td>fake</td>\n",
|
198 |
+
" <td>None</td>\n",
|
199 |
+
" </tr>\n",
|
200 |
+
" </tbody>\n",
|
201 |
+
"</table>\n",
|
202 |
+
"</div>"
|
203 |
+
],
|
204 |
+
"text/plain": [
|
205 |
+
" path label class\n",
|
206 |
+
"1517638 /fsx/home-devforfu/data/fake_imagenet1k/n02027... fake n02027492\n",
|
207 |
+
"1026755 /fsx/home-devforfu/data/real_imagenet1k/n01669... real n01669191\n",
|
208 |
+
"7790495 /fsx/home-devforfu/data/fake_2m_all/d8713853-0... fake None"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
"execution_count": 10,
|
212 |
+
"metadata": {},
|
213 |
+
"output_type": "execute_result"
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"source": [
|
217 |
+
"df_ok.sample(3)"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": 11,
|
223 |
+
"id": "321ac973-5b3b-4626-aa95-e4306680099e",
|
224 |
+
"metadata": {
|
225 |
+
"tags": []
|
226 |
+
},
|
227 |
+
"outputs": [
|
228 |
+
{
|
229 |
+
"data": {
|
230 |
+
"text/plain": [
|
231 |
+
"real 4184273\n",
|
232 |
+
"fake 4160720\n",
|
233 |
+
"Name: label, dtype: int64"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
"execution_count": 11,
|
237 |
+
"metadata": {},
|
238 |
+
"output_type": "execute_result"
|
239 |
+
}
|
240 |
+
],
|
241 |
+
"source": [
|
242 |
+
"df_ok[\"label\"].value_counts()"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": 12,
|
248 |
+
"id": "b0c8b90b-49b3-4aaf-89f2-be2277459ec7",
|
249 |
+
"metadata": {
|
250 |
+
"tags": []
|
251 |
+
},
|
252 |
+
"outputs": [],
|
253 |
+
"source": [
|
254 |
+
"from sklearn.model_selection import train_test_split\n",
|
255 |
+
"\n",
|
256 |
+
"def create_metadata(dataset, test_size: float = 0.1, sample: int = None, seed: int = 1):\n",
|
257 |
+
" if sample is not None:\n",
|
258 |
+
" real = dataset[dataset[\"label\"] == \"real\"].sample(sample)\n",
|
259 |
+
" fake = dataset[dataset[\"label\"] == \"fake\"].sample(sample)\n",
|
260 |
+
" dataset = pd.concat([real, fake])\n",
|
261 |
+
" \n",
|
262 |
+
" imagenet_classes = dataset[\"class\"].dropna().unique()\n",
|
263 |
+
" \n",
|
264 |
+
" trn, val = train_test_split(imagenet_classes, test_size=test_size, random_state=seed)\n",
|
265 |
+
" trn_data = dataset[dataset[\"class\"].isin(trn)]\n",
|
266 |
+
" val_data = dataset[dataset[\"class\"].isin(val)]\n",
|
267 |
+
"\n",
|
268 |
+
" no_class = dataset[dataset[\"class\"].isna()]\n",
|
269 |
+
" trn_data_null, val_data_null = train_test_split(no_class, test_size=test_size, random_state=seed)\n",
|
270 |
+
" \n",
|
271 |
+
" trn_data = pd.concat([trn_data, trn_data_null])\n",
|
272 |
+
" trn_data[\"valid\"] = False\n",
|
273 |
+
" val_data = pd.concat([val_data, val_data_null])\n",
|
274 |
+
" val_data[\"valid\"] = True\n",
|
275 |
+
" \n",
|
276 |
+
" assert not set(trn_data[\"class\"].dropna()).intersection(val_data[\"class\"].dropna())\n",
|
277 |
+
" \n",
|
278 |
+
" return pd.concat([trn_data, val_data])"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 13,
|
284 |
+
"id": "ae4e4d48-9f81-4ea8-bb84-592b50fef3a9",
|
285 |
+
"metadata": {
|
286 |
+
"tags": []
|
287 |
+
},
|
288 |
+
"outputs": [],
|
289 |
+
"source": [
|
290 |
+
"n = 1_000_000"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "code",
|
295 |
+
"execution_count": 14,
|
296 |
+
"id": "9ee41625-a30e-4eb9-8989-bda898094a83",
|
297 |
+
"metadata": {
|
298 |
+
"tags": []
|
299 |
+
},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"data": {
|
303 |
+
"text/html": [
|
304 |
+
"<div>\n",
|
305 |
+
"<style scoped>\n",
|
306 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
307 |
+
" vertical-align: middle;\n",
|
308 |
+
" }\n",
|
309 |
+
"\n",
|
310 |
+
" .dataframe tbody tr th {\n",
|
311 |
+
" vertical-align: top;\n",
|
312 |
+
" }\n",
|
313 |
+
"\n",
|
314 |
+
" .dataframe thead th {\n",
|
315 |
+
" text-align: right;\n",
|
316 |
+
" }\n",
|
317 |
+
"</style>\n",
|
318 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
319 |
+
" <thead>\n",
|
320 |
+
" <tr style=\"text-align: right;\">\n",
|
321 |
+
" <th></th>\n",
|
322 |
+
" <th>path</th>\n",
|
323 |
+
" <th>label</th>\n",
|
324 |
+
" <th>class</th>\n",
|
325 |
+
" <th>valid</th>\n",
|
326 |
+
" </tr>\n",
|
327 |
+
" </thead>\n",
|
328 |
+
" <tbody>\n",
|
329 |
+
" <tr>\n",
|
330 |
+
" <th>135038</th>\n",
|
331 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n01917...</td>\n",
|
332 |
+
" <td>real</td>\n",
|
333 |
+
" <td>n01917289</td>\n",
|
334 |
+
" <td>False</td>\n",
|
335 |
+
" </tr>\n",
|
336 |
+
" <tr>\n",
|
337 |
+
" <th>803039</th>\n",
|
338 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n01697...</td>\n",
|
339 |
+
" <td>real</td>\n",
|
340 |
+
" <td>n01697457</td>\n",
|
341 |
+
" <td>False</td>\n",
|
342 |
+
" </tr>\n",
|
343 |
+
" <tr>\n",
|
344 |
+
" <th>1280747</th>\n",
|
345 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n02992...</td>\n",
|
346 |
+
" <td>real</td>\n",
|
347 |
+
" <td>n02992211</td>\n",
|
348 |
+
" <td>False</td>\n",
|
349 |
+
" </tr>\n",
|
350 |
+
" <tr>\n",
|
351 |
+
" <th>130185</th>\n",
|
352 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n04599...</td>\n",
|
353 |
+
" <td>real</td>\n",
|
354 |
+
" <td>n04599235</td>\n",
|
355 |
+
" <td>False</td>\n",
|
356 |
+
" </tr>\n",
|
357 |
+
" <tr>\n",
|
358 |
+
" <th>701554</th>\n",
|
359 |
+
" <td>/fsx/home-devforfu/data/real_imagenet1k/n02108...</td>\n",
|
360 |
+
" <td>real</td>\n",
|
361 |
+
" <td>n02108000</td>\n",
|
362 |
+
" <td>False</td>\n",
|
363 |
+
" </tr>\n",
|
364 |
+
" <tr>\n",
|
365 |
+
" <th>...</th>\n",
|
366 |
+
" <td>...</td>\n",
|
367 |
+
" <td>...</td>\n",
|
368 |
+
" <td>...</td>\n",
|
369 |
+
" <td>...</td>\n",
|
370 |
+
" </tr>\n",
|
371 |
+
" <tr>\n",
|
372 |
+
" <th>7879868</th>\n",
|
373 |
+
" <td>/fsx/home-devforfu/data/fake_2m_all/3cf77f54-2...</td>\n",
|
374 |
+
" <td>fake</td>\n",
|
375 |
+
" <td>None</td>\n",
|
376 |
+
" <td>True</td>\n",
|
377 |
+
" </tr>\n",
|
378 |
+
" <tr>\n",
|
379 |
+
" <th>3542472</th>\n",
|
380 |
+
" <td>/fsx/home-devforfu/data/real_aes_400_700/00485...</td>\n",
|
381 |
+
" <td>real</td>\n",
|
382 |
+
" <td>None</td>\n",
|
383 |
+
" <td>True</td>\n",
|
384 |
+
" </tr>\n",
|
385 |
+
" <tr>\n",
|
386 |
+
" <th>6454613</th>\n",
|
387 |
+
" <td>/fsx/home-devforfu/data/fake_2m_all/1e7c20a8-8...</td>\n",
|
388 |
+
" <td>fake</td>\n",
|
389 |
+
" <td>None</td>\n",
|
390 |
+
" <td>True</td>\n",
|
391 |
+
" </tr>\n",
|
392 |
+
" <tr>\n",
|
393 |
+
" <th>5466667</th>\n",
|
394 |
+
" <td>/fsx/home-devforfu/data/real_aes_400_700/00441...</td>\n",
|
395 |
+
" <td>real</td>\n",
|
396 |
+
" <td>None</td>\n",
|
397 |
+
" <td>True</td>\n",
|
398 |
+
" </tr>\n",
|
399 |
+
" <tr>\n",
|
400 |
+
" <th>6469539</th>\n",
|
401 |
+
" <td>/fsx/home-devforfu/data/fake_2m_all/1b126896-e...</td>\n",
|
402 |
+
" <td>fake</td>\n",
|
403 |
+
" <td>None</td>\n",
|
404 |
+
" <td>True</td>\n",
|
405 |
+
" </tr>\n",
|
406 |
+
" </tbody>\n",
|
407 |
+
"</table>\n",
|
408 |
+
"<p>2000000 rows × 4 columns</p>\n",
|
409 |
+
"</div>"
|
410 |
+
],
|
411 |
+
"text/plain": [
|
412 |
+
" path label class \\\n",
|
413 |
+
"135038 /fsx/home-devforfu/data/real_imagenet1k/n01917... real n01917289 \n",
|
414 |
+
"803039 /fsx/home-devforfu/data/real_imagenet1k/n01697... real n01697457 \n",
|
415 |
+
"1280747 /fsx/home-devforfu/data/real_imagenet1k/n02992... real n02992211 \n",
|
416 |
+
"130185 /fsx/home-devforfu/data/real_imagenet1k/n04599... real n04599235 \n",
|
417 |
+
"701554 /fsx/home-devforfu/data/real_imagenet1k/n02108... real n02108000 \n",
|
418 |
+
"... ... ... ... \n",
|
419 |
+
"7879868 /fsx/home-devforfu/data/fake_2m_all/3cf77f54-2... fake None \n",
|
420 |
+
"3542472 /fsx/home-devforfu/data/real_aes_400_700/00485... real None \n",
|
421 |
+
"6454613 /fsx/home-devforfu/data/fake_2m_all/1e7c20a8-8... fake None \n",
|
422 |
+
"5466667 /fsx/home-devforfu/data/real_aes_400_700/00441... real None \n",
|
423 |
+
"6469539 /fsx/home-devforfu/data/fake_2m_all/1b126896-e... fake None \n",
|
424 |
+
"\n",
|
425 |
+
" valid \n",
|
426 |
+
"135038 False \n",
|
427 |
+
"803039 False \n",
|
428 |
+
"1280747 False \n",
|
429 |
+
"130185 False \n",
|
430 |
+
"701554 False \n",
|
431 |
+
"... ... \n",
|
432 |
+
"7879868 True \n",
|
433 |
+
"3542472 True \n",
|
434 |
+
"6454613 True \n",
|
435 |
+
"5466667 True \n",
|
436 |
+
"6469539 True \n",
|
437 |
+
"\n",
|
438 |
+
"[2000000 rows x 4 columns]"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
"execution_count": 14,
|
442 |
+
"metadata": {},
|
443 |
+
"output_type": "execute_result"
|
444 |
+
}
|
445 |
+
],
|
446 |
+
"source": [
|
447 |
+
"df = create_metadata(df_ok, sample=n)\n",
|
448 |
+
"df"
|
449 |
+
]
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"cell_type": "code",
|
453 |
+
"execution_count": 20,
|
454 |
+
"id": "3094035a-cdf2-4d81-bbaf-d8d185150a27",
|
455 |
+
"metadata": {
|
456 |
+
"tags": []
|
457 |
+
},
|
458 |
+
"outputs": [
|
459 |
+
{
|
460 |
+
"data": {
|
461 |
+
"text/plain": [
|
462 |
+
"'metadata/prepared.2000k.jsonl'"
|
463 |
+
]
|
464 |
+
},
|
465 |
+
"execution_count": 20,
|
466 |
+
"metadata": {},
|
467 |
+
"output_type": "execute_result"
|
468 |
+
}
|
469 |
+
],
|
470 |
+
"source": [
|
471 |
+
"filename = \"prepared.jsonl\" if n is None else f\"prepared.{2*n//1000}k.jsonl\" \n",
|
472 |
+
"filename = f\"metadata/{filename}\"\n",
|
473 |
+
"filename"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"execution_count": 37,
|
479 |
+
"id": "bc990be8-16ca-4a3b-bbc6-eaa652d46d81",
|
480 |
+
"metadata": {
|
481 |
+
"tags": []
|
482 |
+
},
|
483 |
+
"outputs": [],
|
484 |
+
"source": [
|
485 |
+
"write_jsonl(filename, df.to_dict(\"records\"))"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 21,
|
491 |
+
"id": "7f6a3549-affe-459b-bf4a-bb8e2bc0ac62",
|
492 |
+
"metadata": {},
|
493 |
+
"outputs": [],
|
494 |
+
"source": [
|
495 |
+
"df = pd.DataFrame(read_jsonl(filename))"
|
496 |
+
]
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"cell_type": "code",
|
500 |
+
"execution_count": 22,
|
501 |
+
"id": "f36ed292-4f7e-4876-a1ca-b6fc841227b8",
|
502 |
+
"metadata": {
|
503 |
+
"tags": []
|
504 |
+
},
|
505 |
+
"outputs": [
|
506 |
+
{
|
507 |
+
"data": {
|
508 |
+
"text/plain": [
|
509 |
+
"(2000000, 4)"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
"execution_count": 22,
|
513 |
+
"metadata": {},
|
514 |
+
"output_type": "execute_result"
|
515 |
+
}
|
516 |
+
],
|
517 |
+
"source": [
|
518 |
+
"df.shape"
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "code",
|
523 |
+
"execution_count": 23,
|
524 |
+
"id": "b72d0d93-d265-4149-b816-7a62f7a5a17a",
|
525 |
+
"metadata": {
|
526 |
+
"tags": []
|
527 |
+
},
|
528 |
+
"outputs": [
|
529 |
+
{
|
530 |
+
"data": {
|
531 |
+
"text/plain": [
|
532 |
+
"False 0.899385\n",
|
533 |
+
"True 0.100614\n",
|
534 |
+
"Name: valid, dtype: float64"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
"execution_count": 23,
|
538 |
+
"metadata": {},
|
539 |
+
"output_type": "execute_result"
|
540 |
+
}
|
541 |
+
],
|
542 |
+
"source": [
|
543 |
+
"df[\"valid\"].value_counts(normalize=True)"
|
544 |
+
]
|
545 |
+
}
|
546 |
+
],
|
547 |
+
"metadata": {
|
548 |
+
"kernelspec": {
|
549 |
+
"display_name": "Python 3 (ipykernel)",
|
550 |
+
"language": "python",
|
551 |
+
"name": "python3"
|
552 |
+
},
|
553 |
+
"language_info": {
|
554 |
+
"codemirror_mode": {
|
555 |
+
"name": "ipython",
|
556 |
+
"version": 3
|
557 |
+
},
|
558 |
+
"file_extension": ".py",
|
559 |
+
"mimetype": "text/x-python",
|
560 |
+
"name": "python",
|
561 |
+
"nbconvert_exporter": "python",
|
562 |
+
"pygments_lexer": "ipython3",
|
563 |
+
"version": "3.8.10"
|
564 |
+
}
|
565 |
+
},
|
566 |
+
"nbformat": 4,
|
567 |
+
"nbformat_minor": 5
|
568 |
+
}
|
realfake/bin/check_files.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Check that files that are references from JSONL file are valid.
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
from joblib import delayed, Parallel
|
10 |
+
|
11 |
+
from realfake.utils import inject_args, Args, read_jsonl
|
12 |
+
|
13 |
+
|
14 |
+
class CheckFilesArgs(Args):
|
15 |
+
jsonl_file: Path
|
16 |
+
|
17 |
+
|
18 |
+
@inject_args
|
19 |
+
def main(args: CheckFilesArgs) -> None:
|
20 |
+
records = read_jsonl(args.jsonl_file)
|
21 |
+
results = Parallel(n_jobs=-1, verbose=100)(delayed(check_file)(record) for record in records)
|
22 |
+
failed = [result for result in results if result["error"] is not None]
|
23 |
+
if not failed:
|
24 |
+
print("All files are valid")
|
25 |
+
else:
|
26 |
+
saved_file = args.jsonl_file.with_suffix(".failed.jsonl")
|
27 |
+
print(f"{len(failed)} files are invalid, saved errors to {saved_file}")
|
28 |
+
with open(saved_file, "w") as f:
|
29 |
+
for record in failed:
|
30 |
+
f.write(json.dumps(record) + "\n")
|
31 |
+
|
32 |
+
|
33 |
+
def check_file(record: dict) -> dict:
|
34 |
+
path = Path(record["path"])
|
35 |
+
error = None
|
36 |
+
if not path.exists():
|
37 |
+
error = "File does not exist"
|
38 |
+
elif not path.is_file():
|
39 |
+
error = "Path is not a file"
|
40 |
+
elif path.suffix.lower() not in (".jpg", ".jpeg", ".png"):
|
41 |
+
error = "File is not an image file"
|
42 |
+
else:
|
43 |
+
try:
|
44 |
+
np.asarray(PIL.Image.open(path))
|
45 |
+
except Exception as e:
|
46 |
+
error = f"Image cannot be opened: {e}"
|
47 |
+
return dict(record, error=error)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
main()
|
realfake/bin/create_metadata.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Creates a meta-data file by combining the information from directory structure.
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from pydantic import Field
|
8 |
+
|
9 |
+
from realfake.config import IMAGE_FORMATS
|
10 |
+
from realfake.utils import inject_args, Args
|
11 |
+
|
12 |
+
|
13 |
+
class CreateMetadataArgs(Args):
|
14 |
+
root_dir: Path
|
15 |
+
datasets: str = Field(..., help="Comma-separated list of datasets to include in the meta-data file")
|
16 |
+
jsonl_file: Path = Field(..., help="Path to the output JSONL file")
|
17 |
+
|
18 |
+
|
19 |
+
@inject_args
|
20 |
+
def main(args: CreateMetadataArgs) -> None:
|
21 |
+
datasets = args.datasets.split(",")
|
22 |
+
records = []
|
23 |
+
for dataset in datasets:
|
24 |
+
label = "real" if dataset.startswith("real") else "fake"
|
25 |
+
dirpath = args.root_dir/dataset
|
26 |
+
assert dirpath.exists(), f"dataset dir does not exist: {dirpath}"
|
27 |
+
records.extend((parse_imagenet if "imagenet" in dataset else parse_flat)(dirpath, label))
|
28 |
+
with open(args.jsonl_file, "w") as f:
|
29 |
+
for record in records:
|
30 |
+
f.write(json.dumps(record) + "\n")
|
31 |
+
|
32 |
+
|
33 |
+
def parse_imagenet(dirpath: Path, label: str) -> list:
|
34 |
+
records = []
|
35 |
+
for classdir in dirpath.iterdir():
|
36 |
+
assert classdir.is_dir(), f"class directory is not a directory: {classdir}"
|
37 |
+
for fn in classdir.iterdir():
|
38 |
+
if fn.suffix.lower() in IMAGE_FORMATS:
|
39 |
+
records.append({"path": str(fn), "label": label, "class": classdir.name})
|
40 |
+
else:
|
41 |
+
print("Not an image file:", fn)
|
42 |
+
return records
|
43 |
+
|
44 |
+
|
45 |
+
def parse_flat(dirpath: Path, label: str) -> list:
|
46 |
+
records = []
|
47 |
+
for fn in dirpath.iterdir():
|
48 |
+
if fn.suffix.lower() in IMAGE_FORMATS:
|
49 |
+
records.append({"path": str(fn), "label": label, "class": None})
|
50 |
+
else:
|
51 |
+
print("Not an image file:", fn)
|
52 |
+
return records
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main()
|
realfake/bin/diffusion_db.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from hashlib import md5
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from realfake.utils import Args, inject_args
|
9 |
+
|
10 |
+
|
11 |
+
class DownloadParams(Args):
|
12 |
+
output_dir: Path
|
13 |
+
subset: str = "2m_first_1k"
|
14 |
+
|
15 |
+
|
16 |
+
@inject_args
|
17 |
+
def main(params: DownloadParams) -> None:
|
18 |
+
dataset = datasets.load_dataset("poloclub/diffusiondb", params.subset, split="train", streaming=True)
|
19 |
+
|
20 |
+
output_dir = params.output_dir/params.subset
|
21 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
22 |
+
|
23 |
+
with (output_dir/"test.jsonl").open("w") as fp:
|
24 |
+
for item in tqdm(dataset, total=None):
|
25 |
+
image_id = md5((item["prompt"] + str(item["seed"])).encode()).hexdigest()
|
26 |
+
filename = output_dir/f"{image_id}.png"
|
27 |
+
if not filename.exists():
|
28 |
+
item["image"].save(filename)
|
29 |
+
record = {"path": str(filename), "label": "fake", "class": None, "valid": False}
|
30 |
+
fp.write(f"{json.dumps(record)}\n")
|
31 |
+
|
32 |
+
print(f"Saved records to {output_dir}")
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
main()
|
realfake/bin/download_s3.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import tarfile
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import boto3
|
7 |
+
from joblib import Parallel, delayed
|
8 |
+
|
9 |
+
from realfake.utils import get_user_name
|
10 |
+
|
11 |
+
|
12 |
+
def main() -> None:
|
13 |
+
bucket, prefix = "s-datasets", "laion-aesthetic/data/laion2B-en-aesthetic/"
|
14 |
+
start_idx, end_idx = 400, 700
|
15 |
+
keys_range = list(range(start_idx, end_idx))
|
16 |
+
|
17 |
+
output_dir = Path(f"/fsx/{get_user_name()}/data/real_aes_{start_idx}_{end_idx}")
|
18 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
19 |
+
|
20 |
+
jobs = get_jobs(keys_range, bucket, prefix, output_dir)
|
21 |
+
|
22 |
+
Parallel(n_jobs=-1, backend="multiprocessing", verbose=100)(delayed(download_and_extract)(job) for job in jobs)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class Job:
|
27 |
+
bucket: str
|
28 |
+
key: Path
|
29 |
+
output_dir: Path
|
30 |
+
|
31 |
+
|
32 |
+
def get_jobs(keys_range: list, bucket: str, prefix: str, output_dir: Path) -> list[Job]:
|
33 |
+
client = boto3.client("s3")
|
34 |
+
|
35 |
+
token, jobs = None, []
|
36 |
+
|
37 |
+
while True:
|
38 |
+
conf = dict(Bucket=bucket, Prefix=prefix)
|
39 |
+
if token is not None: conf["ContinuationToken"] = token
|
40 |
+
response = client.list_objects_v2(**conf)
|
41 |
+
|
42 |
+
for item in response.get("Contents"):
|
43 |
+
key = Path(item["Key"])
|
44 |
+
if key.suffix == ".tar" and int(key.stem) in keys_range:
|
45 |
+
jobs.append(Job(bucket, key, output_dir))
|
46 |
+
|
47 |
+
if not response["IsTruncated"]: break
|
48 |
+
token = response["NextContinuationToken"]
|
49 |
+
|
50 |
+
return jobs
|
51 |
+
|
52 |
+
|
53 |
+
def download_and_extract(job: Job) -> None:
|
54 |
+
client = boto3.client("s3")
|
55 |
+
tar_file = job.output_dir / job.key.name
|
56 |
+
|
57 |
+
print(f"{job.key}: downloading...")
|
58 |
+
client.download_file(job.bucket, str(job.key), tar_file)
|
59 |
+
|
60 |
+
print(f"{job.key}: extracting...")
|
61 |
+
with tarfile.open(tar_file) as tar:
|
62 |
+
for name in tar.getnames():
|
63 |
+
if name.endswith(".jpg"):
|
64 |
+
tar.extract(name, job.output_dir)
|
65 |
+
|
66 |
+
print(f"{job.key}: done!")
|
67 |
+
tar_file.unlink()
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
main()
|
realfake/bin/imagenet.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Unpacks tar files from Imagenet-1k dataset while keeping the original directory structure.
|
3 |
+
|
4 |
+
The script only unpacks the files from training subset.
|
5 |
+
"""
|
6 |
+
import tarfile
|
7 |
+
from pathlib import Path
|
8 |
+
from joblib import delayed, Parallel
|
9 |
+
from realfake.utils import inject_args, Args
|
10 |
+
|
11 |
+
|
12 |
+
class ImagenetArgs(Args):
|
13 |
+
imagenet_dir: Path
|
14 |
+
unpacked_dir: Path
|
15 |
+
|
16 |
+
|
17 |
+
@inject_args
|
18 |
+
def main(args: ImagenetArgs) -> None:
|
19 |
+
train_dir = args.imagenet_dir/"train"
|
20 |
+
assert train_dir.exists(), f"Directory {train_dir} does not exist"
|
21 |
+
archives = train_dir.glob("*.tar")
|
22 |
+
Parallel(n_jobs=-1, verbose=100)(delayed(unpack_tar)(tar_file, args.unpacked_dir) for tar_file in archives)
|
23 |
+
|
24 |
+
|
25 |
+
def unpack_tar(tar_file: Path, output_dir: Path) -> None:
|
26 |
+
output_subdir = output_dir/tar_file.stem
|
27 |
+
with tarfile.open(tar_file) as tar:
|
28 |
+
tar.extractall(output_subdir)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
main()
|
realfake/bin/inference.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
from realfake.data import DictDataset, get_augs
|
9 |
+
from realfake.models import RealFakeClassifier, RealFakeParams
|
10 |
+
from realfake.utils import Args, inject_args, read_jsonl
|
11 |
+
|
12 |
+
|
13 |
+
class InferenceParams(Args):
|
14 |
+
checkpoint_path: Path
|
15 |
+
test_file: Path
|
16 |
+
map_location: str = "cpu"
|
17 |
+
num_workers: int = 16
|
18 |
+
|
19 |
+
|
20 |
+
@inject_args
|
21 |
+
def main(params: InferenceParams) -> None:
|
22 |
+
checkpoint = torch.load(params.checkpoint_path, map_location=params.map_location)
|
23 |
+
|
24 |
+
# todo: use PL mechanism to store hparams
|
25 |
+
model = RealFakeClassifier(RealFakeParams.parse_file(params.checkpoint_path.parent/"params.json"))
|
26 |
+
model.load_state_dict(checkpoint["state_dict"])
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
records = read_jsonl(params.test_file)
|
30 |
+
|
31 |
+
for _ in range(10):
|
32 |
+
selected = random.sample(records, k=1000)
|
33 |
+
|
34 |
+
with torch.inference_mode():
|
35 |
+
ds = DictDataset(selected, get_augs(train=False))
|
36 |
+
dl = DataLoader(ds, batch_size=32, num_workers=params.num_workers, shuffle=False)
|
37 |
+
matched, total = 0, len(ds)
|
38 |
+
|
39 |
+
for batch in dl:
|
40 |
+
_, logits, y_true_onehot = model(batch)
|
41 |
+
y_true = y_true_onehot.argmax(dim=1)
|
42 |
+
y_pred = logits.softmax(dim=1).argmax(dim=1)
|
43 |
+
matched += (y_true == y_pred).sum().item()
|
44 |
+
|
45 |
+
print(f"Accuracy: {matched/total:2.2%}")
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
realfake/bin/unpack_diffusion_db.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import zipfile
|
3 |
+
from itertools import chain
|
4 |
+
from pathlib import Path
|
5 |
+
from joblib import Parallel, delayed
|
6 |
+
from realfake.utils import get_user_name, inject_args, Args
|
7 |
+
|
8 |
+
|
9 |
+
class UnpackParams(Args):
|
10 |
+
meta_file: Path
|
11 |
+
jsonl_file: Path
|
12 |
+
num_workers: int = 16
|
13 |
+
|
14 |
+
|
15 |
+
def unpack(zip_path: Path, output_dir: Path):
|
16 |
+
print("extracting", zip_path)
|
17 |
+
with zipfile.ZipFile(zip_path, "r") as arch:
|
18 |
+
paths = [str(output_dir/fn) for fn in arch.namelist() if fn.endswith(".png")]
|
19 |
+
arch.extractall(output_dir)
|
20 |
+
return paths
|
21 |
+
|
22 |
+
|
23 |
+
@inject_args
|
24 |
+
def main(params: UnpackParams) -> None:
|
25 |
+
subset_name = params.meta_file.stem
|
26 |
+
output_dir = Path(f"/fsx/{get_user_name()}/data/fake_{subset_name}")
|
27 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
28 |
+
meta = json.loads(params.meta_file.read_text())
|
29 |
+
with Parallel(n_jobs=params.num_workers, verbose=100) as parallel:
|
30 |
+
results = parallel(delayed(unpack)(Path(m["path"]), output_dir) for m in meta if m["ok"])
|
31 |
+
records = [
|
32 |
+
{"path": str(fn), "label": "fake", "class": None, "valid": None}
|
33 |
+
for fn in chain.from_iterable(results)
|
34 |
+
]
|
35 |
+
with params.jsonl_file.open("w") as fp:
|
36 |
+
for record in records:
|
37 |
+
fp.write(json.dumps(record) + "\n")
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
realfake/callbacks.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.callbacks import BaseFinetuning, Callback
|
6 |
+
from pytorch_lightning.utilities import rank_zero_info
|
7 |
+
|
8 |
+
|
9 |
+
class ConsoleLogger(Callback):
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self._reset()
|
14 |
+
|
15 |
+
def get_history(self) -> list:
|
16 |
+
return list(self._history)
|
17 |
+
|
18 |
+
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
19 |
+
self._reset()
|
20 |
+
|
21 |
+
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
22 |
+
[lr] = trainer.lr_scheduler_configs[0].scheduler.get_last_lr() # type: ignore
|
23 |
+
log = {"epoch": trainer.current_epoch, "lr": lr}
|
24 |
+
log.update({name: tensor.item() for name, tensor in trainer.logged_metrics.items()})
|
25 |
+
self._history.append(log)
|
26 |
+
formatted = []
|
27 |
+
for key, value in log.items():
|
28 |
+
if isinstance(value, int):
|
29 |
+
kv = f"{key}={value:3d}"
|
30 |
+
elif isinstance(value, float):
|
31 |
+
kv = f"{key}={value:.4f}"
|
32 |
+
else:
|
33 |
+
kv = f"{key}={value}"
|
34 |
+
formatted.append(kv)
|
35 |
+
rank_zero_info(" | ".join(formatted))
|
36 |
+
|
37 |
+
def _reset(self):
|
38 |
+
self._history = []
|
39 |
+
|
40 |
+
|
41 |
+
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
|
42 |
+
|
43 |
+
def __init__(self, unfreeze_at_epoch: int):
|
44 |
+
super().__init__()
|
45 |
+
self._unfreeze_at_epoch = unfreeze_at_epoch
|
46 |
+
|
47 |
+
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
|
48 |
+
rank_zero_info("Freezing backbone")
|
49 |
+
self.freeze(_get_backbone(pl_module.model))
|
50 |
+
|
51 |
+
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
|
52 |
+
if epoch == self._unfreeze_at_epoch:
|
53 |
+
rank_zero_info(f"Unfreezing backbone at epoch {epoch}")
|
54 |
+
self.unfreeze_and_add_param_group(
|
55 |
+
modules=_get_backbone(pl_module.model),
|
56 |
+
optimizer=optimizer,
|
57 |
+
train_bn=True,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def _get_backbone(module: pl.LightningModule) -> Iterable[nn.Module]:
|
62 |
+
for name, child in module.named_children():
|
63 |
+
if name.startswith("head"):
|
64 |
+
continue
|
65 |
+
yield child
|
realfake/config.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SEED = 1
|
2 |
+
LABELS = {"real": 0, "fake": 1}
|
3 |
+
SUBSETS = ("train", "validation")
|
4 |
+
IMAGE_FORMATS = ".jpeg", ".jpg", ".png"
|
realfake/data.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import random
|
3 |
+
|
4 |
+
import albumentations as A
|
5 |
+
import numpy as np
|
6 |
+
import PIL.Image
|
7 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
|
10 |
+
from realfake.config import LABELS
|
11 |
+
|
12 |
+
IMG_RESIZE = 256
|
13 |
+
IMG_CROP = 224
|
14 |
+
|
15 |
+
|
16 |
+
class DictDataset(Dataset):
|
17 |
+
def __init__(self, records: list[dict], transform_x=None):
|
18 |
+
self.records = records
|
19 |
+
self.transform_x = transform_x
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.records)
|
23 |
+
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
record = self.records[idx]
|
26 |
+
image = np.asarray(PIL.Image.open(record["path"]))
|
27 |
+
if self.transform_x is not None:
|
28 |
+
image = self.transform_x(image=image)["image"]
|
29 |
+
item = {"image": image}
|
30 |
+
if "label" in record:
|
31 |
+
item["label"] = LABELS[record["label"]]
|
32 |
+
return item
|
33 |
+
|
34 |
+
|
35 |
+
def get_augs(train: bool = True) -> A.Compose:
|
36 |
+
if train:
|
37 |
+
return A.Compose([
|
38 |
+
A.Resize(IMG_RESIZE, IMG_RESIZE),
|
39 |
+
A.RandomCrop(IMG_CROP, IMG_CROP),
|
40 |
+
A.HorizontalFlip(),
|
41 |
+
A.VerticalFlip(),
|
42 |
+
A.RandomBrightnessContrast(),
|
43 |
+
A.Affine(),
|
44 |
+
A.Rotate(),
|
45 |
+
A.CoarseDropout(),
|
46 |
+
ExpandChannels(),
|
47 |
+
RGBAtoRGB(),
|
48 |
+
A.Normalize(),
|
49 |
+
ToTensorV2(),
|
50 |
+
])
|
51 |
+
else:
|
52 |
+
return A.Compose([
|
53 |
+
A.Resize(IMG_RESIZE, IMG_RESIZE),
|
54 |
+
A.CenterCrop(IMG_CROP, IMG_CROP),
|
55 |
+
ExpandChannels(),
|
56 |
+
RGBAtoRGB(),
|
57 |
+
A.Normalize(),
|
58 |
+
ToTensorV2(),
|
59 |
+
])
|
60 |
+
|
61 |
+
|
62 |
+
class ExpandChannels(A.ImageOnlyTransform):
|
63 |
+
"""Expands image up to three channes if the image is grayscale."""
|
64 |
+
|
65 |
+
def __init__(self, always_apply: bool = False, p: float = 0.5):
|
66 |
+
super().__init__(True, 1.0)
|
67 |
+
|
68 |
+
def apply(self, image, **params):
|
69 |
+
if image.ndim == 2:
|
70 |
+
image = np.repeat(image[..., None], 3, axis=2)
|
71 |
+
elif image.shape[2] == 1:
|
72 |
+
image = np.repeat(image, 3, axis=2)
|
73 |
+
return image
|
74 |
+
|
75 |
+
|
76 |
+
class RGBAtoRGB(A.ImageOnlyTransform):
|
77 |
+
"""Converts RGBA image to RGB."""
|
78 |
+
|
79 |
+
def __init__(self, always_apply: bool = False, p: float = 0.5):
|
80 |
+
super().__init__(True, 1.0)
|
81 |
+
|
82 |
+
def apply(self, image, **params):
|
83 |
+
if image.shape[2] == 4:
|
84 |
+
image = image[:, :, :3]
|
85 |
+
return image
|
86 |
+
|
87 |
+
|
88 |
+
def get_dss(records: list) -> tuple[DictDataset, DictDataset]:
|
89 |
+
train_records = [x for x in records if not x["valid"]]
|
90 |
+
valid_records = [x for x in records if x["valid"]]
|
91 |
+
assert len(train_records) + len(valid_records) == len(records)
|
92 |
+
random.shuffle(train_records)
|
93 |
+
train_ds = DictDataset(train_records, transform_x=get_augs(train=True))
|
94 |
+
valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False))
|
95 |
+
return train_ds, valid_ds
|
96 |
+
|
97 |
+
|
98 |
+
def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]:
|
99 |
+
train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers)
|
100 |
+
valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
|
101 |
+
return train_dl, valid_dl
|
realfake/models.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import timm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torchmetrics
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from realfake.data import get_dss, get_dls
|
13 |
+
from realfake.utils import Args
|
14 |
+
|
15 |
+
N_CLASSES = 2
|
16 |
+
|
17 |
+
|
18 |
+
class AcceleratorParams(BaseModel):
|
19 |
+
"""PyTorch Lightning accelerator parameters."""
|
20 |
+
|
21 |
+
name: str = Field("gpu")
|
22 |
+
devices: int = Field(4)
|
23 |
+
strategy: str = Field("dp")
|
24 |
+
precision: int = Field(16)
|
25 |
+
override_float32_matmul: bool = Field(True)
|
26 |
+
float32_matmul: str = Field("medium")
|
27 |
+
|
28 |
+
|
29 |
+
class RealFakeParams(Args):
|
30 |
+
jsonl_file: Path
|
31 |
+
dry_run: bool = Field(False)
|
32 |
+
model_name: str = Field("convnext_tiny")
|
33 |
+
batch_size: int = Field(256)
|
34 |
+
freeze_epochs: int = Field(3)
|
35 |
+
epochs: int = Field(6)
|
36 |
+
base_lr: float = Field(1e-3)
|
37 |
+
pretrained: bool = Field(True)
|
38 |
+
accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
|
39 |
+
|
40 |
+
|
41 |
+
class RealFakeDataModule(pl.LightningDataModule):
|
42 |
+
|
43 |
+
def __init__(self, jsonl_records: Path, batch_size: int, num_workers: int = 0):
|
44 |
+
super().__init__()
|
45 |
+
self.jsonl_records = jsonl_records
|
46 |
+
self.batch_size = batch_size
|
47 |
+
self.num_workers = num_workers
|
48 |
+
self.dss = self.dls = None
|
49 |
+
|
50 |
+
def setup(self, stage=None):
|
51 |
+
records = [json.loads(line) for line in self.jsonl_records.open()]
|
52 |
+
self.dss = get_dss(records)
|
53 |
+
self.dls = get_dls(*self.dss, self.batch_size, self.num_workers)
|
54 |
+
|
55 |
+
def train_dataloader(self):
|
56 |
+
return self.dls[0]
|
57 |
+
|
58 |
+
def val_dataloader(self):
|
59 |
+
return self.dls[1]
|
60 |
+
|
61 |
+
|
62 |
+
class RealFakeClassifier(pl.LightningModule):
|
63 |
+
|
64 |
+
def __init__(self, params: RealFakeParams):
|
65 |
+
super().__init__()
|
66 |
+
self.params = params
|
67 |
+
self.ce = nn.BCEWithLogitsLoss()
|
68 |
+
self.model = timm.create_model(params.model_name, pretrained=params.pretrained, num_classes=N_CLASSES)
|
69 |
+
self.acc = torchmetrics.Accuracy(task="binary")
|
70 |
+
|
71 |
+
def train_dataloader(self):
|
72 |
+
return self.dls.train
|
73 |
+
|
74 |
+
def val_dataloader(self):
|
75 |
+
return self.dls.valid
|
76 |
+
|
77 |
+
def forward(self, batch):
|
78 |
+
x, y = batch["image"], batch["label"]
|
79 |
+
y = torch.nn.functional.one_hot(y, num_classes=N_CLASSES).float()
|
80 |
+
out = self.model(x)
|
81 |
+
loss = self.ce(out, y)
|
82 |
+
return loss, out, y
|
83 |
+
|
84 |
+
def training_step(self, batch, batch_idx):
|
85 |
+
loss, _, _ = self.forward(batch)
|
86 |
+
self.log("train_loss", loss, on_epoch=True, on_step=False)
|
87 |
+
return loss
|
88 |
+
|
89 |
+
def validation_step(self, batch, batch_idx):
|
90 |
+
loss, out, y = self.forward(batch)
|
91 |
+
y_pred = out.sigmoid().argmax(dim=-1)
|
92 |
+
y_true = y.argmax(dim=-1)
|
93 |
+
self.log("val_loss", loss, on_epoch=True, on_step=False)
|
94 |
+
return {"gt": y_true, "yhat": y_pred}
|
95 |
+
|
96 |
+
def validation_step_end(self, outputs):
|
97 |
+
self.acc.update(outputs["yhat"], outputs["gt"])
|
98 |
+
|
99 |
+
def validation_epoch_end(self, outputs):
|
100 |
+
self.log("val_acc", self.acc.compute(), on_epoch=True)
|
101 |
+
self.acc.reset()
|
102 |
+
|
103 |
+
def configure_optimizers(self):
|
104 |
+
adamw = torch.optim.AdamW(self.parameters(), lr=self.params.base_lr)
|
105 |
+
one_cycle = torch.optim.lr_scheduler.OneCycleLR(
|
106 |
+
adamw,
|
107 |
+
max_lr=self.params.base_lr,
|
108 |
+
total_steps=self.trainer.estimated_stepping_batches
|
109 |
+
)
|
110 |
+
return [adamw], [one_cycle]
|
realfake/train.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import os
|
3 |
+
import signal
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9 |
+
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
10 |
+
|
11 |
+
from realfake.callbacks import ConsoleLogger
|
12 |
+
from realfake.models import RealFakeParams
|
13 |
+
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
|
14 |
+
|
15 |
+
|
16 |
+
def get_existing_checkpoint(job_id: str | None = None) -> tuple:
|
17 |
+
if job_id is None:
|
18 |
+
checkpoints_dir = get_checkpoints_dir(timestamp=True)
|
19 |
+
else:
|
20 |
+
checkpoints_dir = get_checkpoints_dir(timestamp=False)/job_id
|
21 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
22 |
+
existing_checkpoint = find_latest_checkpoint(checkpoints_dir)
|
23 |
+
return checkpoints_dir, existing_checkpoint
|
24 |
+
|
25 |
+
|
26 |
+
def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
|
27 |
+
job_id = os.environ.get("SLURM_JOB_ID")
|
28 |
+
checkpoints_dir, existing_checkpoint = get_existing_checkpoint(job_id)
|
29 |
+
|
30 |
+
if job_id is None:
|
31 |
+
print("SLURM job id is not found, running locally.")
|
32 |
+
|
33 |
+
if existing_checkpoint is None:
|
34 |
+
print("No existing checkpoint found, starting from scratch.")
|
35 |
+
|
36 |
+
if args.accelerator.override_float32_matmul:
|
37 |
+
torch.set_float32_matmul_precision(args.accelerator.float32_matmul)
|
38 |
+
|
39 |
+
with (checkpoints_dir/"params.json").open("w") as fp:
|
40 |
+
fp.write(args.json())
|
41 |
+
|
42 |
+
trainer_params = dict(
|
43 |
+
accelerator=args.accelerator.name,
|
44 |
+
devices=args.accelerator.devices,
|
45 |
+
precision=args.accelerator.precision,
|
46 |
+
max_epochs=args.epochs,
|
47 |
+
num_nodes=1,
|
48 |
+
num_sanity_val_steps=0,
|
49 |
+
enable_progress_bar=False,
|
50 |
+
callbacks=[
|
51 |
+
ConsoleLogger(),
|
52 |
+
ModelCheckpoint(
|
53 |
+
monitor="val_acc",
|
54 |
+
mode="max",
|
55 |
+
save_last=True,
|
56 |
+
save_top_k=1,
|
57 |
+
dirpath=checkpoints_dir,
|
58 |
+
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
|
59 |
+
),
|
60 |
+
],
|
61 |
+
resume_from_checkpoint=existing_checkpoint,
|
62 |
+
)
|
63 |
+
|
64 |
+
if job_id is not None:
|
65 |
+
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP),
|
66 |
+
trainer_params["strategy"] = args.accelerator.strategy
|
67 |
+
|
68 |
+
return pl.Trainer(**trainer_params)
|
realfake/train_cluster.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from realfake.config import SEED
|
6 |
+
from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams
|
7 |
+
from realfake.train import prepare_trainer
|
8 |
+
|
9 |
+
|
10 |
+
def main() -> None:
|
11 |
+
pl.seed_everything(SEED)
|
12 |
+
args = RealFakeParams.from_args()
|
13 |
+
model = RealFakeClassifier(args)
|
14 |
+
data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4)
|
15 |
+
trainer = prepare_trainer(args)
|
16 |
+
|
17 |
+
if args.dry_run:
|
18 |
+
print("Dry run, skipping training.")
|
19 |
+
print("Model summary:")
|
20 |
+
print(model)
|
21 |
+
print("Data summary:")
|
22 |
+
data.setup()
|
23 |
+
print("Train batches:", len(data.dls[0]))
|
24 |
+
print("Valid batches:", len(data.dls[1]))
|
25 |
+
|
26 |
+
else:
|
27 |
+
trainer.fit(model, datamodule=data)
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
with warnings.catch_warnings():
|
32 |
+
warnings.filterwarnings("ignore", category=Warning)
|
33 |
+
main()
|
realfake/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import argparse
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from operator import itemgetter
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Callable
|
9 |
+
|
10 |
+
import requests
|
11 |
+
import pynvml
|
12 |
+
import PIL.Image
|
13 |
+
import torch
|
14 |
+
from pydantic import BaseSettings, BaseModel
|
15 |
+
|
16 |
+
|
17 |
+
class Args(BaseSettings):
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def from_args(cls):
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
for field in cls.__fields__.values():
|
23 |
+
if issubclass(field.type_, BaseModel):
|
24 |
+
prefix = field.type_.__name__.lower()
|
25 |
+
for subfield in field.type_.__fields__.values():
|
26 |
+
short = "".join([x[0] for x in subfield.name.split("_")])
|
27 |
+
parser.add_argument(f"--{prefix}.{subfield.name}", default=subfield.default, required=subfield.required)
|
28 |
+
else:
|
29 |
+
short = "".join([x[0] for x in field.name.split("_")])
|
30 |
+
parser.add_argument(f"-{short}", f"--{field.name}", default=field.default, required=field.required)
|
31 |
+
args = vars(parser.parse_known_args()[0])
|
32 |
+
to_delete = set()
|
33 |
+
for field in cls.__fields__.values():
|
34 |
+
if issubclass(field.type_, BaseModel):
|
35 |
+
prefix = field.type_.__name__.lower()
|
36 |
+
sub_args = {}
|
37 |
+
for k, v in args.items():
|
38 |
+
if k.startswith(prefix):
|
39 |
+
to_delete.add(k)
|
40 |
+
sub_args[k.replace(f"{prefix}.", "")] = v
|
41 |
+
args[field.name] = sub_args
|
42 |
+
args = {k: v for k, v in args.items() if k not in to_delete}
|
43 |
+
return cls(**args)
|
44 |
+
|
45 |
+
class Config:
|
46 |
+
env_file = ".env"
|
47 |
+
env_file_encoding = "utf-8"
|
48 |
+
env_prefix = "ARG_"
|
49 |
+
|
50 |
+
|
51 |
+
def inject_args(func: Callable) -> Callable:
|
52 |
+
"""Decorates a function to inject the arguments."""
|
53 |
+
|
54 |
+
injected = None
|
55 |
+
for type_ in func.__annotations__.values():
|
56 |
+
if issubclass(type_, Args):
|
57 |
+
injected = type_.from_args()
|
58 |
+
break
|
59 |
+
|
60 |
+
if injected is None:
|
61 |
+
raise ValueError(f"Function {func.__name__} is not annotated with an Args subclass.")
|
62 |
+
|
63 |
+
def wrapper(*args, **kwargs):
|
64 |
+
return func(injected, *args, **kwargs)
|
65 |
+
|
66 |
+
return wrapper
|
67 |
+
|
68 |
+
|
69 |
+
def get_free_gpu() -> int:
|
70 |
+
pynvml.nvmlInit()
|
71 |
+
total = torch.cuda.device_count()
|
72 |
+
gpus = []
|
73 |
+
for i in range(total):
|
74 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
75 |
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
76 |
+
gpus.append((i, info.free))
|
77 |
+
gpus = sorted(gpus, key=itemgetter(1), reverse=True)
|
78 |
+
return gpus[0][0]
|
79 |
+
|
80 |
+
|
81 |
+
def get_user_name() -> str:
|
82 |
+
return Path(os.environ["HOME"]).stem
|
83 |
+
|
84 |
+
|
85 |
+
def get_storage_dir() -> Path:
|
86 |
+
return Path(f"/fsx/{get_user_name()}")
|
87 |
+
|
88 |
+
|
89 |
+
def get_checkpoints_dir(*, timestamp: bool) -> Path:
|
90 |
+
base_dir = get_storage_dir()/"checkpoints"
|
91 |
+
return Path(f"{base_dir}/{now()}") if timestamp else base_dir
|
92 |
+
|
93 |
+
|
94 |
+
def now() -> str:
|
95 |
+
return datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
96 |
+
|
97 |
+
|
98 |
+
def read_jsonl(path: Path) -> list:
|
99 |
+
return [json.loads(x) for x in Path(path).read_text().split("\n") if x]
|
100 |
+
|
101 |
+
|
102 |
+
def write_jsonl(path: Path, data: list):
|
103 |
+
with Path(path).open("w") as f:
|
104 |
+
for x in data:
|
105 |
+
f.write(json.dumps(x) + "\n")
|
106 |
+
|
107 |
+
|
108 |
+
def get_image(url: str, filename: Path | None = None):
|
109 |
+
if filename is None: filename = Path(f"{url.split('/')[-1]}.jpg")
|
110 |
+
filename = Path(filename)
|
111 |
+
if filename.exists(): return filename
|
112 |
+
PIL.Image.open(requests.get(url, stream=True).raw).save(filename)
|
113 |
+
return filename
|
114 |
+
|
115 |
+
|
116 |
+
def find_latest_checkpoint(dirname: Path) -> Path:
|
117 |
+
checkpoints = list(dirname.glob("*.ckpt"))
|
118 |
+
if not checkpoints:
|
119 |
+
return None
|
120 |
+
latest = max(checkpoints, key=lambda path: path.stat().st_mtime)
|
121 |
+
return latest
|
submit.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -l
|
2 |
+
|
3 |
+
# SLURM SUBMIT SCRIPT
|
4 |
+
#SBATCH --partition=g40
|
5 |
+
#SBATCH --nodes=1
|
6 |
+
#SBATCH --gpus=8
|
7 |
+
#SBATCH --cpus-per-gpu=6
|
8 |
+
#SBATCH --job-name=realfake
|
9 |
+
#SBATCH --comment=laion
|
10 |
+
#SBATCH --signal=SIGUSR1@90
|
11 |
+
|
12 |
+
source "${HOME}/venv/bin/activate"
|
13 |
+
|
14 |
+
export NCCL_DEBUG=INFO
|
15 |
+
export PYTHONFAULTHANDLER=1
|
16 |
+
export PYTHONPATH="${HOME}/realfake"
|
17 |
+
|
18 |
+
echo "Working directory: `pwd`"
|
19 |
+
|
20 |
+
srun python3 realfake/train_cluster.py \
|
21 |
+
-jf "${HOME}/realfake/metadata/prepared.2000k.jsonl" \
|
22 |
+
-mn convnext_large -e 5 -bs 128 \
|
23 |
+
--acceleratorparams.devices=8 \
|
24 |
+
--acceleratorparams.strategy=ddp_find_unused_parameters_false
|