radames commited on
Commit
b2af35b
1 Parent(s): 0f62188

watermark and copyright

Browse files
Files changed (2) hide show
  1. visualizer_drag_gradio.py +17 -1
  2. viz/renderer.py +109 -80
visualizer_drag_gradio.py CHANGED
@@ -183,7 +183,7 @@ with gr.Blocks() as app:
183
  <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
184
 
185
  * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
186
- * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) with [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
187
  """)
188
 
189
  # renderer = Renderer()
@@ -355,6 +355,22 @@ with gr.Blocks() as app:
355
  unmasked region to remain unchanged.
356
 
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  """)
359
  # Network & latents tab listeners
360
 
 
183
  <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p>
184
 
185
  * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
186
+ * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) © [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
187
  """)
188
 
189
  # renderer = Renderer()
 
355
  unmasked region to remain unchanged.
356
 
357
 
358
+ """)
359
+ gr.HTML("""
360
+ <style>
361
+ .container {
362
+ position: absolute;
363
+ height: 50px;
364
+ text-align: center;
365
+ line-height: 50px;
366
+ width: 100%;
367
+ }
368
+ </style>
369
+ <div class="container">
370
+ Gradio demo supported by
371
+ <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
372
+ <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
373
+ </div>
374
  """)
375
  # Network & latents tab listeners
376
 
viz/renderer.py CHANGED
@@ -20,9 +20,10 @@ import torch.nn.functional as F
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
- import legacy # pylint: disable=import-error
 
 
24
 
25
- #----------------------------------------------------------------------------
26
 
27
  class CapturedException(Exception):
28
  def __init__(self, msg=None):
@@ -36,16 +37,18 @@ class CapturedException(Exception):
36
  assert isinstance(msg, str)
37
  super().__init__(msg)
38
 
39
- #----------------------------------------------------------------------------
 
40
 
41
  class CaptureSuccess(Exception):
42
  def __init__(self, out):
43
  super().__init__()
44
  self.out = out
45
 
46
- #----------------------------------------------------------------------------
 
47
 
48
- def add_watermark_np(input_image_array, watermark_text="Watermark"):
49
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
50
 
51
  # Initialize text image
@@ -54,8 +57,10 @@ def add_watermark_np(input_image_array, watermark_text="Watermark"):
54
  d = ImageDraw.Draw(txt)
55
 
56
  text_width, text_height = font.getsize(watermark_text)
57
- text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10)
58
- text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent
 
 
59
 
60
  # Draw the text onto the text canvas
61
  d.text(text_position, watermark_text, font=font, fill=text_color)
@@ -65,21 +70,22 @@ def add_watermark_np(input_image_array, watermark_text="Watermark"):
65
  watermarked_array = np.array(watermarked)
66
  return watermarked_array
67
 
68
- #----------------------------------------------------------------------------
 
69
 
70
  class Renderer:
71
  def __init__(self, disable_timing=False):
72
- self._device = torch.device('cuda')
73
- self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
74
- self._networks = dict() # {cache_key: torch.nn.Module, ...}
75
- self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
76
- self._cmaps = dict() # {name: torch.Tensor, ...}
77
- self._is_timing = False
78
  if not disable_timing:
79
- self._start_event = torch.cuda.Event(enable_timing=True)
80
- self._end_event = torch.cuda.Event(enable_timing=True)
81
  self._disable_timing = disable_timing
82
- self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
83
 
84
  def render(self, **args):
85
  if self._disable_timing:
@@ -122,7 +128,8 @@ class Renderer:
122
 
123
  if self._is_timing and not self._disable_timing:
124
  self._end_event.synchronize()
125
- res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
 
126
  self._is_timing = False
127
  return res
128
 
@@ -143,7 +150,8 @@ class Renderer:
143
  raise data
144
 
145
  orig_net = data[key]
146
- cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
 
147
  net = self._networks.get(cache_key, None)
148
  if net is None:
149
  try:
@@ -159,9 +167,11 @@ class Renderer:
159
  print(data[key].init_args)
160
  print(data[key].init_kwargs)
161
  if 'stylegan_human' in pkl:
162
- net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
 
163
  else:
164
- net = Generator(*data[key].init_args, **data[key].init_kwargs)
 
165
  net.load_state_dict(data[key].state_dict())
166
  net.to(self._device)
167
  except:
@@ -202,25 +212,27 @@ class Renderer:
202
  return x
203
 
