cocktailpeanut commited on
Commit
13d3a21
1 Parent(s): 30b3bf7
Files changed (1) hide show
  1. tokenflow_pnp.py +2 -1
tokenflow_pnp.py CHANGED
@@ -28,6 +28,7 @@ elif torch.backends.mps.is_available():
28
  device = "mps"
29
  else:
30
  device = "cpu"
 
31
 
32
  class TokenFlow(nn.Module):
33
  def __init__(self, config,
@@ -117,7 +118,7 @@ class TokenFlow(nn.Module):
117
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
118
  depth_maps.append(depth_map)
119
 
120
- return torch.cat(depth_maps).to(self.to).to(self.device)
121
 
122
  def get_pnp_inversion_prompt(self):
123
  inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
 
28
  device = "mps"
29
  else:
30
  device = "cpu"
31
+ to = torch.float16 if self.device == 'cuda' else torch.float32
32
 
33
  class TokenFlow(nn.Module):
34
  def __init__(self, config,
 
118
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
119
  depth_maps.append(depth_map)
120
 
121
+ return torch.cat(depth_maps).to(to).to(self.device)
122
 
123
  def get_pnp_inversion_prompt(self):
124
  inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')