gavinyuan commited on
Commit
127df95
1 Parent(s): 82206b3

update: mouth helper

Browse files
Files changed (1) hide show
  1. inference/tricks.py +1 -1
inference/tricks.py CHANGED
@@ -65,7 +65,7 @@ class Trick(object):
65
 
66
  @staticmethod
67
  def arr_to_tensor(arr, norm: bool = True):
68
- tensor = torch.tensor(arr, dtype=torch.float).cuda() / 255 # in [0,1]
69
  tensor = (tensor - 0.5) / 0.5 if norm else tensor # in [-1,1]
70
  tensor = tensor.permute(0, 3, 1, 2)
71
  return tensor
 
65
 
66
  @staticmethod
67
  def arr_to_tensor(arr, norm: bool = True):
68
+ tensor = torch.tensor(arr, dtype=torch.float).to(global_device) / 255 # in [0,1]
69
  tensor = (tensor - 0.5) / 0.5 if norm else tensor # in [-1,1]
70
  tensor = tensor.permute(0, 3, 1, 2)
71
  return tensor