Johannes Kolbe
commited on
Commit
·
ed6b6d6
1
Parent(s):
dd2f594
enable model loading from hf hub
Browse files- .gitignore +1 -0
- .ipynb_checkpoints/model_to_hf_hub-checkpoint.ipynb +255 -0
- app.py +2 -2
- interface.py +4 -5
- model_to_hf_hub.ipynb +297 -0
- models/model_zoo.py +5 -8
- models/pggan_generator.py +48 -2
- models/stylegan2_generator.py +44 -2
- models/stylegan_generator.py +49 -2
- utils.py +19 -13
.gitignore
CHANGED
@@ -20,6 +20,7 @@ __pycache__/
|
|
20 |
*.zip
|
21 |
events.*
|
22 |
|
|
|
23 |
*.pkl
|
24 |
*.h5
|
25 |
*.dat
|
|
|
20 |
*.zip
|
21 |
events.*
|
22 |
|
23 |
+
/checkpoints/
|
24 |
*.pkl
|
25 |
*.h5
|
26 |
*.dat
|
.ipynb_checkpoints/model_to_hf_hub-checkpoint.ipynb
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 15,
|
6 |
+
"metadata": {
|
7 |
+
"pycharm": {
|
8 |
+
"name": "#%%\n"
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"outputs": [],
|
12 |
+
"source": [
|
13 |
+
"import huggingface_hub\n",
|
14 |
+
"import utils"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 16,
|
20 |
+
"metadata": {
|
21 |
+
"pycharm": {
|
22 |
+
"name": "#%%\n"
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"data": {
|
28 |
+
"application/vnd.jupyter.widget-view+json": {
|
29 |
+
"model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
|
30 |
+
"version_major": 2,
|
31 |
+
"version_minor": 0
|
32 |
+
},
|
33 |
+
"text/plain": [
|
34 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "display_data"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"huggingface_hub.notebook_login()"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 13,
|
48 |
+
"metadata": {
|
49 |
+
"pycharm": {
|
50 |
+
"name": "#%%\n"
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"Building generator for model `stylegan_animeface512` ...\n",
|
59 |
+
"Finish building generator.\n",
|
60 |
+
"Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
|
61 |
+
"Finish loading checkpoint.\n"
|
62 |
+
]
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"source": [
|
66 |
+
"animeface_model = utils.load_generator('stylegan_animeface512')"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 5,
|
72 |
+
"metadata": {
|
73 |
+
"pycharm": {
|
74 |
+
"name": "#%%\n"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"outputs": [
|
78 |
+
{
|
79 |
+
"name": "stderr",
|
80 |
+
"output_type": "stream",
|
81 |
+
"text": [
|
82 |
+
"Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"data": {
|
87 |
+
"application/vnd.jupyter.widget-view+json": {
|
88 |
+
"model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
|
89 |
+
"version_major": 2,
|
90 |
+
"version_minor": 0
|
91 |
+
},
|
92 |
+
"text/plain": [
|
93 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
"metadata": {},
|
97 |
+
"output_type": "display_data"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "stderr",
|
101 |
+
"output_type": "stream",
|
102 |
+
"text": [
|
103 |
+
"To https://huggingface.co/johko/stylegan_animeface512\n",
|
104 |
+
" 750cd03..2841156 main -> main\n",
|
105 |
+
"\n"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"text/plain": [
|
111 |
+
"'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
"execution_count": 5,
|
115 |
+
"metadata": {},
|
116 |
+
"output_type": "execute_result"
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 11,
|
126 |
+
"metadata": {
|
127 |
+
"pycharm": {
|
128 |
+
"name": "#%%\n"
|
129 |
+
}
|
130 |
+
},
|
131 |
+
"outputs": [
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"Building generator for model `pggan_celebahq1024` ...\n",
|
137 |
+
"Finish building generator.\n",
|
138 |
+
"Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
|
139 |
+
"Finish loading checkpoint.\n"
|
140 |
+
]
|
141 |
+
}
|
142 |
+
],
|
143 |
+
"source": [
|
144 |
+
"celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": 7,
|
150 |
+
"metadata": {
|
151 |
+
"pycharm": {
|
152 |
+
"name": "#%%\n"
|
153 |
+
}
|
154 |
+
},
|
155 |
+
"outputs": [
|
156 |
+
{
|
157 |
+
"name": "stderr",
|
158 |
+
"output_type": "stream",
|
159 |
+
"text": [
|
160 |
+
"Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"data": {
|
165 |
+
"application/vnd.jupyter.widget-view+json": {
|
166 |
+
"model_id": "ef4086b23a654b079bd6a3678140c50d",
|
167 |
+
"version_major": 2,
|
168 |
+
"version_minor": 0
|
169 |
+
},
|
170 |
+
"text/plain": [
|
171 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
"metadata": {},
|
175 |
+
"output_type": "display_data"
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"name": "stderr",
|
179 |
+
"output_type": "stream",
|
180 |
+
"text": [
|
181 |
+
"To https://huggingface.co/johko/pggan-celebahq-1024\n",
|
182 |
+
" 780695e..278449f main -> main\n",
|
183 |
+
"\n"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"data": {
|
188 |
+
"text/plain": [
|
189 |
+
"'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
"execution_count": 7,
|
193 |
+
"metadata": {},
|
194 |
+
"output_type": "execute_result"
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"source": [
|
198 |
+
"celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 17,
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"name": "stdout",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"Building generator for model `stylegan_car512` ...\n",
|
211 |
+
"Finish building generator.\n",
|
212 |
+
"Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
|
213 |
+
"Finish loading checkpoint.\n"
|
214 |
+
]
|
215 |
+
}
|
216 |
+
],
|
217 |
+
"source": [
|
218 |
+
"cars_model = utils.load_generator(\"stylegan_car512\")"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": null,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [],
|
226 |
+
"source": [
|
227 |
+
"cars_model.push_to_hub(\"johko/stylegan_car512\")"
|
228 |
+
]
|
229 |
+
}
|
230 |
+
],
|
231 |
+
"metadata": {
|
232 |
+
"interpreter": {
|
233 |
+
"hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
|
234 |
+
},
|
235 |
+
"kernelspec": {
|
236 |
+
"display_name": "Python 3",
|
237 |
+
"language": "python",
|
238 |
+
"name": "python3"
|
239 |
+
},
|
240 |
+
"language_info": {
|
241 |
+
"codemirror_mode": {
|
242 |
+
"name": "ipython",
|
243 |
+
"version": 3
|
244 |
+
},
|
245 |
+
"file_extension": ".py",
|
246 |
+
"mimetype": "text/x-python",
|
247 |
+
"name": "python",
|
248 |
+
"nbconvert_exporter": "python",
|
249 |
+
"pygments_lexer": "ipython3",
|
250 |
+
"version": "3.9.9"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"nbformat": 4,
|
254 |
+
"nbformat_minor": 2
|
255 |
+
}
|
app.py
CHANGED
@@ -16,7 +16,7 @@ from utils import factorize_weight
|
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
def get_model(model_name):
|
18 |
"""Gets model by name."""
|
19 |
-
return load_generator(model_name)
|
20 |
|
21 |
|
22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
@@ -72,7 +72,7 @@ layer_idx = st.sidebar.selectbox(
|
|
72 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
73 |
|
74 |
num_semantics = st.sidebar.number_input(
|
75 |
-
'Number of semantics', value=
|
76 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
77 |
if gan_type == 'pggan':
|
78 |
max_step = 5.0
|
|
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
def get_model(model_name):
|
18 |
"""Gets model by name."""
|
19 |
+
return load_generator(model_name, from_hf_hub=True)
|
20 |
|
21 |
|
22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
|
72 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
73 |
|
74 |
num_semantics = st.sidebar.number_input(
|
75 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
76 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
77 |
if gan_type == 'pggan':
|
78 |
max_step = 5.0
|
interface.py
CHANGED
@@ -16,7 +16,7 @@ from utils import factorize_weight
|
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
def get_model(model_name):
|
18 |
"""Gets model by name."""
|
19 |
-
return load_generator(model_name)
|
20 |
|
21 |
|
22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
@@ -27,7 +27,7 @@ def factorize_model(model, layer_idx):
|
|
27 |
|
28 |
def sample(model, gan_type, num=1):
|
29 |
"""Samples latent codes."""
|
30 |
-
codes = torch.randn(num, model.z_space_dim)
|
31 |
if gan_type == 'pggan':
|
32 |
codes = model.layer0.pixel_norm(codes)
|
33 |
elif gan_type == 'stylegan':
|
@@ -63,8 +63,7 @@ def main():
|
|
63 |
|
64 |
model_name = st.sidebar.selectbox(
|
65 |
'Model to Interpret',
|
66 |
-
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256'
|
67 |
-
])
|
68 |
|
69 |
model = get_model(model_name)
|
70 |
gan_type = parse_gan_type(model)
|
@@ -74,7 +73,7 @@ def main():
|
|
74 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
75 |
|
76 |
num_semantics = st.sidebar.number_input(
|
77 |
-
'Number of semantics', value=
|
78 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
79 |
if gan_type == 'pggan':
|
80 |
max_step = 5.0
|
|
|
16 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
def get_model(model_name):
|
18 |
"""Gets model by name."""
|
19 |
+
return load_generator(model_name, from_hf_hub=True)
|
20 |
|
21 |
|
22 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
|
|
27 |
|
28 |
def sample(model, gan_type, num=1):
|
29 |
"""Samples latent codes."""
|
30 |
+
codes = torch.randn(num, model.z_space_dim)
|
31 |
if gan_type == 'pggan':
|
32 |
codes = model.layer0.pixel_norm(codes)
|
33 |
elif gan_type == 'stylegan':
|
|
|
63 |
|
64 |
model_name = st.sidebar.selectbox(
|
65 |
'Model to Interpret',
|
66 |
+
['pggan_celebahq1024', 'stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',])
|
|
|
67 |
|
68 |
model = get_model(model_name)
|
69 |
gan_type = parse_gan_type(model)
|
|
|
73 |
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
74 |
|
75 |
num_semantics = st.sidebar.number_input(
|
76 |
+
'Number of semantics', value=5, min_value=0, max_value=None, step=1)
|
77 |
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
78 |
if gan_type == 'pggan':
|
79 |
max_step = 5.0
|
model_to_hf_hub.ipynb
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 15,
|
6 |
+
"metadata": {
|
7 |
+
"pycharm": {
|
8 |
+
"name": "#%%\n"
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"outputs": [],
|
12 |
+
"source": [
|
13 |
+
"import huggingface_hub\n",
|
14 |
+
"import utils"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 16,
|
20 |
+
"metadata": {
|
21 |
+
"pycharm": {
|
22 |
+
"name": "#%%\n"
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"data": {
|
28 |
+
"application/vnd.jupyter.widget-view+json": {
|
29 |
+
"model_id": "525a0eaa021f4fdebd9138f4e7c5ab65",
|
30 |
+
"version_major": 2,
|
31 |
+
"version_minor": 0
|
32 |
+
},
|
33 |
+
"text/plain": [
|
34 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "display_data"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"huggingface_hub.notebook_login()"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 13,
|
48 |
+
"metadata": {
|
49 |
+
"pycharm": {
|
50 |
+
"name": "#%%\n"
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"Building generator for model `stylegan_animeface512` ...\n",
|
59 |
+
"Finish building generator.\n",
|
60 |
+
"Loading checkpoint from `checkpoints/stylegan_animeface512.pth` ...\n",
|
61 |
+
"Finish loading checkpoint.\n"
|
62 |
+
]
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"source": [
|
66 |
+
"animeface_model = utils.load_generator('stylegan_animeface512')"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 5,
|
72 |
+
"metadata": {
|
73 |
+
"pycharm": {
|
74 |
+
"name": "#%%\n"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"outputs": [
|
78 |
+
{
|
79 |
+
"name": "stderr",
|
80 |
+
"output_type": "stream",
|
81 |
+
"text": [
|
82 |
+
"Cloning https://huggingface.co/johko/stylegan_animeface512 into local empty directory.\n"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"data": {
|
87 |
+
"application/vnd.jupyter.widget-view+json": {
|
88 |
+
"model_id": "6e51c5ae4a504617aa0f1c1ac798ed15",
|
89 |
+
"version_major": 2,
|
90 |
+
"version_minor": 0
|
91 |
+
},
|
92 |
+
"text/plain": [
|
93 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/103M [00:00<?, ?B/s]"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
"metadata": {},
|
97 |
+
"output_type": "display_data"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"name": "stderr",
|
101 |
+
"output_type": "stream",
|
102 |
+
"text": [
|
103 |
+
"To https://huggingface.co/johko/stylegan_animeface512\n",
|
104 |
+
" 750cd03..2841156 main -> main\n",
|
105 |
+
"\n"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"text/plain": [
|
111 |
+
"'https://huggingface.co/johko/stylegan_animeface512/commit/2841156bad3c5a5f47f3edbf4a41880ea8fd3ad3'"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
"execution_count": 5,
|
115 |
+
"metadata": {},
|
116 |
+
"output_type": "execute_result"
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"animeface_model.push_to_hub(\"johko/stylegan_animeface512\")"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 11,
|
126 |
+
"metadata": {
|
127 |
+
"pycharm": {
|
128 |
+
"name": "#%%\n"
|
129 |
+
}
|
130 |
+
},
|
131 |
+
"outputs": [
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"Building generator for model `pggan_celebahq1024` ...\n",
|
137 |
+
"Finish building generator.\n",
|
138 |
+
"Loading checkpoint from `checkpoints/pggan_celebahq1024.pth` ...\n",
|
139 |
+
"Finish loading checkpoint.\n"
|
140 |
+
]
|
141 |
+
}
|
142 |
+
],
|
143 |
+
"source": [
|
144 |
+
"celebhq_model = utils.load_generator(\"pggan_celebahq1024\")"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": 7,
|
150 |
+
"metadata": {
|
151 |
+
"pycharm": {
|
152 |
+
"name": "#%%\n"
|
153 |
+
}
|
154 |
+
},
|
155 |
+
"outputs": [
|
156 |
+
{
|
157 |
+
"name": "stderr",
|
158 |
+
"output_type": "stream",
|
159 |
+
"text": [
|
160 |
+
"Cloning https://huggingface.co/johko/pggan-celebahq-1024 into local empty directory.\n"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"data": {
|
165 |
+
"application/vnd.jupyter.widget-view+json": {
|
166 |
+
"model_id": "ef4086b23a654b079bd6a3678140c50d",
|
167 |
+
"version_major": 2,
|
168 |
+
"version_minor": 0
|
169 |
+
},
|
170 |
+
"text/plain": [
|
171 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/88.1M [00:00<?, ?B/s]"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
"metadata": {},
|
175 |
+
"output_type": "display_data"
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"name": "stderr",
|
179 |
+
"output_type": "stream",
|
180 |
+
"text": [
|
181 |
+
"To https://huggingface.co/johko/pggan-celebahq-1024\n",
|
182 |
+
" 780695e..278449f main -> main\n",
|
183 |
+
"\n"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"data": {
|
188 |
+
"text/plain": [
|
189 |
+
"'https://huggingface.co/johko/pggan-celebahq-1024/commit/278449f8416d38a0233c980774528d32c4eee99c'"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
"execution_count": 7,
|
193 |
+
"metadata": {},
|
194 |
+
"output_type": "execute_result"
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"source": [
|
198 |
+
"celebhq_model.push_to_hub(\"johko/pggan-celebahq-1024\")"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 17,
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"name": "stdout",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"Building generator for model `stylegan_car512` ...\n",
|
211 |
+
"Finish building generator.\n",
|
212 |
+
"Loading checkpoint from `checkpoints/stylegan_car512.pth` ...\n",
|
213 |
+
"Finish loading checkpoint.\n"
|
214 |
+
]
|
215 |
+
}
|
216 |
+
],
|
217 |
+
"source": [
|
218 |
+
"cars_model = utils.load_generator(\"stylegan_car512\")"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 21,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [
|
226 |
+
{
|
227 |
+
"name": "stdout",
|
228 |
+
"output_type": "stream",
|
229 |
+
"text": [
|
230 |
+
"Building generator for model `stylegan_cat256` ...\n",
|
231 |
+
"Finish building generator.\n",
|
232 |
+
"Loading checkpoint from `checkpoints/stylegan_cat256.pth` ...\n",
|
233 |
+
"Finish loading checkpoint.\n"
|
234 |
+
]
|
235 |
+
}
|
236 |
+
],
|
237 |
+
"source": [
|
238 |
+
"cats_model = utils.load_generator(\"stylegan_cat256\")"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": null,
|
244 |
+
"metadata": {},
|
245 |
+
"outputs": [
|
246 |
+
{
|
247 |
+
"name": "stderr",
|
248 |
+
"output_type": "stream",
|
249 |
+
"text": [
|
250 |
+
"Cloning https://huggingface.co/johko/stylegan_cat256 into local empty directory.\n"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"data": {
|
255 |
+
"application/vnd.jupyter.widget-view+json": {
|
256 |
+
"model_id": "651e9bff9c9f4555814171195e36d4d3",
|
257 |
+
"version_major": 2,
|
258 |
+
"version_minor": 0
|
259 |
+
},
|
260 |
+
"text/plain": [
|
261 |
+
"Upload file pytorch_model.bin: 0%| | 32.0k/100M [00:00<?, ?B/s]"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
"metadata": {},
|
265 |
+
"output_type": "display_data"
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"source": [
|
269 |
+
"cats_model.push_to_hub(\"johko/stylegan_cat256\")"
|
270 |
+
]
|
271 |
+
}
|
272 |
+
],
|
273 |
+
"metadata": {
|
274 |
+
"interpreter": {
|
275 |
+
"hash": "a8d699d01f596cc27ac2722fbc0550b939d217978c7e1ca888dca7ba146ee4bf"
|
276 |
+
},
|
277 |
+
"kernelspec": {
|
278 |
+
"display_name": "Python 3",
|
279 |
+
"language": "python",
|
280 |
+
"name": "python3"
|
281 |
+
},
|
282 |
+
"language_info": {
|
283 |
+
"codemirror_mode": {
|
284 |
+
"name": "ipython",
|
285 |
+
"version": 3
|
286 |
+
},
|
287 |
+
"file_extension": ".py",
|
288 |
+
"mimetype": "text/x-python",
|
289 |
+
"name": "python",
|
290 |
+
"nbconvert_exporter": "python",
|
291 |
+
"pygments_lexer": "ipython3",
|
292 |
+
"version": "3.9.9"
|
293 |
+
}
|
294 |
+
},
|
295 |
+
"nbformat": 4,
|
296 |
+
"nbformat_minor": 2
|
297 |
+
}
|
models/model_zoo.py
CHANGED
@@ -9,6 +9,7 @@ MODEL_ZOO = {
|
|
9 |
gan_type='pggan',
|
10 |
resolution=1024,
|
11 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
|
|
|
12 |
),
|
13 |
'pggan_bedroom256': dict(
|
14 |
gan_type='pggan',
|
@@ -181,11 +182,13 @@ MODEL_ZOO = {
|
|
181 |
gan_type='stylegan',
|
182 |
resolution=256,
|
183 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
|
|
|
184 |
),
|
185 |
'stylegan_car512': dict(
|
186 |
gan_type='stylegan',
|
187 |
resolution=512,
|
188 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
|
|
|
189 |
),
|
190 |
|
191 |
# StyleGAN ours.
|
@@ -260,6 +263,7 @@ MODEL_ZOO = {
|
|
260 |
gan_type='stylegan',
|
261 |
resolution=512,
|
262 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
|
|
|
263 |
),
|
264 |
'stylegan_animeportrait512': dict(
|
265 |
gan_type='stylegan',
|
@@ -296,15 +300,8 @@ MODEL_ZOO = {
|
|
296 |
'stylegan2_car512': dict(
|
297 |
gan_type='stylegan2',
|
298 |
resolution=512,
|
299 |
-
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1'
|
300 |
),
|
301 |
-
|
302 |
-
#huggingface models
|
303 |
-
'akhaliq/OneshotCLIP-stylegan2-ffhq' : dict(
|
304 |
-
gan_type='stylegan2',
|
305 |
-
resolution=512,
|
306 |
-
url='akhaliq/OneshotCLIP-stylegan2-ffhq',
|
307 |
-
)
|
308 |
}
|
309 |
|
310 |
# pylint: enable=line-too-long
|
|
|
9 |
gan_type='pggan',
|
10 |
resolution=1024,
|
11 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1',
|
12 |
+
hf_hub_repo='huggan/pggan-celebahq-1024'
|
13 |
),
|
14 |
'pggan_bedroom256': dict(
|
15 |
gan_type='pggan',
|
|
|
182 |
gan_type='stylegan',
|
183 |
resolution=256,
|
184 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1',
|
185 |
+
hf_hub_repo="huggan/stylegan_cat256"
|
186 |
),
|
187 |
'stylegan_car512': dict(
|
188 |
gan_type='stylegan',
|
189 |
resolution=512,
|
190 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1',
|
191 |
+
hf_hub_repo="huggan/stylegan_car512"
|
192 |
),
|
193 |
|
194 |
# StyleGAN ours.
|
|
|
263 |
gan_type='stylegan',
|
264 |
resolution=512,
|
265 |
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1',
|
266 |
+
hf_hub_repo='huggan/stylegan_animeface512'
|
267 |
),
|
268 |
'stylegan_animeportrait512': dict(
|
269 |
gan_type='stylegan',
|
|
|
300 |
'stylegan2_car512': dict(
|
301 |
gan_type='stylegan2',
|
302 |
resolution=512,
|
303 |
+
url='https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1'
|
304 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
}
|
306 |
|
307 |
# pylint: enable=line-too-long
|
models/pggan_generator.py
CHANGED
@@ -6,6 +6,7 @@ Paper: https://arxiv.org/pdf/1710.10196.pdf
|
|
6 |
Official TensorFlow implementation:
|
7 |
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
"""
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
|
@@ -13,6 +14,8 @@ import torch
|
|
13 |
import torch.nn as nn
|
14 |
import torch.nn.functional as F
|
15 |
|
|
|
|
|
16 |
__all__ = ['PGGANGenerator']
|
17 |
|
18 |
# Resolutions allowed.
|
@@ -25,7 +28,7 @@ _INIT_RES = 4
|
|
25 |
_WSCALE_GAIN = np.sqrt(2.0)
|
26 |
|
27 |
|
28 |
-
class PGGANGenerator(nn.Module):
|
29 |
"""Defines the generator network in PGGAN.
|
30 |
|
31 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
@@ -57,7 +60,8 @@ class PGGANGenerator(nn.Module):
|
|
57 |
fused_scale=False,
|
58 |
use_wscale=True,
|
59 |
fmaps_base=16 << 10,
|
60 |
-
fmaps_max=512
|
|
|
61 |
"""Initializes with basic settings.
|
62 |
|
63 |
Raises:
|
@@ -81,6 +85,8 @@ class PGGANGenerator(nn.Module):
|
|
81 |
self.use_wscale = use_wscale
|
82 |
self.fmaps_base = fmaps_base
|
83 |
self.fmaps_max = fmaps_max
|
|
|
|
|
84 |
|
85 |
# Number of convolutional layers.
|
86 |
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
@@ -202,6 +208,46 @@ class PGGANGenerator(nn.Module):
|
|
202 |
}
|
203 |
return results
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
class PixelNormLayer(nn.Module):
|
207 |
"""Implements pixel-wise feature vector normalization layer."""
|
|
|
6 |
Official TensorFlow implementation:
|
7 |
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
"""
|
9 |
+
import os
|
10 |
|
11 |
import numpy as np
|
12 |
|
|
|
14 |
import torch.nn as nn
|
15 |
import torch.nn.functional as F
|
16 |
|
17 |
+
from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
|
18 |
+
|
19 |
__all__ = ['PGGANGenerator']
|
20 |
|
21 |
# Resolutions allowed.
|
|
|
28 |
_WSCALE_GAIN = np.sqrt(2.0)
|
29 |
|
30 |
|
31 |
+
class PGGANGenerator(nn.Module, PyTorchModelHubMixin):
|
32 |
"""Defines the generator network in PGGAN.
|
33 |
|
34 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
60 |
fused_scale=False,
|
61 |
use_wscale=True,
|
62 |
fmaps_base=16 << 10,
|
63 |
+
fmaps_max=512,
|
64 |
+
**kwargs):
|
65 |
"""Initializes with basic settings.
|
66 |
|
67 |
Raises:
|
|
|
85 |
self.use_wscale = use_wscale
|
86 |
self.fmaps_base = fmaps_base
|
87 |
self.fmaps_max = fmaps_max
|
88 |
+
|
89 |
+
self.config = kwargs.pop("config", None)
|
90 |
|
91 |
# Number of convolutional layers.
|
92 |
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
|
|
208 |
}
|
209 |
return results
|
210 |
|
211 |
+
@classmethod
|
212 |
+
def _from_pretrained(
|
213 |
+
cls,
|
214 |
+
model_id,
|
215 |
+
revision,
|
216 |
+
cache_dir,
|
217 |
+
force_download,
|
218 |
+
proxies,
|
219 |
+
resume_download,
|
220 |
+
local_files_only,
|
221 |
+
use_auth_token,
|
222 |
+
map_location="cpu",
|
223 |
+
strict=False,
|
224 |
+
**model_kwargs,
|
225 |
+
):
|
226 |
+
"""
|
227 |
+
Overwrite this method in case you wish to initialize your model in a
|
228 |
+
different way.
|
229 |
+
"""
|
230 |
+
map_location = torch.device(map_location)
|
231 |
+
|
232 |
+
if os.path.isdir(model_id):
|
233 |
+
print("Loading weights from local directory")
|
234 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
235 |
+
else:
|
236 |
+
model_file = hf_hub_download(
|
237 |
+
repo_id=model_id,
|
238 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
239 |
+
revision=revision,
|
240 |
+
cache_dir=cache_dir,
|
241 |
+
force_download=force_download,
|
242 |
+
proxies=proxies,
|
243 |
+
resume_download=resume_download,
|
244 |
+
use_auth_token=use_auth_token,
|
245 |
+
local_files_only=local_files_only,
|
246 |
+
)
|
247 |
+
|
248 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
249 |
+
return pretrained
|
250 |
+
|
251 |
|
252 |
class PixelNormLayer(nn.Module):
|
253 |
"""Implements pixel-wise feature vector normalization layer."""
|
models/stylegan2_generator.py
CHANGED
@@ -9,12 +9,14 @@ Paper: https://arxiv.org/pdf/1912.04958.pdf
|
|
9 |
|
10 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
"""
|
|
|
12 |
|
13 |
import numpy as np
|
14 |
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.nn.functional as F
|
|
|
18 |
|
19 |
from .sync_op import all_gather
|
20 |
|
@@ -33,7 +35,7 @@ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
|
33 |
_WSCALE_GAIN = 1.0
|
34 |
|
35 |
|
36 |
-
class StyleGAN2Generator(nn.Module):
|
37 |
"""Defines the generator network in StyleGAN2.
|
38 |
|
39 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
@@ -88,7 +90,8 @@ class StyleGAN2Generator(nn.Module):
|
|
88 |
demodulate=True,
|
89 |
use_wscale=True,
|
90 |
fmaps_base=32 << 10,
|
91 |
-
fmaps_max=512
|
|
|
92 |
"""Initializes with basic settings.
|
93 |
|
94 |
Raises:
|
@@ -195,6 +198,45 @@ class StyleGAN2Generator(nn.Module):
|
|
195 |
|
196 |
return {**mapping_results, **synthesis_results}
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
class MappingModule(nn.Module):
|
200 |
"""Implements the latent space mapping module.
|
|
|
9 |
|
10 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
"""
|
12 |
+
import os
|
13 |
|
14 |
import numpy as np
|
15 |
|
16 |
import torch
|
17 |
import torch.nn as nn
|
18 |
import torch.nn.functional as F
|
19 |
+
from huggingface_hub import PYTORCH_WEIGHTS_NAME, hf_hub_download, PyTorchModelHubMixin
|
20 |
|
21 |
from .sync_op import all_gather
|
22 |
|
|
|
35 |
_WSCALE_GAIN = 1.0
|
36 |
|
37 |
|
38 |
+
class StyleGAN2Generator(nn.Module, PyTorchModelHubMixin):
|
39 |
"""Defines the generator network in StyleGAN2.
|
40 |
|
41 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
90 |
demodulate=True,
|
91 |
use_wscale=True,
|
92 |
fmaps_base=32 << 10,
|
93 |
+
fmaps_max=512,
|
94 |
+
**kwargs):
|
95 |
"""Initializes with basic settings.
|
96 |
|
97 |
Raises:
|
|
|
198 |
|
199 |
return {**mapping_results, **synthesis_results}
|
200 |
|
201 |
+
@classmethod
|
202 |
+
def _from_pretrained(
|
203 |
+
cls,
|
204 |
+
model_id,
|
205 |
+
revision,
|
206 |
+
cache_dir,
|
207 |
+
force_download,
|
208 |
+
proxies,
|
209 |
+
resume_download,
|
210 |
+
local_files_only,
|
211 |
+
use_auth_token,
|
212 |
+
map_location="cpu",
|
213 |
+
strict=False,
|
214 |
+
**model_kwargs,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Overwrite this method in case you wish to initialize your model in a
|
218 |
+
different way.
|
219 |
+
"""
|
220 |
+
map_location = torch.device(map_location)
|
221 |
+
|
222 |
+
if os.path.isdir(model_id):
|
223 |
+
print("Loading weights from local directory")
|
224 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
225 |
+
else:
|
226 |
+
model_file = hf_hub_download(
|
227 |
+
repo_id=model_id,
|
228 |
+
filename="stylegan2-ffhq-config-f.pt",
|
229 |
+
revision=revision,
|
230 |
+
cache_dir=cache_dir,
|
231 |
+
force_download=force_download,
|
232 |
+
proxies=proxies,
|
233 |
+
resume_download=resume_download,
|
234 |
+
use_auth_token=use_auth_token,
|
235 |
+
local_files_only=local_files_only,
|
236 |
+
)
|
237 |
+
|
238 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
239 |
+
return pretrained
|
240 |
|
241 |
class MappingModule(nn.Module):
|
242 |
"""Implements the latent space mapping module.
|
models/stylegan_generator.py
CHANGED
@@ -5,6 +5,7 @@ Paper: https://arxiv.org/pdf/1812.04948.pdf
|
|
5 |
|
6 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
"""
|
|
|
8 |
|
9 |
import numpy as np
|
10 |
|
@@ -14,6 +15,8 @@ import torch.nn.functional as F
|
|
14 |
|
15 |
from .sync_op import all_gather
|
16 |
|
|
|
|
|
17 |
__all__ = ['StyleGANGenerator']
|
18 |
|
19 |
# Resolutions allowed.
|
@@ -33,7 +36,7 @@ _WSCALE_GAIN = np.sqrt(2.0)
|
|
33 |
_STYLEMOD_WSCALE_GAIN = 1.0
|
34 |
|
35 |
|
36 |
-
class StyleGANGenerator(nn.Module):
|
37 |
"""Defines the generator network in StyleGAN.
|
38 |
|
39 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
@@ -83,7 +86,8 @@ class StyleGANGenerator(nn.Module):
|
|
83 |
fused_scale='auto',
|
84 |
use_wscale=True,
|
85 |
fmaps_base=16 << 10,
|
86 |
-
fmaps_max=512
|
|
|
87 |
"""Initializes with basic settings.
|
88 |
|
89 |
Raises:
|
@@ -115,6 +119,9 @@ class StyleGANGenerator(nn.Module):
|
|
115 |
self.use_wscale = use_wscale
|
116 |
self.fmaps_base = fmaps_base
|
117 |
self.fmaps_max = fmaps_max
|
|
|
|
|
|
|
118 |
|
119 |
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
120 |
|
@@ -188,6 +195,46 @@ class StyleGANGenerator(nn.Module):
|
|
188 |
|
189 |
return {**mapping_results, **synthesis_results}
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
class MappingModule(nn.Module):
|
193 |
"""Implements the latent space mapping module.
|
|
|
5 |
|
6 |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
"""
|
8 |
+
import os
|
9 |
|
10 |
import numpy as np
|
11 |
|
|
|
15 |
|
16 |
from .sync_op import all_gather
|
17 |
|
18 |
+
from huggingface_hub import PyTorchModelHubMixin, PYTORCH_WEIGHTS_NAME, hf_hub_download
|
19 |
+
|
20 |
__all__ = ['StyleGANGenerator']
|
21 |
|
22 |
# Resolutions allowed.
|
|
|
36 |
_STYLEMOD_WSCALE_GAIN = 1.0
|
37 |
|
38 |
|
39 |
+
class StyleGANGenerator(nn.Module, PyTorchModelHubMixin):
|
40 |
"""Defines the generator network in StyleGAN.
|
41 |
|
42 |
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
|
|
86 |
fused_scale='auto',
|
87 |
use_wscale=True,
|
88 |
fmaps_base=16 << 10,
|
89 |
+
fmaps_max=512,
|
90 |
+
**kwargs):
|
91 |
"""Initializes with basic settings.
|
92 |
|
93 |
Raises:
|
|
|
119 |
self.use_wscale = use_wscale
|
120 |
self.fmaps_base = fmaps_base
|
121 |
self.fmaps_max = fmaps_max
|
122 |
+
|
123 |
+
self.config = kwargs.pop("config", None)
|
124 |
+
|
125 |
|
126 |
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
127 |
|
|
|
195 |
|
196 |
return {**mapping_results, **synthesis_results}
|
197 |
|
198 |
+
@classmethod
|
199 |
+
def _from_pretrained(
|
200 |
+
cls,
|
201 |
+
model_id,
|
202 |
+
revision,
|
203 |
+
cache_dir,
|
204 |
+
force_download,
|
205 |
+
proxies,
|
206 |
+
resume_download,
|
207 |
+
local_files_only,
|
208 |
+
use_auth_token,
|
209 |
+
map_location="cpu",
|
210 |
+
strict=False,
|
211 |
+
**model_kwargs,
|
212 |
+
):
|
213 |
+
"""
|
214 |
+
Overwrite this method in case you wish to initialize your model in a
|
215 |
+
different way.
|
216 |
+
"""
|
217 |
+
map_location = torch.device(map_location)
|
218 |
+
|
219 |
+
if os.path.isdir(model_id):
|
220 |
+
print("Loading weights from local directory")
|
221 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
222 |
+
else:
|
223 |
+
model_file = hf_hub_download(
|
224 |
+
repo_id=model_id,
|
225 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
226 |
+
revision=revision,
|
227 |
+
cache_dir=cache_dir,
|
228 |
+
force_download=force_download,
|
229 |
+
proxies=proxies,
|
230 |
+
resume_download=resume_download,
|
231 |
+
use_auth_token=use_auth_token,
|
232 |
+
local_files_only=local_files_only,
|
233 |
+
)
|
234 |
+
|
235 |
+
pretrained = torch.load(model_file, map_location=map_location)
|
236 |
+
return pretrained
|
237 |
+
|
238 |
|
239 |
class MappingModule(nn.Module):
|
240 |
"""Implements the latent space mapping module.
|
utils.py
CHANGED
@@ -50,7 +50,7 @@ def postprocess(images, min_val=-1.0, max_val=1.0):
|
|
50 |
return images
|
51 |
|
52 |
|
53 |
-
def load_generator(model_name):
|
54 |
"""Loads pre-trained generator.
|
55 |
|
56 |
Args:
|
@@ -74,19 +74,25 @@ def load_generator(model_name):
|
|
74 |
generator = build_generator(**model_config)
|
75 |
print(f'Finish building generator.')
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
if not os.path.exists(checkpoint_path):
|
82 |
-
print(f' Downloading checkpoint from `{url}` ...')
|
83 |
-
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
84 |
-
print(f' Finish downloading checkpoint.')
|
85 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
86 |
-
if 'generator_smooth' in checkpoint:
|
87 |
-
generator.load_state_dict(checkpoint['generator_smooth'])
|
88 |
else:
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
#generator = generator.cuda()
|
91 |
generator.eval()
|
92 |
print(f'Finish loading checkpoint.')
|
|
|
50 |
return images
|
51 |
|
52 |
|
53 |
+
def load_generator(model_name, from_hf_hub=False):
|
54 |
"""Loads pre-trained generator.
|
55 |
|
56 |
Args:
|
|
|
74 |
generator = build_generator(**model_config)
|
75 |
print(f'Finish building generator.')
|
76 |
|
77 |
+
if from_hf_hub and "hf_hub_repo" in model_config.keys():
|
78 |
+
checkpoint = generator.from_pretrained(model_config["hf_hub_repo"])
|
79 |
+
generator.load_state_dict(checkpoint)
|
80 |
+
print("loaded from hf_hub")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
else:
|
82 |
+
# Load pre-trained weights.
|
83 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
84 |
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
|
85 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
86 |
+
if not os.path.exists(checkpoint_path):
|
87 |
+
print(f' Downloading checkpoint from `{url}` ...')
|
88 |
+
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
89 |
+
print(f' Finish downloading checkpoint.')
|
90 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
91 |
+
|
92 |
+
if 'generator_smooth' in checkpoint:
|
93 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
94 |
+
else:
|
95 |
+
generator.load_state_dict(checkpoint['generator'])
|
96 |
#generator = generator.cuda()
|
97 |
generator.eval()
|
98 |
print(f'Finish loading checkpoint.')
|