Anton Forsman commited on
Commit
9efd988
·
1 Parent(s): 475978d

fix device issue

Browse files
Files changed (2) hide show
  1. inference.py +2 -1
  2. unet.py +1 -0
inference.py CHANGED
@@ -11,7 +11,8 @@ from unet import Unet, ConditionalUnet
11
 
12
  from diffusion import GaussianDiffusion, DiffusionImageAPI
13
 
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
15
 
16
  def inference1():
17
  # new image from web page
 
11
 
12
  from diffusion import GaussianDiffusion, DiffusionImageAPI
13
 
14
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ device = torch.device("mps")
16
 
17
  def inference1():
18
  # new image from web page
unet.py CHANGED
@@ -414,6 +414,7 @@ class ConditionalUnet(nn.Module):
414
  def forward(self, x, t, cond=None):
415
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
416
  cond = cond.unsqueeze(0)
 
417
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
418
  t = self.unet.time_encoding(t)
419
 
 
414
  def forward(self, x, t, cond=None):
415
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
416
  cond = cond.unsqueeze(0)
417
+ cond = cond.to(self.device)
418
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
419
  t = self.unet.time_encoding(t)
420