radames commited on
Commit
c760a5e
1 Parent(s): 6d61f2f

add timeout generator

Browse files
Files changed (2) hide show
  1. visualizer_drag_gradio.py +1 -1
  2. viz/renderer.py +111 -81
visualizer_drag_gradio.py CHANGED
@@ -562,7 +562,7 @@ with gr.Blocks() as app:
562
  if IS_SPACE and time.time() - last_time > TIMEOUT:
563
  print('Timeout break!')
564
  break
565
- if global_state["temporal_params"]["stop"]:
566
  break
567
 
568
  # do drage here!
 
562
  if IS_SPACE and time.time() - last_time > TIMEOUT:
563
  print('Timeout break!')
564
  break
565
+ if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
566
  break
567
 
568
  # do drage here!
viz/renderer.py CHANGED
@@ -20,9 +20,10 @@ import torch.nn.functional as F
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
- import legacy # pylint: disable=import-error
 
 
24
 
25
- #----------------------------------------------------------------------------
26
 
27
  class CapturedException(Exception):
28
  def __init__(self, msg=None):
@@ -36,14 +37,16 @@ class CapturedException(Exception):
36
  assert isinstance(msg, str)
37
  super().__init__(msg)
38
 
39
- #----------------------------------------------------------------------------
 
40
 
41
  class CaptureSuccess(Exception):
42
  def __init__(self, out):
43
  super().__init__()
44
  self.out = out
45
 
46
- #----------------------------------------------------------------------------
 
47
 
48
  def add_watermark_np(input_image_array, watermark_text="AI Generated"):
49
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
@@ -54,8 +57,10 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
54
  d = ImageDraw.Draw(txt)
55
 
56
  text_width, text_height = font.getsize(watermark_text)
57
- text_position = (image.size[0] - text_width - 10, image.size[1] - text_height - 10)
58
- text_color = (255, 255, 255, 128) # white color with the alpha channel set to semi-transparent
 
 
59
 
60
  # Draw the text onto the text canvas
61
  d.text(text_position, watermark_text, font=font, fill=text_color)
@@ -65,22 +70,24 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
65
  watermarked_array = np.array(watermarked)
66
  return watermarked_array
67
 
68
- #----------------------------------------------------------------------------
 
69
 
70
  class Renderer:
71
  def __init__(self, disable_timing=False):
72
- self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
73
- self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
74
- self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
75
- self._networks = dict() # {cache_key: torch.nn.Module, ...}
76
- self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
77
- self._cmaps = dict() # {name: torch.Tensor, ...}
78
- self._is_timing = False
 
79
  if not disable_timing:
80
- self._start_event = torch.cuda.Event(enable_timing=True)
81
- self._end_event = torch.cuda.Event(enable_timing=True)
82
  self._disable_timing = disable_timing
83
- self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
84
 
85
  def render(self, **args):
86
  if self._disable_timing:
@@ -126,7 +133,8 @@ class Renderer:
126
 
127
  if self._is_timing and not self._disable_timing:
128
  self._end_event.synchronize()
129
- res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3
 
130
  self._is_timing = False
131
  return res
132
 
@@ -147,7 +155,8 @@ class Renderer:
147
  raise data
148
 
149
  orig_net = data[key]
150
- cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items())))
 
151
  net = self._networks.get(cache_key, None)
152
  if net is None:
153
  try:
@@ -163,9 +172,11 @@ class Renderer:
163
  print(data[key].init_args)
164
  print(data[key].init_kwargs)
165
  if 'stylegan_human' in pkl:
166
- net = Generator(*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
 
167
  else:
168
- net = Generator(*data[key].init_args, **data[key].init_kwargs)
 
169
  net.load_state_dict(data[key].state_dict())
170
  net.to(self._device)
171
  except:
@@ -206,26 +217,28 @@ class Renderer:
206
  return x
207
 
208
  def init_network(self, res,
209
- pkl = None,
210
- w0_seed = 0,
211
- w_load = None,
212
- w_plus = True,
213
- noise_mode = 'const',
214
- trunc_psi = 0.7,
215
- trunc_cutoff = None,
216
- input_transform = None,
217
- lr = 0.001,
218
- **kwargs
219
- ):
220
  # Dig up network details.
221
  self.pkl = pkl
222
  G = self.get_network(pkl, 'G_ema')
223
  self.G = G
224
  res.img_resolution = G.img_resolution
225
  res.num_ws = G.num_ws
226
- res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers())
227
- res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
228
-
 
 
229
  # Set input transform.
230
  if res.has_input_transform:
231
  m = np.eye(3)
@@ -242,11 +255,13 @@ class Renderer:
242
 
243
  if self.w_load is None:
244
  # Generate random latents.