204
  def init_network(self, res,
205
- pkl = None,
206
- w0_seed = 0,
207
- w_load = None,
208
- w_plus = True,
209
- noise_mode = 'const',
210
- trunc_psi = 0.7,
211
- trunc_cutoff = None,
212
- input_transform = None,
213
- lr = 0.001,
214
- **kwargs
215
- ):
216
  # Dig up network details.
217
  self.pkl = pkl
218
  G = self.get_network(pkl, 'G_ema')
219
  self.G = G
220
  res.img_resolution = G.img_resolution
221
  res.num_ws = G.num_ws
222
- res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
223
- res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
 
 
224
 
225
  # Set input transform.
226
  if res.has_input_transform:
@@ -238,11 +250,13 @@ class Renderer:
238
 
239
  if self.w_load is None:
240
  # Generate random latents.
241
- z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device).float()
 
242
 
243
  # Run mapping network.
244
  label = torch.zeros([1, G.c_dim], device=self._device)
245
- w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
 
246
  else:
247
  w = self.w_load.clone().to(self._device)
248
 
@@ -266,34 +280,34 @@ class Renderer:
266
  print(' Remain feat_refs and points0_pt')
267
 
268
  def _render_drag_impl(self, res,
269
- points = [],
270
- targets = [],
271
- mask = None,
272
- lambda_mask = 10,
273
- reg = 0,
274
- feature_idx = 5,
275
- r1 = 3,
276
- r2 = 12,
277
- random_seed = 0,
278
- noise_mode = 'const',
279
- trunc_psi = 0.7,
280
- force_fp32 = False,
281
- layer_name = None,
282
- sel_channels = 3,
283
- base_channel = 0,
284
- img_scale_db = 0,
285
- img_normalize = False,
286
- untransform = False,
287
- is_drag = False,
288
- reset = False,
289
- to_pil = False,
290
- **kwargs
291
- ):
292
  G = self.G
293
  ws = self.w
294
  if ws.dim() == 2:
295
- ws = ws.unsqueeze(1).repeat(1,6,1)
296
- ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
297
  if hasattr(self, 'points'):
298
  if len(points) != len(self.points):
299
  reset = True
@@ -304,7 +318,8 @@ class Renderer:
304
 
305
  # Run synthesis network.
306
  label = torch.zeros([1, G.c_dim], device=self._device)
307
- img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
 
308
 
309
  h, w = G.img_resolution, G.img_resolution
310
 
@@ -312,14 +327,17 @@ class Renderer:
312
  X = torch.linspace(0, h, h)
313
  Y = torch.linspace(0, w, w)
314
  xx, yy = torch.meshgrid(X, Y)
315
- feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
 
316
  if self.feat_refs is None:
317
- self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
 
318
  self.feat_refs = []
319
  for point in points:
320
  py, px = round(point[0]), round(point[1])
321
- self.feat_refs.append(self.feat0_resize[:,:,py,px])
322
- self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
 
323
 
324
  # Point tracking with feature matching
325
  with torch.no_grad():
@@ -329,11 +347,13 @@ class Renderer:
329
  down = min(point[0] + r + 1, h)
330
  left = max(point[1] - r, 0)
331
  right = min(point[1] + r + 1, w)
332
- feat_patch = feat_resize[:,:,up:down,left:right]
333
- L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1)
334
- _, idx = torch.min(L2.view(1,-1), -1)
 
335
  width = right - left
336
- point = [idx.item() // width + up, idx.item() % width + left]
 
337
  points[j] = point
338
 
339
  res.points = [[point[0], point[1]] for point in points]
@@ -342,24 +362,31 @@ class Renderer:
342
  loss_motion = 0
343
  res.stop = True
344
  for j, point in enumerate(points):
345
- direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]])
 
346
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
347
  res.stop = False
348
  if torch.linalg.norm(direction) > 1:
349
- distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
 
350
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
351
- direction = direction / (torch.linalg.norm(direction) + 1e-7)
 
352
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
353
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
354
- grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
355
- target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
356
- loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
 
 
 
357
 
358
  loss = loss_motion
359
  if mask is not None:
360
  if mask.min() == 0 and mask.max() == 1:
361
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
362
- loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq)
 
363
  loss += lambda_mask * loss_fix
364
 
365
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
@@ -371,13 +398,15 @@ class Renderer:
371
  # Scale and convert to uint8.
372
  img = img[0]
373
  if img_normalize:
374
- img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
 
375
  img = img * (10 ** (img_scale_db / 20))
376
- img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
 
377
  if to_pil:
