Spaces:
Runtime error
Runtime error
edit util.py and gradio_demo.py
Browse files- FONT/modules/util.py +12 -4
FONT/modules/util.py
CHANGED
@@ -306,9 +306,9 @@ class AT_net(nn.Module):
|
|
306 |
|
307 |
|
308 |
class AT_net2(nn.Module):
|
309 |
-
def __init__(self):
|
310 |
super(AT_net2, self).__init__()
|
311 |
-
|
312 |
down_blocks = []
|
313 |
for i in range(8):
|
314 |
down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i),
|
@@ -372,8 +372,16 @@ class AT_net2(nn.Module):
|
|
372 |
|
373 |
|
374 |
def forward(self, example_image, audio, pose, jaco_net, weight):
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
outs = example_image
|
378 |
for down_block in self.down_blocks:
|
379 |
outs = down_block(outs)
|
|
|
306 |
|
307 |
|
308 |
class AT_net2(nn.Module):
|
309 |
+
def __init__(self, device):
|
310 |
super(AT_net2, self).__init__()
|
311 |
+
self.device = device
|
312 |
down_blocks = []
|
313 |
for i in range(8):
|
314 |
down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i),
|
|
|
372 |
|
373 |
|
374 |
def forward(self, example_image, audio, pose, jaco_net, weight):
|
375 |
+
|
376 |
+
hidden_ele1 = torch.zeros(3, audio.size(0), 256)
|
377 |
+
hidden_ele2 = torch.zeros(3, audio.size(0), 256)
|
378 |
+
|
379 |
+
if self.device == 'cuda':
|
380 |
+
hidden_ele1 = hidden_ele1.cuda()
|
381 |
+
hidden_ele2 = hidden_ele2.cuda()
|
382 |
+
|
383 |
+
|
384 |
+
hidden = (torch.autograd.Variable(hidden_ele1), torch.autograd.Variable(hidden_ele2) )
|
385 |
outs = example_image
|
386 |
for down_block in self.down_blocks:
|
387 |
outs = down_block(outs)
|