245
- z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
 
246
 
247
  # Run mapping network.
248
  label = torch.zeros([1, G.c_dim], device=self._device)
249
- w = G.mapping(z, label, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff)
 
250
  else:
251
  w = self.w_load.clone().to(self._device)
252
 
@@ -270,34 +285,34 @@ class Renderer:
270
  print(' Remain feat_refs and points0_pt')
271
 
272
  def _render_drag_impl(self, res,
273
- points = [],
274
- targets = [],
275
- mask = None,
276
- lambda_mask = 10,
277
- reg = 0,
278
- feature_idx = 5,
279
- r1 = 3,
280
- r2 = 12,
281
- random_seed = 0,
282
- noise_mode = 'const',
283
- trunc_psi = 0.7,
284
- force_fp32 = False,
285
- layer_name = None,
286
- sel_channels = 3,
287
- base_channel = 0,
288
- img_scale_db = 0,
289
- img_normalize = False,
290
- untransform = False,
291
- is_drag = False,
292
- reset = False,
293
- to_pil = False,
294
- **kwargs
295
- ):
296
  G = self.G
297
  ws = self.w
298
  if ws.dim() == 2:
299
- ws = ws.unsqueeze(1).repeat(1,6,1)
300
- ws = torch.cat([ws[:,:6,:], self.w0[:,6:,:]], dim=1)
301
  if hasattr(self, 'points'):
302
  if len(points) != len(self.points):
303
  reset = True
@@ -308,7 +323,8 @@ class Renderer:
308
 
309
  # Run synthesis network.
310
  label = torch.zeros([1, G.c_dim], device=self._device)
311
- img, feat = G(ws, label, truncation_psi=trunc_psi, noise_mode=noise_mode, input_is_w=True, return_feature=True)
 
312
 
313
  h, w = G.img_resolution, G.img_resolution
314
 
@@ -316,14 +332,17 @@ class Renderer:
316
  X = torch.linspace(0, h, h)
317
  Y = torch.linspace(0, w, w)
318
  xx, yy = torch.meshgrid(X, Y)
319
- feat_resize = F.interpolate(feat[feature_idx], [h, w], mode='bilinear')
 
320
  if self.feat_refs is None:
321
- self.feat0_resize = F.interpolate(feat[feature_idx].detach(), [h, w], mode='bilinear')
 
322
  self.feat_refs = []
323
  for point in points:
324
  py, px = round(point[0]), round(point[1])
325
- self.feat_refs.append(self.feat0_resize[:,:,py,px])
326
- self.points0_pt = torch.Tensor(points).unsqueeze(0).to(self._device) # 1, N, 2
 
327
 
328
  # Point tracking with feature matching
329
  with torch.no_grad():
@@ -333,11 +352,13 @@ class Renderer:
333
  down = min(point[0] + r + 1, h)
334
  left = max(point[1] - r, 0)
335
  right = min(point[1] + r + 1, w)
336
- feat_patch = feat_resize[:,:,up:down,left:right]
337
- L2 = torch.linalg.norm(feat_patch - self.feat_refs[j].reshape(1,-1,1,1), dim=1)
338
- _, idx = torch.min(L2.view(1,-1), -1)
 
339
  width = right - left
340
- point = [idx.item() // width + up, idx.item() % width + left]
 
341
  points[j] = point
342
 
343
  res.points = [[point[0], point[1]] for point in points]
@@ -346,24 +367,31 @@ class Renderer:
346
  loss_motion = 0
347
  res.stop = True
348
  for j, point in enumerate(points):
349
- direction = torch.Tensor([targets[j][1] - point[1], targets[j][0] - point[0]])
 
350
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
351
  res.stop = False
352
  if torch.linalg.norm(direction) > 1:
353
- distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
 
354
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
355
- direction = direction / (torch.linalg.norm(direction) + 1e-7)
 
356
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
357
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
358
- grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0)
359
- target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2)
360
- loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach())
 
 
 
361
 
362
  loss = loss_motion
363
  if mask is not None:
364
  if mask.min() == 0 and mask.max() == 1:
365
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
366
- loss_fix = F.l1_loss(feat_resize * mask_usq, self.feat0_resize * mask_usq)
 
367
  loss += lambda_mask * loss_fix
368
 
369
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
@@ -375,9 +403,11 @@ class Renderer:
375
  # Scale and convert to uint8.
376
  img = img[0]
377
  if img_normalize:
378
- img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8)
 
379
  img = img * (10 ** (img_scale_db / 20))
380
- img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
 
381
  if to_pil:
382
  from PIL import Image
383
  img = img.cpu().numpy()
@@ -385,4 +415,4 @@ class Renderer:
385
  res.image = img
386
  res.w = ws.detach().cpu().numpy()
