daddyjin commited on
Commit
0ed986f
1 Parent(s): ec68eda

edit util.py and gradio_demo.py

Browse files
Files changed (1) hide show
  1. FONT/modules/util.py +12 -4
FONT/modules/util.py CHANGED
@@ -306,9 +306,9 @@ class AT_net(nn.Module):
306
 
307
 
308
  class AT_net2(nn.Module):
309
- def __init__(self):
310
  super(AT_net2, self).__init__()
311
-
312
  down_blocks = []
313
  for i in range(8):
314
  down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i),
@@ -372,8 +372,16 @@ class AT_net2(nn.Module):
372
 
373
 
374
  def forward(self, example_image, audio, pose, jaco_net, weight):
375
- hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),
376
- torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))
 
 
 
 
 
 
 
 
377
  outs = example_image
378
  for down_block in self.down_blocks:
379
  outs = down_block(outs)
 
306
 
307
 
308
  class AT_net2(nn.Module):
309
+ def __init__(self, device):
310
  super(AT_net2, self).__init__()
311
+ self.device = device
312
  down_blocks = []
313
  for i in range(8):
314
  down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i),
 
372
 
373
 
374
  def forward(self, example_image, audio, pose, jaco_net, weight):
375
+
376
+ hidden_ele1 = torch.zeros(3, audio.size(0), 256)
377
+ hidden_ele2 = torch.zeros(3, audio.size(0), 256)
378
+
379
+ if self.device == 'cuda':
380
+ hidden_ele1 = hidden_ele1.cuda()
381
+ hidden_ele2 = hidden_ele2.cuda()
382
+
383
+
384
+ hidden = (torch.autograd.Variable(hidden_ele1), torch.autograd.Variable(hidden_ele2) )
385
  outs = example_image
386
  for down_block in self.down_blocks:
387
  outs = down_block(outs)