ciCic commited on
Commit
1e44006
·
1 Parent(s): b8a62eb
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -6,15 +6,14 @@ from diffusers import AutoencoderTiny
6
  from torchvision.transforms.functional import to_pil_image, center_crop, resize, to_tensor
7
 
8
  device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
9
- d_type = torch.float32 if device == 'mps' else torch.float16
10
 
11
  model_id = "madebyollin/taesd"
12
- vae = AutoencoderTiny.from_pretrained(model_id, safetensors=True, torch_dtype=d_type).to(device)
13
 
14
 
15
  @torch.no_grad()
16
  def decode(image):
17
- t = to_tensor(image).unsqueeze(0).to(device, dtype=d_type)
18
  unscaled_t = vae.unscale_latents(t)
19
  reconstructed = vae.decoder(unscaled_t).clamp(0, 1)
20
  return to_pil_image(reconstructed[0])
 
6
  from torchvision.transforms.functional import to_pil_image, center_crop, resize, to_tensor
7
 
8
  device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
 
9
 
10
  model_id = "madebyollin/taesd"
11
+ vae = AutoencoderTiny.from_pretrained(model_id, safetensors=True).to(device)
12
 
13
 
14
  @torch.no_grad()
15
  def decode(image):
16
+ t = to_tensor(image).unsqueeze(0).to(device)
17
  unscaled_t = vae.unscale_latents(t)
18
  reconstructed = vae.decoder(unscaled_t).clamp(0, 1)
19
  return to_pil_image(reconstructed[0])