cocktailpeanut commited on
Commit
43f619a
1 Parent(s): bbbbd95
Files changed (1) hide show
  1. tokenflow_pnp.py +1 -9
tokenflow_pnp.py CHANGED
@@ -21,14 +21,6 @@ logging.set_verbosity_error()
21
 
22
  VAE_BATCH_SIZE = 10
23
 
24
- if torch.cuda.is_available():
25
- device = "cuda"
26
- elif torch.backends.mps.is_available():
27
- device = "mps"
28
- else:
29
- device = "cpu"
30
- to = torch.float16 if device == 'cuda' else torch.float32
31
-
32
  class TokenFlow(nn.Module):
33
  def __init__(self, config,
34
  pipe,
@@ -275,7 +267,7 @@ class TokenFlow(nn.Module):
275
  denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
276
  return denoised_latent
277
 
278
- @torch.autocast(dtype=to, device_type=device)
279
  def batched_denoise_step(self, x, t, indices):
280
  batch_size = self.config["batch_size"]
281
  denoised_latents = []
 
21
 
22
  VAE_BATCH_SIZE = 10
23
 
 
 
 
 
 
 
 
 
24
  class TokenFlow(nn.Module):
25
  def __init__(self, config,
26
  pipe,
 
267
  denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
268
  return denoised_latent
269
 
270
+ @torch.autocast(dtype=torch.float16, device_type='cuda')
271
  def batched_denoise_step(self, x, t, indices):
272
  batch_size = self.config["batch_size"]
273
  denoised_latents = []