378
  from PIL import Image
379
  img = img.cpu().numpy()
380
  img = Image.fromarray(img)
381
  res.image = img
382
 
383
- #----------------------------------------------------------------------------
 
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
+ import legacy # pylint: disable=import-error
24
+
25
+ # ----------------------------------------------------------------------------
26
 
 
27
 
28
  class CapturedException(Exception):
29
  def __init__(self, msg=None):
 
37
  assert isinstance(msg, str)
38
  super().__init__(msg)
39
 
40
+ # ----------------------------------------------------------------------------
41
+
42
 
43
  class CaptureSuccess(Exception):
44
  def __init__(self, out):
45
  super().__init__()
46
  self.out = out
47
 
48
+ # ----------------------------------------------------------------------------
49
+
50
 
51
+ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
52
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
53
 
54
  # Initialize text image
 
57
  d = ImageDraw.Draw(txt)
58
 
59
  text_width, text_height = font.getsize(watermark_text)
60
+ text_position = (image.size[0] - text_width -
61
+ 10, image.size[1] - text_height - 10)
62
+ # white color with the alpha channel set to semi-transparent
63
+ text_color = (255, 255, 255, 128)
64
 
65
  # Draw the text onto the text canvas
66
  d.text(text_position, watermark_text, font=font, fill=text_color)
 
70
  watermarked_array = np.array(watermarked)
71
  return watermarked_array
72
 
73
+ # ----------------------------------------------------------------------------
74
+
75
 
76
  class Renderer:
77
  def __init__(self, disable_timing=False):
78
+ self._device = torch.device('cuda')
79
+ self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
80
+ self._networks = dict() # {cache_key: torch.nn.Module, ...}
81
+ self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
82
+ self._cmaps = dict() # {name: torch.Tensor, ...}
83
+ self._is_timing = False
84
  if not disable_timing:
85
+ self._start_event = torch.cuda.Event(enable_timing=True)
86
+ self._end_event = torch.cuda.Event(enable_timing=True)
87
  self._disable_timing = disable_timing
88
+ self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
89
 
90
  def render(self, **args):
91
  if self._disable_timing:
 
128
 
129
  if self._is_timing and not self._disable_timing:
130
  self._end_event.synchronize()
131
+ res.render_time = self._start_event.elapsed_time(
132
+ self._end_event) * 1e-3
133
  self._is_timing = False
134
  return res
135
 
 
150
  raise data
151
 
152
  orig_net = data[key]
153
+ cache_key = (orig_net, self._device, tuple(
154
+ sorted(tweak_kwargs.items())))
155
  net = self._networks.get(cache_key, None)
156
  if net is None:
157
  try:
 
167
  print(data[key].init_args)
168
  print(data[key].init_kwargs)
169
  if 'stylegan_human' in pkl:
170
+ net = Generator(
171
+ *data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
172
  else:
173
+ net = Generator(*data[key].init_args,
174
+ **data[key].init_kwargs)
175
  net.load_state_dict(data[key].state_dict())
176
  net.to(self._device)
177
  except:
 
212
  return x
213
 
214
  def init_network(self, res,
215
+ pkl=None,
216
+ w0_seed=0,
217
+ w_load=None,
218
+ w_plus=True,
219
+ noise_mode='const',
220
+ trunc_psi=0.7,
221
+ trunc_cutoff=None,
222
+ input_transform=None,
223
+ lr=0.001,
224
+ **kwargs
225
+ ):
226
  # Dig up network details.
227
  self.pkl = pkl
228
  G = self.get_network(pkl, 'G_ema')
229
  self.G = G
230
  res.img_resolution = G.img_resolution
231
  res.num_ws = G.num_ws
232
+ res.has_noise = any('noise_const' in name for name,
233
+ _buf in G.synthesis.named_buffers())
234
+ res.has_input_transform = (
235
+ hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
236
 
237
  # Set input transform.
238
  if res.has_input_transform:
 
250
 
251
  if self.w_load is None:
252
  # Generate random latents.
253
+ z = torch.from_numpy(np.random.RandomState(
254
+ w0_seed).randn(1, 512)).to(self._device).float()
255
 
256
  # Run mapping network.
257
  label = torch.zeros([1, G.c_dim], device=self._device)
258
+ w = G.mapping(z, label, truncation_psi=trunc_psi,
259
+ truncation_cutoff=trunc_cutoff)
260
  else:
261
  w = self.w_load.clone().to(self._device)
262
 
 
280
  print(' Remain feat_refs and points0_pt')
281
 
282
  def _render_drag_impl(self, res,
283
+ points=[],
284
+ targets=[],
285
+ mask=None,
286
+ lambda_mask=10,
287
+ reg=0,
288
+ feature_idx=5,
289
+ r1=3,
290
+ r2=12,
291
+ random_seed=0,
292
+ noise_mode='const',
293
+ trunc_psi=0.7,
294
+ force_fp32=False,
295
+ layer_name=None,
296
+ sel_channels=3,
297
+ base_channel=0,
298
+ img_scale_db=0,
299
+ img_normalize=False,
300
+ untransform=False,
301
+ is_drag=False,
302
+ reset=False,
303
+ to_pil=False,
304
+ **kwargs
305
+ ):
306
  G = self.G
307
  ws = self.w
308
  if ws.dim() == 2:
309
+ ws = ws.unsqueeze(1).repeat(1, 6, 1)
310
+ ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
311
  if hasattr(self, 'points'):
312
  if len(points) != len(self.points):
313
  reset = True
 
318
 
319
  # Run synthesis network.
320
  label = torch.zeros([1, G.c_dim], device=self._device)
321
+ img, feat = G(ws, label, truncation_psi=trunc_psi,
322
+ noise_mode=noise_mode, input_is_w=True, return_feature=True)
323
 
324
  h, w = G.img_resolution, G.img_resolution
325
 
 
327
  X = torch.linspace(0, h, h)
328
  Y = torch.linspace(0, w, w)
329
  xx, yy = torch.meshgrid(X, Y)
330
+ feat_resize = F.interpolate(
331
+ feat[feature_idx], [h, w], mode='bilinear')
332
  if self.feat_refs is None:
333
+ self.feat0_resize = F.interpolate(
334
+ feat[feature_idx].detach(), [h, w], mode='bilinear')
335
  self.feat_refs = []
336
  for point in points:
337
  py, px = round(point[0]), round(point[1])
338
+ self.feat_refs.append(self.feat0_resize[:, :, py, px])
339
+ self.points0_pt = torch.Tensor(points).unsqueeze(
340
+ 0).to(self._device) # 1, N, 2
341
 
342
  # Point tracking with feature matching
343
  with torch.no_grad():
 
347
  down = min(point[0] + r + 1, h)
348
  left = max(point[1] - r, 0)
349
  right = min(point[1] + r + 1, w)
350
+ feat_patch = feat_resize[:, :, up:down, left:right]
351
+ L2 = torch.linalg.norm(
352
+ feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
353
+ _, idx = torch.min(L2.view(1, -1), -1)
354
  width = right - left
355
+ point = [idx.item() // width + up, idx.item() %
356
+ width + left]
357
  points[j] = point
358
 
359
  res.points = [[point[0], point[1]] for point in points]
 
362
  loss_motion = 0
363
  res.stop = True
364
  for j, point in enumerate(points):
365
+ direction = torch.Tensor(
366
+ [targets[j][1] - point[1], targets[j][0] - point[0]])
367
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
368
  res.stop = False
369
  if torch.linalg.norm(direction) > 1:
370
+ distance = (
371
+ (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
372
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
373
+ direction = direction / \
374
+ (torch.linalg.norm(direction) + 1e-7)
375
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
376
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
377
+ grid = torch.stack(
378
+ [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
379
+ target = F.grid_sample(
380
+ feat_resize.float(), grid, align_corners=True).squeeze(2)
381
+ loss_motion += F.l1_loss(
382
+ feat_resize[:, :, relis, reljs], target.detach())
383
 
384
  loss = loss_motion
385
  if mask is not None:
386
  if mask.min() == 0 and mask.max() == 1:
387
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
388
+ loss_fix = F.l1_loss(
389
+ feat_resize * mask_usq, self.feat0_resize * mask_usq)
390
  loss += lambda_mask * loss_fix
391
 
392
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
 
398
  # Scale and convert to uint8.
399
  img = img[0]
400
  if img_normalize:
401
+ img = img / img.norm(float('inf'),
402
+ dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
403
  img = img * (10 ** (img_scale_db / 20))
404
+ img = (img * 127.5 + 128).clamp(0,
405
+ 255).to(torch.uint8).permute(1, 2, 0)
406
  if to_pil:
407
  from PIL import Image
408
  img = img.cpu().numpy()
409
  img = Image.fromarray(img)
410
  res.image = img
411
 
412
+ # ----------------------------------------------------------------------------