add timeout generator
Browse files- visualizer_drag_gradio.py +1 -1
- viz/renderer.py +111 -81
visualizer_drag_gradio.py
CHANGED
@@ -562,7 +562,7 @@ with gr.Blocks() as app:
|
|
562 |
if IS_SPACE and time.time() - last_time > TIMEOUT:
|
563 |
print('Timeout break!')
|
564 |
break
|
565 |
-
if global_state["temporal_params"]["stop"]:
|
566 |
break
|
567 |
|
568 |
# do drage here!
|
|
|
562 |
if IS_SPACE and time.time() - last_time > TIMEOUT:
|
563 |
print('Timeout break!')
|
564 |
break
|
565 |
+
if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
|
566 |
break
|
567 |
|
568 |
# do drage here!
|
viz/renderer.py
CHANGED
@@ -20,9 +20,10 @@ import torch.nn.functional as F
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
-
import legacy
|
|
|
|
|
24 |
|
25 |
-
#----------------------------------------------------------------------------
|
26 |
|
27 |
class CapturedException(Exception):
|
28 |
def __init__(self, msg=None):
|
@@ -36,14 +37,16 @@ class CapturedException(Exception):
|
|
36 |
assert isinstance(msg, str)
|
37 |
super().__init__(msg)
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
class CaptureSuccess(Exception):
|
42 |
def __init__(self, out):
|
43 |
super().__init__()
|
44 |
self.out = out
|
45 |
|
46 |
-
|
|
|
47 |
|
48 |
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
49 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
@@ -54,8 +57,10 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
|
54 |
d = ImageDraw.Draw(txt)
|
55 |
|
56 |
text_width, text_height = font.getsize(watermark_text)
|
57 |
-
text_position = (image.size[0] - text_width -
|
58 |
-
|
|
|
|
|
59 |
|
60 |
# Draw the text onto the text canvas
|
61 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
@@ -65,22 +70,24 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
|
65 |
watermarked_array = np.array(watermarked)
|
66 |
return watermarked_array
|
67 |
|
68 |
-
|
|
|
69 |
|
70 |
class Renderer:
|
71 |
def __init__(self, disable_timing=False):
|
72 |
-
self._device
|
73 |
-
|
74 |
-
self.
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
self.
|
78 |
-
self.
|
|
|
79 |
if not disable_timing:
|
80 |
-
self._start_event
|
81 |
-
self._end_event
|
82 |
self._disable_timing = disable_timing
|
83 |
-
self._net_layers
|
84 |
|
85 |
def render(self, **args):
|
86 |
if self._disable_timing:
|
@@ -126,7 +133,8 @@ class Renderer:
|
|
126 |
|
127 |
if self._is_timing and not self._disable_timing:
|
128 |
self._end_event.synchronize()
|
129 |
-
res.render_time = self._start_event.elapsed_time(
|
|
|
130 |
self._is_timing = False
|
131 |
return res
|
132 |
|
@@ -147,7 +155,8 @@ class Renderer:
|
|
147 |
raise data
|
148 |
|
149 |
orig_net = data[key]
|
150 |
-
cache_key = (orig_net, self._device, tuple(
|
|
|
151 |
net = self._networks.get(cache_key, None)
|
152 |
if net is None:
|
153 |
try:
|
@@ -163,9 +172,11 @@ class Renderer:
|
|
163 |
print(data[key].init_args)
|
164 |
print(data[key].init_kwargs)
|
165 |
if 'stylegan_human' in pkl:
|
166 |
-
net = Generator(
|
|
|
167 |
else:
|
168 |
-
net = Generator(*data[key].init_args,
|
|
|
169 |
net.load_state_dict(data[key].state_dict())
|
170 |
net.to(self._device)
|
171 |
except:
|
@@ -206,26 +217,28 @@ class Renderer:
|
|
206 |
return x
|
207 |
|
208 |
def init_network(self, res,
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
# Dig up network details.
|
221 |
self.pkl = pkl
|
222 |
G = self.get_network(pkl, 'G_ema')
|
223 |
self.G = G
|
224 |
res.img_resolution = G.img_resolution
|
225 |
res.num_ws = G.num_ws
|
226 |
-
res.has_noise = any('noise_const' in name for name,
|
227 |
-
|
228 |
-
|
|
|
|
|
229 |
# Set input transform.
|
230 |
if res.has_input_transform:
|
231 |
m = np.eye(3)
|
@@ -242,11 +255,13 @@ class Renderer:
|
|
242 |
|
243 |
if self.w_load is None:
|
244 |
# Generate random latents.
|
245 |
-
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(
|
|
|
246 |
|
247 |
# Run mapping network.
|
248 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
249 |
-
w = G.mapping(z, label, truncation_psi=trunc_psi,
|
|
|
250 |
else:
|
251 |
w = self.w_load.clone().to(self._device)
|
252 |
|
@@ -270,34 +285,34 @@ class Renderer:
|
|
270 |
print(' Remain feat_refs and points0_pt')
|
271 |
|
272 |
def _render_drag_impl(self, res,
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
G = self.G
|
297 |
ws = self.w
|
298 |
if ws.dim() == 2:
|
299 |
-
ws = ws.unsqueeze(1).repeat(1,6,1)
|
300 |
-
ws = torch.cat([ws[
|
301 |
if hasattr(self, 'points'):
|
302 |
if len(points) != len(self.points):
|
303 |
reset = True
|
@@ -308,7 +323,8 @@ class Renderer:
|
|
308 |
|
309 |
# Run synthesis network.
|
310 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
311 |
-
img, feat = G(ws, label, truncation_psi=trunc_psi,
|
|
|
312 |
|
313 |
h, w = G.img_resolution, G.img_resolution
|
314 |
|
@@ -316,14 +332,17 @@ class Renderer:
|
|
316 |
X = torch.linspace(0, h, h)
|
317 |
Y = torch.linspace(0, w, w)
|
318 |
xx, yy = torch.meshgrid(X, Y)
|
319 |
-
feat_resize = F.interpolate(
|
|
|
320 |
if self.feat_refs is None:
|
321 |
-
self.feat0_resize = F.interpolate(
|
|
|
322 |
self.feat_refs = []
|
323 |
for point in points:
|
324 |
py, px = round(point[0]), round(point[1])
|
325 |
-
self.feat_refs.append(self.feat0_resize[
|
326 |
-
self.points0_pt = torch.Tensor(points).unsqueeze(
|
|
|
327 |
|
328 |
# Point tracking with feature matching
|
329 |
with torch.no_grad():
|
@@ -333,11 +352,13 @@ class Renderer:
|
|
333 |
down = min(point[0] + r + 1, h)
|
334 |
left = max(point[1] - r, 0)
|
335 |
right = min(point[1] + r + 1, w)
|
336 |
-
feat_patch = feat_resize[
|
337 |
-
L2 = torch.linalg.norm(
|
338 |
-
|
|
|
339 |
width = right - left
|
340 |
-
point = [idx.item() // width + up, idx.item() %
|
|
|
341 |
points[j] = point
|
342 |
|
343 |
res.points = [[point[0], point[1]] for point in points]
|
@@ -346,24 +367,31 @@ class Renderer:
|
|
346 |
loss_motion = 0
|
347 |
res.stop = True
|
348 |
for j, point in enumerate(points):
|
349 |
-
direction = torch.Tensor(
|
|
|
350 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
351 |
res.stop = False
|
352 |
if torch.linalg.norm(direction) > 1:
|
353 |
-
distance = (
|
|
|
354 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
355 |
-
direction = direction /
|
|
|
356 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
357 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
358 |
-
grid = torch.stack(
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
361 |
|
362 |
loss = loss_motion
|
363 |
if mask is not None:
|
364 |
if mask.min() == 0 and mask.max() == 1:
|
365 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
366 |
-
loss_fix = F.l1_loss(
|
|
|
367 |
loss += lambda_mask * loss_fix
|
368 |
|
369 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
@@ -375,9 +403,11 @@ class Renderer:
|
|
375 |
# Scale and convert to uint8.
|
376 |
img = img[0]
|
377 |
if img_normalize:
|
378 |
-
img = img / img.norm(float('inf'),
|
|
|
379 |
img = img * (10 ** (img_scale_db / 20))
|
380 |
-
img = (img * 127.5 + 128).clamp(0,
|
|
|
381 |
if to_pil:
|
382 |
from PIL import Image
|
383 |
img = img.cpu().numpy()
|
@@ -385,4 +415,4 @@ class Renderer:
|
|
385 |
res.image = img
|
386 |
res.w = ws.detach().cpu().numpy()
|
387 |
|
388 |
-
|
|
|
20 |
import matplotlib.cm
|
21 |
import dnnlib
|
22 |
from torch_utils.ops import upfirdn2d
|
23 |
+
import legacy # pylint: disable=import-error
|
24 |
+
|
25 |
+
# ----------------------------------------------------------------------------
|
26 |
|
|
|
27 |
|
28 |
class CapturedException(Exception):
|
29 |
def __init__(self, msg=None):
|
|
|
37 |
assert isinstance(msg, str)
|
38 |
super().__init__(msg)
|
39 |
|
40 |
+
# ----------------------------------------------------------------------------
|
41 |
+
|
42 |
|
43 |
class CaptureSuccess(Exception):
|
44 |
def __init__(self, out):
|
45 |
super().__init__()
|
46 |
self.out = out
|
47 |
|
48 |
+
# ----------------------------------------------------------------------------
|
49 |
+
|
50 |
|
51 |
def add_watermark_np(input_image_array, watermark_text="AI Generated"):
|
52 |
image = Image.fromarray(np.uint8(input_image_array)).convert("RGBA")
|
|
|
57 |
d = ImageDraw.Draw(txt)
|
58 |
|
59 |
text_width, text_height = font.getsize(watermark_text)
|
60 |
+
text_position = (image.size[0] - text_width -
|
61 |
+
10, image.size[1] - text_height - 10)
|
62 |
+
# white color with the alpha channel set to semi-transparent
|
63 |
+
text_color = (255, 255, 255, 128)
|
64 |
|
65 |
# Draw the text onto the text canvas
|
66 |
d.text(text_position, watermark_text, font=font, fill=text_color)
|
|
|
70 |
watermarked_array = np.array(watermarked)
|
71 |
return watermarked_array
|
72 |
|
73 |
+
# ----------------------------------------------------------------------------
|
74 |
+
|
75 |
|
76 |
class Renderer:
|
77 |
def __init__(self, disable_timing=False):
|
78 |
+
self._device = torch.device('cuda' if torch.cuda.is_available(
|
79 |
+
) else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
80 |
+
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
|
81 |
+
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
|
82 |
+
self._networks = dict() # {cache_key: torch.nn.Module, ...}
|
83 |
+
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
|
84 |
+
self._cmaps = dict() # {name: torch.Tensor, ...}
|
85 |
+
self._is_timing = False
|
86 |
if not disable_timing:
|
87 |
+
self._start_event = torch.cuda.Event(enable_timing=True)
|
88 |
+
self._end_event = torch.cuda.Event(enable_timing=True)
|
89 |
self._disable_timing = disable_timing
|
90 |
+
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...}
|
91 |
|
92 |
def render(self, **args):
|
93 |
if self._disable_timing:
|
|
|
133 |
|
134 |
if self._is_timing and not self._disable_timing:
|
135 |
self._end_event.synchronize()
|
136 |
+
res.render_time = self._start_event.elapsed_time(
|
137 |
+
self._end_event) * 1e-3
|
138 |
self._is_timing = False
|
139 |
return res
|
140 |
|
|
|
155 |
raise data
|
156 |
|
157 |
orig_net = data[key]
|
158 |
+
cache_key = (orig_net, self._device, tuple(
|
159 |
+
sorted(tweak_kwargs.items())))
|
160 |
net = self._networks.get(cache_key, None)
|
161 |
if net is None:
|
162 |
try:
|
|
|
172 |
print(data[key].init_args)
|
173 |
print(data[key].init_kwargs)
|
174 |
if 'stylegan_human' in pkl:
|
175 |
+
net = Generator(
|
176 |
+
*data[key].init_args, **data[key].init_kwargs, square=False, padding=True)
|
177 |
else:
|
178 |
+
net = Generator(*data[key].init_args,
|
179 |
+
**data[key].init_kwargs)
|
180 |
net.load_state_dict(data[key].state_dict())
|
181 |
net.to(self._device)
|
182 |
except:
|
|
|
217 |
return x
|
218 |
|
219 |
def init_network(self, res,
|
220 |
+
pkl=None,
|
221 |
+
w0_seed=0,
|
222 |
+
w_load=None,
|
223 |
+
w_plus=True,
|
224 |
+
noise_mode='const',
|
225 |
+
trunc_psi=0.7,
|
226 |
+
trunc_cutoff=None,
|
227 |
+
input_transform=None,
|
228 |
+
lr=0.001,
|
229 |
+
**kwargs
|
230 |
+
):
|
231 |
# Dig up network details.
|
232 |
self.pkl = pkl
|
233 |
G = self.get_network(pkl, 'G_ema')
|
234 |
self.G = G
|
235 |
res.img_resolution = G.img_resolution
|
236 |
res.num_ws = G.num_ws
|
237 |
+
res.has_noise = any('noise_const' in name for name,
|
238 |
+
_buf in G.synthesis.named_buffers())
|
239 |
+
res.has_input_transform = (
|
240 |
+
hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform'))
|
241 |
+
res.stop = False
|
242 |
# Set input transform.
|
243 |
if res.has_input_transform:
|
244 |
m = np.eye(3)
|
|
|
255 |
|
256 |
if self.w_load is None:
|
257 |
# Generate random latents.
|
258 |
+
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(
|
259 |
+
1, 512)).to(self._device, dtype=self._dtype)
|
260 |
|
261 |
# Run mapping network.
|
262 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
263 |
+
w = G.mapping(z, label, truncation_psi=trunc_psi,
|
264 |
+
truncation_cutoff=trunc_cutoff)
|
265 |
else:
|
266 |
w = self.w_load.clone().to(self._device)
|
267 |
|
|
|
285 |
print(' Remain feat_refs and points0_pt')
|
286 |
|
287 |
def _render_drag_impl(self, res,
|
288 |
+
points=[],
|
289 |
+
targets=[],
|
290 |
+
mask=None,
|
291 |
+
lambda_mask=10,
|
292 |
+
reg=0,
|
293 |
+
feature_idx=5,
|
294 |
+
r1=3,
|
295 |
+
r2=12,
|
296 |
+
random_seed=0,
|
297 |
+
noise_mode='const',
|
298 |
+
trunc_psi=0.7,
|
299 |
+
force_fp32=False,
|
300 |
+
layer_name=None,
|
301 |
+
sel_channels=3,
|
302 |
+
base_channel=0,
|
303 |
+
img_scale_db=0,
|
304 |
+
img_normalize=False,
|
305 |
+
untransform=False,
|
306 |
+
is_drag=False,
|
307 |
+
reset=False,
|
308 |
+
to_pil=False,
|
309 |
+
**kwargs
|
310 |
+
):
|
311 |
G = self.G
|
312 |
ws = self.w
|
313 |
if ws.dim() == 2:
|
314 |
+
ws = ws.unsqueeze(1).repeat(1, 6, 1)
|
315 |
+
ws = torch.cat([ws[:, :6, :], self.w0[:, 6:, :]], dim=1)
|
316 |
if hasattr(self, 'points'):
|
317 |
if len(points) != len(self.points):
|
318 |
reset = True
|
|
|
323 |
|
324 |
# Run synthesis network.
|
325 |
label = torch.zeros([1, G.c_dim], device=self._device)
|
326 |
+
img, feat = G(ws, label, truncation_psi=trunc_psi,
|
327 |
+
noise_mode=noise_mode, input_is_w=True, return_feature=True)
|
328 |
|
329 |
h, w = G.img_resolution, G.img_resolution
|
330 |
|
|
|
332 |
X = torch.linspace(0, h, h)
|
333 |
Y = torch.linspace(0, w, w)
|
334 |
xx, yy = torch.meshgrid(X, Y)
|
335 |
+
feat_resize = F.interpolate(
|
336 |
+
feat[feature_idx], [h, w], mode='bilinear')
|
337 |
if self.feat_refs is None:
|
338 |
+
self.feat0_resize = F.interpolate(
|
339 |
+
feat[feature_idx].detach(), [h, w], mode='bilinear')
|
340 |
self.feat_refs = []
|
341 |
for point in points:
|
342 |
py, px = round(point[0]), round(point[1])
|
343 |
+
self.feat_refs.append(self.feat0_resize[:, :, py, px])
|
344 |
+
self.points0_pt = torch.Tensor(points).unsqueeze(
|
345 |
+
0).to(self._device) # 1, N, 2
|
346 |
|
347 |
# Point tracking with feature matching
|
348 |
with torch.no_grad():
|
|
|
352 |
down = min(point[0] + r + 1, h)
|
353 |
left = max(point[1] - r, 0)
|
354 |
right = min(point[1] + r + 1, w)
|
355 |
+
feat_patch = feat_resize[:, :, up:down, left:right]
|
356 |
+
L2 = torch.linalg.norm(
|
357 |
+
feat_patch - self.feat_refs[j].reshape(1, -1, 1, 1), dim=1)
|
358 |
+
_, idx = torch.min(L2.view(1, -1), -1)
|
359 |
width = right - left
|
360 |
+
point = [idx.item() // width + up, idx.item() %
|
361 |
+
width + left]
|
362 |
points[j] = point
|
363 |
|
364 |
res.points = [[point[0], point[1]] for point in points]
|
|
|
367 |
loss_motion = 0
|
368 |
res.stop = True
|
369 |
for j, point in enumerate(points):
|
370 |
+
direction = torch.Tensor(
|
371 |
+
[targets[j][1] - point[1], targets[j][0] - point[0]])
|
372 |
if torch.linalg.norm(direction) > max(2 / 512 * h, 2):
|
373 |
res.stop = False
|
374 |
if torch.linalg.norm(direction) > 1:
|
375 |
+
distance = (
|
376 |
+
(xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5
|
377 |
relis, reljs = torch.where(distance < round(r1 / 512 * h))
|
378 |
+
direction = direction / \
|
379 |
+
(torch.linalg.norm(direction) + 1e-7)
|
380 |
gridh = (relis-direction[1]) / (h-1) * 2 - 1
|
381 |
gridw = (reljs-direction[0]) / (w-1) * 2 - 1
|
382 |
+
grid = torch.stack(
|
383 |
+
[gridw, gridh], dim=-1).unsqueeze(0).unsqueeze(0)
|
384 |
+
target = F.grid_sample(
|
385 |
+
feat_resize.float(), grid, align_corners=True).squeeze(2)
|
386 |
+
loss_motion += F.l1_loss(
|
387 |
+
feat_resize[:, :, relis, reljs], target.detach())
|
388 |
|
389 |
loss = loss_motion
|
390 |
if mask is not None:
|
391 |
if mask.min() == 0 and mask.max() == 1:
|
392 |
mask_usq = mask.to(self._device).unsqueeze(0).unsqueeze(0)
|
393 |
+
loss_fix = F.l1_loss(
|
394 |
+
feat_resize * mask_usq, self.feat0_resize * mask_usq)
|
395 |
loss += lambda_mask * loss_fix
|
396 |
|
397 |
loss += reg * F.l1_loss(ws, self.w0) # latent code regularization
|
|
|
403 |
# Scale and convert to uint8.
|
404 |
img = img[0]
|
405 |
if img_normalize:
|
406 |
+
img = img / img.norm(float('inf'),
|
407 |
+
dim=[1, 2], keepdim=True).clip(1e-8, 1e8)
|
408 |
img = img * (10 ** (img_scale_db / 20))
|
409 |
+
img = (img * 127.5 + 128).clamp(0,
|
410 |
+
255).to(torch.uint8).permute(1, 2, 0)
|
411 |
if to_pil:
|
412 |
from PIL import Image
|
413 |
img = img.cpu().numpy()
|
|
|
415 |
res.image = img
|
416 |
res.w = ws.detach().cpu().numpy()
|
417 |
|
418 |
+
# ----------------------------------------------------------------------------
|