387
 
388
- #----------------------------------------------------------------------------
 
20
  import matplotlib.cm
21
  import dnnlib
22
  from torch_utils.ops import upfirdn2d
23
+ import legacy # pylint: disable=import-error
24
+
25
+ # ----------------------------------------------------------------------------
26
 
 
27
 
28
  class CapturedException(Exception):
29
  def __init__(self, msg=None):
 
37
  assert isinstance(msg, str)
38
  super().__init__(msg)
39
 
40
+ # ----------------------------------------------------------------------------
41
+
42
 
43
  class CaptureSuccess(Exception):
44
  def __init__(self, out):
45
  super().__init__()
46
  self.out = out
47
 
48
+ # ----------------------------------------------------------------------------
49
+
50
 
51
  def add_watermark_np(input_image_array, watermark_text="AI Generated"):
52
  image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
 
57
  d = ImageDraw.Draw(txt)
58
 
59
  text_width, text_height = font.getsize(watermark_text)
60
+ text_position = (image.size[0] - text_width -
61
+ 10, image.size[1] - text_height - 10)
62
+ # white color with the alpha channel set to semi-transparent
63
+ text_color = (255, 255, 255, 128)
64
 
65
  # Draw the text onto the text canvas
66
  d.text(text_position, watermark_text, font=font, fill=text_color)
 
70
  watermarked_array = np.array(watermarked)
71
  return watermarked_array
72
 
73
+ # ----------------------------------------------------------------------------
74
+
75
 
76
  class Renderer:
77
  def __init__(self, disable_timing=False):
78
+ self._device = torch.device('cuda' if torch.cuda.is_available(
79
+ ) else 'mps' if torch.backends.mps.is_available() else 'cpu')
80
+ self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
81
+ self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
82
+ self._networks = dict() # {cache_key: torch.nn.Module, ...}
83
+ self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
84
+ self._cmaps = dict() # {name: torch.Tensor, ...}
85
+ self._is_timing = False
86
  if not disable_timing:
87
+ self._start_event = torch.cuda.Event(enable_timing=True)
88
+ self._end_event = torch.cuda.Event(enable_timing=True)
89
  self._disable_timing = disable_timing
90
+ self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
91
 
92
  def render(self, **args):
93
  if self._disable_timing:
 
133
 
134
  if self._is_timing and not self._disable_timing:
135
  self._end_event.synchronize()
136
+ res.render_time = self._start_event.elapsed_time(
137
+ self._end_event) * 1e-3
138
  self._is_timing = False
139
  return res
140
 
 
155
  raise data
156
 
157
  orig_net = data[key]
158
+ cache_key = (orig_net, self._device, tuple(
159
+ sorted(tweak_kwargs.items())))
160
  net = self._networks.get(cache_key, None)
161
  if net is None:
162
  try:
 
172
  print(data[key].init_args)
173
  print(data[key].init_kwargs)
174
  if 'stylegan_human' in pkl:
