watermark and copyright
Browse files- visualizer_drag_gradio.py +17 -1
- viz/renderer.py +109 -80
visualizer_drag_gradio.py
CHANGED
@@ -183,7 +183,7 @@ with gr.Blocks() as app:
|
|
183 |
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
|
184 |
|
185 |
* Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
|
186 |
-
* Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996)
|
187 |
""")
|
188 |
|
189 |
# renderer = Renderer()
|
@@ -355,6 +355,22 @@ with gr.Blocks() as app:
|
|
355 |
unmasked region to remain unchanged.
|
356 |
|
357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
""")
|
359 |
# Network & latents tab listeners
|
360 |
|
|
|
183 |
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
|
184 |
|
185 |
* Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
|
186 |
+
* Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) © [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
|
187 |
""")
|
188 |
|
189 |
# renderer = Renderer()
|
|
|
355 |
unmasked region to remain unchanged.
|
356 |
|
357 |
|
358 |
+
""")
|
359 |
+
gr.HTML("""
|
360 |
+
<style>
|
361 |
+
.container {
|
362 |
+
position: absolute;
|
363 |
+
height: 50px;
|
364 |
+
text-align: center;
|
365 |
+
line-height: 50px;
|
366 |
+
width: 100%;
|
367 |
+
}
|
368 |
+
</style>
|
369 |
+
<div class="container">
|
370 |
+
Gradio demo supported by
|
371 |
+
<img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
|
372 |
+
<a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
|
373 |
+
</div>
|
374 |
""")
|
375 |
# Network & latents tab listeners
|
376 |
|
viz/renderer.py
CHANGED
@@ -20,9 +20,10 @@ import torch.nn.functional as F
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
-
import legacy
|
|
|
|
|
24 |
|
25 |
-
#----------------------------------------------------------------------------
|
26 |
|
27 |
class CapturedException(Exception):
|
28 |
def __init__(self, msg=None):
|
@@ -36,16 +37,18 @@ class CapturedException(Exception):
|
|
36 |
assert isinstance(msg, str)
|
37 |
super().__init__(msg)
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
class CaptureSuccess(Exception):
|
42 |
def __init__(self, out):
|
43 |
super().__init__()
|
44 |
self.out = out
|
45 |
|
46 |
-
|
|
|
47 |
|
48 |
-
def add_watermark_np(input_image_array, watermark_text="
|
49 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
50 |
|
51 |
# Initialize text image
|
@@ -54,8 +57,10 @@ def add_watermark_np(input_image_array, watermark_text="Watermark"):
|
|
54 |
d = ImageDraw.Draw(txt)
|
55 |
|
56 |
text_width, text_height = font.getsize(watermark_text)
|
57 |
-
text_position = (image.size[0] - text_width -
|
58 |
-
|
|
|
|
|
59 |
|
60 |
# Draw the text onto the text canvas
|
61 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
@@ -65,21 +70,22 @@ def add_watermark_np(input_image_array, watermark_text="Watermark"):
|
|
65 |
watermarked_array = np.array(watermarked)
|
66 |
return watermarked_array
|
67 |
|
68 |
-
|
|
|
69 |
|
70 |
class Renderer:
|
71 |
def __init__(self, disable_timing=False):
|
72 |
-
self._device
|
73 |
-
self._pkl_data
|
74 |
-
self._networks
|
75 |
-
self._pinned_bufs
|
76 |
-
self._cmaps
|
77 |
-
self._is_timing
|
78 |
if not disable_timing:
|
79 |
-
self._start_event
|
80 |
-
self._end_event
|
81 |
self._disable_timing = disable_timing
|
82 |
-
self._net_layers
|
83 |
|
84 |
def render(self, **args):
|
85 |
if self._disable_timing:
|
@@ -122,7 +128,8 @@ class Renderer:
|
|
122 |
|
123 |
if self._is_timing and not self._disable_timing:
|
124 |
self._end_event.synchronize()
|
125 |
-
res.render_time = self._start_event.elapsed_time(
|
|
|
126 |
self._is_timing = False
|
127 |
return res
|
128 |
|
@@ -143,7 +150,8 @@ class Renderer:
|
|
143 |
raise data
|
144 |
|
145 |
orig_net = data[key]
|
146 |
-
cache_key = (orig_net, self._device, tuple(
|
|
|
147 |
net = self._networks.get(cache_key, None)
|
148 |
if net is None:
|
149 |
try:
|
@@ -159,9 +167,11 @@ class Renderer:
|
|
159 |
print(data[key].init_args)
|
160 |
print(data[key].init_kwargs)
|
161 |
if 'stylegan_human' in pkl:
|
162 |
-
net = Generator(
|
|
|
163 |
else:
|
164 |
-
net = Generator(*data[key].init_args,
|
|
|
165 |
net.load_state_dict(data[key].state_dict())
|
166 |
net.to(self._device)
|
167 |
except:
|
@@ -202,25 +212,27 @@ class Renderer:
|
|
202 |
return x
|
203 |
|
204 |
def init_network(self, res,
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
# Dig up network details.
|
217 |
self.pkl = pkl
|
218 |
G = self.get_network(pkl, 'G_ema')
|
219 |
self.G = G
|
220 |
res.img_resolution = G.img_resolution
|
221 |
res.num_ws = G.num_ws
|
222 |
-
res.has_noise = any('noise_const' in name for name,
|
223 |
-
|
|
|
|
|
224 |
|
225 |
# Set input transform.
|
226 |
if res.has_input_transform:
|
@@ -238,11 +250,13 @@ class Renderer:
|
|
238 |
|
239 |
if self.w_load is None:
|
240 |
# Generate random latents.
|
241 |
-
z = torch.from_numpy(np.random.RandomState(
|
|
|
242 |
|
243 |
# Run mapping network.
|
244 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
245 |
-
w = G.mapping(z, label, truncation_psi=trunc_psi,
|
|
|
246 |
else:
|
247 |
w = self.w_load.clone().to(self._device)
|
248 |
|
@@ -266,34 +280,34 @@ class Renderer:
|
|
266 |
print(' Remain feat_refs and points0_pt')
|
267 |
|
268 |
def _render_drag_impl(self, res,
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
G = self.G
|
293 |
ws = self.w
|
294 |
if ws.dim() == 2:
|
295 |
-
ws = ws.unsqueeze(1).repeat(1,6,1)
|
296 |
-
ws = torch.cat([ws[
|
297 |
if hasattr(self, 'points'):
|
298 |
if len(points) != len(self.points):
|
299 |
reset = True
|
@@ -304,7 +318,8 @@ class Renderer:
|
|
304 |
|
305 |
# Run synthesis network.
|
306 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
307 |
-
img, feat = G(ws, label, truncation_psi=trunc_psi,
|
|
|
308 |
|
309 |
h, w = G.img_resolution, G.img_resolution
|
310 |
|
@@ -312,14 +327,17 @@ class Renderer:
|
|
312 |
X = torch.linspace(0, h, h)
|
313 |
Y = torch.linspace(0, w, w)
|
314 |
xx, yy = torch.meshgrid(X, Y)
|
315 |
-
feat_resize = F.interpolate(
|
|
|
316 |
if self.feat_refs is None:
|
317 |
-
self.feat0_resize = F.interpolate(
|
|
|
318 |
self.feat_refs = []
|
319 |
for point in points:
|
320 |
py, px = round(point[0]), round(point[1])
|
321 |
-
self.feat_refs.append(self.feat0_resize[
|
322 |
-
self.points0_pt = torch.Tensor(points).unsqueeze(
|
|
|
323 |
|
324 |
# Point tracking with feature matching
|
325 |
with torch.no_grad():
|
@@ -329,11 +347,13 @@ class Renderer:
|
|
329 |
down = min(point[0] + r + 1, h)
|
330 |
left = max(point[1] - r, 0)
|
331 |
right = min(point[1] + r + 1, w)
|
332 |
-
feat_patch = feat_resize[
|
333 |
-
L2 = torch.linalg.norm(
|
334 |
-
|
|
|
335 |
width = right - left
|
336 |
-
point = [idx.item() // width + up, idx.item() %
|
|
|
337 |
points[j] = point
|
338 |
|
339 |
res.points = [[point[0], point[1]] for point in points]
|
@@ -342,24 +362,31 @@ class Renderer:
|
|
342 |
loss_motion = 0
|
343 |
res.stop = True
|
344 |
for j, point in enumerate(points):
|
345 |
-
direction = torch.Tensor(
|
|
|
346 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
347 |
res.stop = False
|
348 |
if torch.linalg.norm(direction) > 1:
|
349 |
-
distance = (
|
|
|
350 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
351 |
-
direction = direction /
|
|
|
352 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
353 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
354 |
-
grid = torch.stack(
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
357 |
|
358 |
loss = loss_motion
|
359 |
if mask is not None:
|
360 |
if mask.min() == 0 and mask.max() == 1:
|
361 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
362 |
-
loss_fix = F.l1_loss(
|
|
|
363 |
loss += lambda_mask * loss_fix
|
364 |
|
365 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
@@ -371,13 +398,15 @@ class Renderer:
|
|
371 |
# Scale and convert to uint8.
|
372 |
img = img[0]
|
373 |
if img_normalize:
|
374 |
-
img = img / img.norm(float('inf'),
|
|
|
375 |
img = img * (10 ** (img_scale_db / 20))
|
376 |
-
img = (img * 127.5 + 128).clamp(0,
|
|
|
377 |
if to_pil:
|
378 |
from PIL import Image
|
379 |
img = img.cpu().numpy()
|
380 |
img = Image.fromarray(img)
|
381 |
res.image = img
|
382 |
|
383 |
-
|
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
+
import legacy # pylint: disable=import-error
|
24 |
+
|
25 |
+
# ----------------------------------------------------------------------------
|
26 |
|
|
|
27 |
|
28 |
class CapturedException(Exception):
|
29 |
def __init__(self, msg=None):
|
|
|
37 |
assert isinstance(msg, str)
|
38 |
super().__init__(msg)
|
39 |
|
40 |
+
# ----------------------------------------------------------------------------
|
41 |
+
|
42 |
|
43 |
class CaptureSuccess(Exception):
|
44 |
def __init__(self, out):
|
45 |
super().__init__()
|
46 |
self.out = out
|
47 |
|
48 |
+
# ----------------------------------------------------------------------------
|
49 |
+
|
50 |
|
51 |
+
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
52 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
53 |
|
54 |
# Initialize text image
|
|
|
57 |
d = ImageDraw.Draw(txt)
|
58 |
|
59 |
text_width, text_height = font.getsize(watermark_text)
|
60 |
+
text_position = (image.size[0] - text_width -
|
61 |
+
10, image.size[1] - text_height - 10)
|
62 |
+
# white color with the alpha channel set to semi-transparent
|
63 |
+
text_color = (255, 255, 255, 128)
|
64 |
|
65 |
# Draw the text onto the text canvas
|
66 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
|
|
70 |
watermarked_array = np.array(watermarked)
|
71 |
return watermarked_array
|
72 |
|
73 |
+
# ----------------------------------------------------------------------------
|
74 |
+
|
75 |
|
76 |
class Renderer:
|
77 |
def __init__(self, disable_timing=False):
|
78 |
+
self._device = torch.device('cuda')
|
79 |
+
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
|
80 |
+
self._networks = dict() # {cache_key: torch.nn.Module, ...}
|
81 |
+
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
|
82 |
+
self._cmaps = dict() # {name: torch.Tensor, ...}
|
83 |
+
self._is_timing = False
|
84 |
if not disable_timing:
|
85 |
+
self._start_event = torch.cuda.Event(enable_timing=True)
|
86 |
+
self._end_event = torch.cuda.Event(enable_timing=True)
|
87 |
self._disable_timing = disable_timing
|
88 |
+
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
|
89 |
|
90 |
def render(self, **args):
|
91 |
if self._disable_timing:
|
|
|
128 |
|
129 |
if self._is_timing and not self._disable_timing:
|
130 |
self._end_event.synchronize()
|
131 |
+
res.render_time = self._start_event.elapsed_time(
|
132 |
+
self._end_event) * 1e-3
|
133 |
self._is_timing = False
|
134 |
return res
|
135 |
|
|
|
150 |
raise data
|
151 |
|
152 |
orig_net = data[key]
|
153 |
+
cache_key = (orig_net, self._device, tuple(
|
154 |
+
sorted(tweak_kwargs.items())))
|
155 |
net = self._networks.get(cache_key, None)
|
156 |
if net is None:
|
157 |
try:
|
|
|
167 |
print(data[key].init_args)
|
168 |
print(data[key].init_kwargs)
|
169 |
if 'stylegan_human' in pkl:
|
170 |
+
net = Generator(
|
171 |
+
*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
|
172 |
else:
|
173 |
+
net = Generator(*data[key].init_args,
|
174 |
+
**data[key].init_kwargs)
|
175 |
net.load_state_dict(data[key].state_dict())
|
176 |
net.to(self._device)
|
177 |
except:
|
|
|
212 |
return x
|
213 |
|
214 |
def init_network(self, res,
|
215 |
+
pkl=None,
|
216 |
+
w0_seed=0,
|
217 |
+
w_load=None,
|
218 |
+
w_plus=True,
|
219 |
+
noise_mode='const',
|
220 |
+
trunc_psi=0.7,
|
221 |
+
trunc_cutoff=None,
|
222 |
+
input_transform=None,
|
223 |
+
lr=0.001,
|
224 |
+
**kwargs
|
225 |
+
):
|
226 |
# Dig up network details.
|
227 |
self.pkl = pkl
|
228 |
G = self.get_network(pkl, 'G_ema')
|
229 |
self.G = G
|
230 |
res.img_resolution = G.img_resolution
|
231 |
res.num_ws = G.num_ws
|
232 |
+
res.has_noise = any('noise_const' in name for name,
|
233 |
+
_buf in G.synthesis.named_buffers())
|
234 |
+
res.has_input_transform = (
|
235 |
+
hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
|
236 |
|
237 |
# Set input transform.
|
238 |
if res.has_input_transform:
|
|
|
250 |
|
251 |
if self.w_load is None:
|
252 |
# Generate random latents.
|
253 |
+
z = torch.from_numpy(np.random.RandomState(
|
254 |
+
w0_seed).randn(1, 512)).to(self._device).float()
|
255 |
|
256 |
# Run mapping network.
|
257 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
258 |
+
w = G.mapping(z, label, truncation_psi=trunc_psi,
|
259 |
+
truncation_cutoff=trunc_cutoff)
|
260 |
else:
|
261 |
w = self.w_load.clone().to(self._device)
|
262 |
|
|
|
280 |
print(' Remain feat_refs and points0_pt')
|
281 |
|
282 |
def _render_drag_impl(self, res,
|
283 |
+
points=[],
|
284 |
+
targets=[],
|
285 |
+
mask=None,
|
286 |
+
lambda_mask=10,
|
287 |
+
reg=0,
|
288 |
+
feature_idx=5,
|
289 |
+
r1=3,
|
290 |
+
r2=12,
|
291 |
+
random_seed=0,
|
292 |
+
noise_mode='const',
|
293 |
+
trunc_psi=0.7,
|
294 |
+
force_fp32=False,
|
295 |
+
layer_name=None,
|
296 |
+
sel_channels=3,
|
297 |
+
base_channel=0,
|
298 |
+
img_scale_db=0,
|
299 |
+
img_normalize=False,
|
300 |
+
untransform=False,
|
301 |
+
is_drag=False,
|
302 |
+
reset=False,
|
303 |
+
to_pil=False,
|
304 |
+
**kwargs
|
305 |
+
):
|
306 |
G = self.G
|
307 |
ws = self.w
|
308 |
if ws.dim() == 2:
|
309 |
+
ws = ws.unsqueeze(1).repeat(1, 6, 1)
|
310 |
+
ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
|
311 |
if hasattr(self, 'points'):
|
312 |
if len(points) != len(self.points):
|
313 |
reset = True
|
|
|
318 |
|
319 |
# Run synthesis network.
|
320 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
321 |
+
img, feat = G(ws, label, truncation_psi=trunc_psi,
|
322 |
+
noise_mode=noise_mode, input_is_w=True, return_feature=True)
|
323 |
|
324 |
h, w = G.img_resolution, G.img_resolution
|
325 |
|
|
|
327 |
X = torch.linspace(0, h, h)
|
328 |
Y = torch.linspace(0, w, w)
|
329 |
xx, yy = torch.meshgrid(X, Y)
|
330 |
+
feat_resize = F.interpolate(
|
331 |
+
feat[feature_idx], [h, w], mode='bilinear')
|
332 |
if self.feat_refs is None:
|
333 |
+
self.feat0_resize = F.interpolate(
|
334 |
+
feat[feature_idx].detach(), [h, w], mode='bilinear')
|
335 |
self.feat_refs = []
|
336 |
for point in points:
|
337 |
py, px = round(point[0]), round(point[1])
|
338 |
+
self.feat_refs.append(self.feat0_resize[:, :, py, px])
|
339 |
+
self.points0_pt = torch.Tensor(points).unsqueeze(
|
340 |
+
0).to(self._device) # 1, N, 2
|
341 |
|
342 |
# Point tracking with feature matching
|
343 |
with torch.no_grad():
|
|
|
347 |
down = min(point[0] + r + 1, h)
|
348 |
left = max(point[1] - r, 0)
|
349 |
right = min(point[1] + r + 1, w)
|
350 |
+
feat_patch = feat_resize[:, :, up:down, left:right]
|
351 |
+
L2 = torch.linalg.norm(
|
352 |
+
feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
|
353 |
+
_, idx = torch.min(L2.view(1, -1), -1)
|
354 |
width = right - left
|
355 |
+
point = [idx.item() // width + up, idx.item() %
|
356 |
+
width + left]
|
357 |
points[j] = point
|
358 |
|
359 |
res.points = [[point[0], point[1]] for point in points]
|
|
|
362 |
loss_motion = 0
|
363 |
res.stop = True
|
364 |
for j, point in enumerate(points):
|
365 |
+
direction = torch.Tensor(
|
366 |
+
[targets[j][1] - point[1], targets[j][0] - point[0]])
|
367 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
368 |
res.stop = False
|
369 |
if torch.linalg.norm(direction) > 1:
|
370 |
+
distance = (
|
371 |
+
(xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
|
372 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
373 |
+
direction = direction / \
|
374 |
+
(torch.linalg.norm(direction) + 1e-7)
|
375 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
376 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
377 |
+
grid = torch.stack(
|
378 |
+
[gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
379 |
+
target = F.grid_sample(
|
380 |
+
feat_resize.float(), grid, align_corners=True).squeeze(2)
|
381 |
+
loss_motion += F.l1_loss(
|
382 |
+
feat_resize[:, :, relis, reljs], target.detach())
|
383 |
|
384 |
loss = loss_motion
|
385 |
if mask is not None:
|
386 |
if mask.min() == 0 and mask.max() == 1:
|
387 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
388 |
+
loss_fix = F.l1_loss(
|
389 |
+
feat_resize * mask_usq, self.feat0_resize * mask_usq)
|
390 |
loss += lambda_mask * loss_fix
|
391 |
|
392 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
|
|
398 |
# Scale and convert to uint8.
|
399 |
img = img[0]
|
400 |
if img_normalize:
|
401 |
+
img = img / img.norm(float('inf'),
|
402 |
+
dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
|
403 |
img = img * (10 ** (img_scale_db / 20))
|
404 |
+
img = (img * 127.5 + 128).clamp(0,
|
405 |
+
255).to(torch.uint8).permute(1, 2, 0)
|
406 |
if to_pil:
|
407 |
from PIL import Image
|
408 |
img = img.cpu().numpy()
|
409 |
img = Image.fromarray(img)
|
410 |
res.image = img
|
411 |
|
412 |
+
# ----------------------------------------------------------------------------
|