Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
13d3a21
1
Parent(s):
30b3bf7
update
Browse files- tokenflow_pnp.py +2 -1
tokenflow_pnp.py
CHANGED
@@ -28,6 +28,7 @@ elif torch.backends.mps.is_available():
|
|
28 |
device = "mps"
|
29 |
else:
|
30 |
device = "cpu"
|
|
|
31 |
|
32 |
class TokenFlow(nn.Module):
|
33 |
def __init__(self, config,
|
@@ -117,7 +118,7 @@ class TokenFlow(nn.Module):
|
|
117 |
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
|
118 |
depth_maps.append(depth_map)
|
119 |
|
120 |
-
return torch.cat(depth_maps).to(
|
121 |
|
122 |
def get_pnp_inversion_prompt(self):
|
123 |
inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
|
|
|
28 |
device = "mps"
|
29 |
else:
|
30 |
device = "cpu"
|
31 |
+
to = torch.float16 if self.device == 'cuda' else torch.float32
|
32 |
|
33 |
class TokenFlow(nn.Module):
|
34 |
def __init__(self, config,
|
|
|
118 |
depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
|
119 |
depth_maps.append(depth_map)
|
120 |
|
121 |
+
return torch.cat(depth_maps).to(to).to(self.device)
|
122 |
|
123 |
def get_pnp_inversion_prompt(self):
|
124 |
inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
|