175
+ net = Generator(
176
+ *data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
177
  else:
178
+ net = Generator(*data[key].init_args,
179
+ **data[key].init_kwargs)
180
  net.load_state_dict(data[key].state_dict())
181
  net.to(self._device)
182
  except:
 
217
  return x
218
 
219
  def init_network(self, res,
220
+ pkl=None,
221
+ w0_seed=0,
222
+ w_load=None,
223
+ w_plus=True,
224
+ noise_mode='const',
225
+ trunc_psi=0.7,
226
+ trunc_cutoff=None,
227
+ input_transform=None,
228
+ lr=0.001,
229
+ **kwargs
230
+ ):
231
  # Dig up network details.
232
  self.pkl = pkl
233
  G = self.get_network(pkl, 'G_ema')
234
  self.G = G
235
  res.img_resolution = G.img_resolution
236
  res.num_ws = G.num_ws
237
+ res.has_noise = any('noise_const' in name for name,
238
+ _buf in G.synthesis.named_buffers())
239
+ res.has_input_transform = (
240
+ hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
241
+ res.stop = False
242
  # Set input transform.
243
  if res.has_input_transform:
244
  m = np.eye(3)
 
255
 
256
  if self.w_load is None:
257
  # Generate random latents.
258
+ z = torch.from_numpy(np.random.RandomState(w0_seed).randn(
259
+ 1, 512)).to(self._device, dtype=self._dtype)
260
 
261
  # Run mapping network.
262
  label = torch.zeros([1, G.c_dim], device=self._device)
263
+ w = G.mapping(z, label, truncation_psi=trunc_psi,
264
+ truncation_cutoff=trunc_cutoff)
265
  else:
266
  w = self.w_load.clone().to(self._device)
267
 
 
285
  print(' Remain feat_refs and points0_pt')
286
 
287
  def _render_drag_impl(self, res,
288
+ points=[],
289
+ targets=[],
290
+ mask=None,
291
+ lambda_mask=10,
292
+ reg=0,
293
+ feature_idx=5,
294
+ r1=3,
295
+ r2=12,
296
+ random_seed=0,
297
+ noise_mode='const',
298
+ trunc_psi=0.7,
299
+ force_fp32=False,
300
+ layer_name=None,
301
+ sel_channels=3,
302
+ base_channel=0,
303
+ img_scale_db=0,
304
+ img_normalize=False,
305
+ untransform=False,
306
+ is_drag=False,
307
+ reset=False,
308
+ to_pil=False,
309
+ **kwargs
310
+ ):
311
  G = self.G
312
  ws = self.w
313
  if ws.dim() == 2:
314
+ ws = ws.unsqueeze(1).repeat(1, 6, 1)
315
+ ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
316
  if hasattr(self, 'points'):
317
  if len(points) != len(self.points):
318
  reset = True
 
323
 
324
  # Run synthesis network.
325
  label = torch.zeros([1, G.c_dim], device=self._device)
326
+ img, feat = G(ws, label, truncation_psi=trunc_psi,
327
+ noise_mode=noise_mode, input_is_w=True, return_feature=True)
328
 
329
  h, w = G.img_resolution, G.img_resolution
330
 
 
332
  X = torch.linspace(0, h, h)
333
  Y = torch.linspace(0, w, w)
334
  xx, yy = torch.meshgrid(X, Y)
335
+ feat_resize = F.interpolate(
336
+ feat[feature_idx], [h, w], mode='bilinear')
337
  if self.feat_refs is None:
338
+ self.feat0_resize = F.interpolate(
339
+ feat[feature_idx].detach(), [h, w], mode='bilinear')
340
  self.feat_refs = []
341
  for point in points:
342
  py, px = round(point[0]), round(point[1])
343
+ self.feat_refs.append(self.feat0_resize[:, :, py, px])
344
+ self.points0_pt = torch.Tensor(points).unsqueeze(
345
+ 0).to(self._device) # 1, N, 2
346
 
347
  # Point tracking with feature matching
348
  with torch.no_grad():
 
352
  down = min(point[0] + r + 1, h)
353
  left = max(point[1] - r, 0)
354
  right = min(point[1] + r + 1, w)
355
+ feat_patch = feat_resize[:, :, up:down, left:right]
356
+ L2 = torch.linalg.norm(
357
+ feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
358
+ _, idx = torch.min(L2.view(1, -1), -1)
359
  width = right - left
360
+ point = [idx.item() // width + up, idx.item() %
361
+ width + left]
362
  points[j] = point
363
 
364
  res.points = [[point[0], point[1]] for point in points]
 
367
  loss_motion = 0
368
  res.stop = True
369
  for j, point in enumerate(points):
370
+ direction = torch.Tensor(
371
+ [targets[j][1] - point[1], targets[j][0] - point[0]])
372
  if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
373
  res.stop = False
374
  if torch.linalg.norm(direction) > 1:
375
+ distance = (
376
+ (xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
377
  relis, reljs = torch.where(distance < round(r1 / 512 * h))
378
+ direction = direction / \
379
+ (torch.linalg.norm(direction) + 1e-7)
380
  gridh = (relis-direction[1]) / (h-1) * 2 - 1
381
  gridw = (reljs-direction[0]) / (w-1) * 2 - 1
382
+ grid = torch.stack(
383
+ [gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
384
+ target = F.grid_sample(
385
+ feat_resize.float(), grid, align_corners=True).squeeze(2)
386
+ loss_motion += F.l1_loss(
387
+ feat_resize[:, :, relis, reljs], target.detach())
388
 
389
  loss = loss_motion
390
  if mask is not None:
391
  if mask.min() == 0 and mask.max() == 1:
392
  mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
393
+ loss_fix = F.l1_loss(
394
+ feat_resize * mask_usq, self.feat0_resize * mask_usq)
395
  loss += lambda_mask * loss_fix
396
 
397
  loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
 
403
  # Scale and convert to uint8.
404
  img = img[0]
405
  if img_normalize:
406
+ img = img / img.norm(float('inf'),
407
+ dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
408
  img = img * (10 ** (img_scale_db / 20))
409
+ img = (img * 127.5 + 128).clamp(0,
410
+ 255).to(torch.uint8).permute(1, 2, 0)
411
  if to_pil:
412
  from PIL import Image
413
  img = img.cpu().numpy()
 
415
  res.image = img
416
  res.w = ws.detach().cpu().numpy()
417
 
418
+ # ----------------------------------------------------------------------------