Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
·
9efd988
1
Parent(s):
475978d
fix device issue
Browse files- inference.py +2 -1
- unet.py +1 -0
inference.py
CHANGED
@@ -11,7 +11,8 @@ from unet import Unet, ConditionalUnet
|
|
11 |
|
12 |
from diffusion import GaussianDiffusion, DiffusionImageAPI
|
13 |
|
14 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
15 |
|
16 |
def inference1():
|
17 |
# new image from web page
|
|
|
11 |
|
12 |
from diffusion import GaussianDiffusion, DiffusionImageAPI
|
13 |
|
14 |
+
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
device = torch.device("mps")
|
16 |
|
17 |
def inference1():
|
18 |
# new image from web page
|
unet.py
CHANGED
@@ -414,6 +414,7 @@ class ConditionalUnet(nn.Module):
|
|
414 |
def forward(self, x, t, cond=None):
|
415 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
416 |
cond = cond.unsqueeze(0)
|
|
|
417 |
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
|
418 |
t = self.unet.time_encoding(t)
|
419 |
|
|
|
414 |
def forward(self, x, t, cond=None):
|
415 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
416 |
cond = cond.unsqueeze(0)
|
417 |
+
cond = cond.to(self.device)
|
418 |
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
|
419 |
t = self.unet.time_encoding(t)
|
420 |
|