cocktailpeanut commited on
Commit
30b3bf7
1 Parent(s): 21e2861

float16/float32

Browse files
Files changed (3) hide show
  1. app.py +6 -4
  2. preprocess_utils.py +7 -6
  3. tokenflow_pnp.py +13 -5
app.py CHANGED
@@ -16,18 +16,20 @@ else:
16
  device = "cpu"
17
  model_id = "stabilityai/stable-diffusion-2-1-base"
18
 
 
 
19
  # components for the Preprocessor
20
  scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
21
  vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", revision="fp16",
22
- torch_dtype=torch.float16).to(device)
23
  tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
24
  text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision="fp16",
25
- torch_dtype=torch.float16).to(device)
26
  unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision="fp16",
27
- torch_dtype=torch.float16).to(device)
28
 
29
  # pipe for TokenFlow
30
- tokenflow_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
31
  if device == "cuda":
32
  tokenflow_pipe.enable_xformers_memory_efficient_attention()
33
 
 
16
  device = "cpu"
17
  model_id = "stabilityai/stable-diffusion-2-1-base"
18
 
19
+ to = torch.float16 if self.device == 'cuda' else torch.float32
20
+
21
  # components for the Preprocessor
22
  scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
23
  vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", revision="fp16",
24
+ torch_dtype=to).to(device)
25
  tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
26
  text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision="fp16",
27
+ torch_dtype=to).to(device)
28
  unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision="fp16",
29
+ torch_dtype=to).to(device)
30
 
31
  # pipe for TokenFlow
32
+ tokenflow_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=to).to(device)
33
  if device == "cuda":
34
  tokenflow_pipe.enable_xformers_memory_efficient_attention()
35
 
preprocess_utils.py CHANGED
@@ -29,6 +29,7 @@ class Preprocess(nn.Module):
29
  super().__init__()
30
 
31
  self.device = device
 
32
  self.sd_version = opt["sd_version"]
33
  self.use_depth = False
34
  self.config = opt
@@ -73,9 +74,9 @@ class Preprocess(nn.Module):
73
 
74
  if self.sd_version == 'ControlNet':
75
  from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
76
- controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(self.device)
77
  control_pipe = StableDiffusionControlNetPipeline.from_pretrained(
78
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
79
  ).to(self.device)
80
  self.unet = control_pipe.unet
81
  self.controlnet = control_pipe.controlnet
@@ -124,7 +125,7 @@ class Preprocess(nn.Module):
124
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
125
  depth_maps.append(depth_map)
126
 
127
- return torch.cat(depth_maps).to(self.device).to(torch.float16)
128
 
129
  @torch.no_grad()
130
  def get_canny_cond(self):
@@ -139,7 +140,7 @@ class Preprocess(nn.Module):
139
  image = np.concatenate([image, image, image], axis=2)
140
  image = torch.from_numpy((image.astype(np.float32) / 255.0))
141
  canny_cond.append(image)
142
- canny_cond = torch.stack(canny_cond).permute(0, 3, 1, 2).to(self.device).to(torch.float16)
143
  return canny_cond
144
 
145
  def controlnet_pred(self, latent_model_input, t, text_embed_input, controlnet_cond):
@@ -211,9 +212,9 @@ class Preprocess(nn.Module):
211
  frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
212
  else:
213
  frames = self.config["frames"][:n_frames]
214
- frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)
215
  # encode to latents
216
- latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
217
  print("frames", frames.shape)
218
  print("latents", latents.shape)
219
 
 
29
  super().__init__()
30
 
31
  self.device = device
32
+ self.to = torch.float16 if self.device == 'cuda' else torch.float32
33
  self.sd_version = opt["sd_version"]
34
  self.use_depth = False
35
  self.config = opt
 
74
 
75
  if self.sd_version == 'ControlNet':
76
  from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
77
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=self.to).to(self.device)
78
  control_pipe = StableDiffusionControlNetPipeline.from_pretrained(
79
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=self.to
80
  ).to(self.device)
81
  self.unet = control_pipe.unet
82
  self.controlnet = control_pipe.controlnet
 
125
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
126
  depth_maps.append(depth_map)
127
 
128
+ return torch.cat(depth_maps).to(self.device).to(self.to)
129
 
130
  @torch.no_grad()
131
  def get_canny_cond(self):
 
140
  image = np.concatenate([image, image, image], axis=2)
141
  image = torch.from_numpy((image.astype(np.float32) / 255.0))
142
  canny_cond.append(image)
143
+ canny_cond = torch.stack(canny_cond).permute(0, 3, 1, 2).to(self.device).to(self.to)
144
  return canny_cond
145
 
146
  def controlnet_pred(self, latent_model_input, t, text_embed_input, controlnet_cond):
 
212
  frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
213
  else:
214
  frames = self.config["frames"][:n_frames]
215
+ frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(self.to).to(self.device)
216
  # encode to latents
217
+ latents = self.encode_imgs(frames, deterministic=True).to(self.to).to(self.device)
218
  print("frames", frames.shape)
219
  print("latents", latents.shape)
220
 
tokenflow_pnp.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import glob
2
  import os
3
  import numpy as np
@@ -21,6 +22,12 @@ logging.set_verbosity_error()
21
 
22
  VAE_BATCH_SIZE = 10
23
 
 
 
 
 
 
 
24
 
25
  class TokenFlow(nn.Module):
26
  def __init__(self, config,
@@ -31,6 +38,7 @@ class TokenFlow(nn.Module):
31
  super().__init__()
32
  self.config = config
33
  self.device = config["device"]
 
34
 
35
  sd_version = config["sd_version"]
36
  self.sd_version = sd_version
@@ -109,7 +117,7 @@ class TokenFlow(nn.Module):
109
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
110
  depth_maps.append(depth_map)
111
 
112
- return torch.cat(depth_maps).to(torch.float16).to(self.device)
113
 
114
  def get_pnp_inversion_prompt(self):
115
  inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
@@ -197,16 +205,16 @@ class TokenFlow(nn.Module):
197
  frames = [Image.open(paths[idx]).convert('RGB') for idx in range(self.config["n_frames"])]
198
  if frames[0].size[0] == frames[0].size[1]:
199
  frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
200
- frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)
201
  save_video(frames, f'{self.config["output_path"]}/input_fps10.mp4', fps=10)
202
  save_video(frames, f'{self.config["output_path"]}/input_fps20.mp4', fps=20)
203
  save_video(frames, f'{self.config["output_path"]}/input_fps30.mp4', fps=30)
204
  else:
205
  frames = self.frames
206
  # encode to latents
207
- latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
208
  # get noise
209
- eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
210
  if not read_from_files:
211
  return None, frames, latents, eps
212
  return paths, frames, latents, eps
@@ -267,7 +275,7 @@ class TokenFlow(nn.Module):
267
  denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
268
  return denoised_latent
269
 
270
- @torch.autocast(dtype=torch.float16, device_type='cuda')
271
  def batched_denoise_step(self, x, t, indices):
272
  batch_size = self.config["batch_size"]
273
  denoised_latents = []
 
1
+ import torch
2
  import glob
3
  import os
4
  import numpy as np
 
22
 
23
  VAE_BATCH_SIZE = 10
24
 
25
+ if torch.cuda.is_available():
26
+ device = "cuda"
27
+ elif torch.backends.mps.is_available():
28
+ device = "mps"
29
+ else:
30
+ device = "cpu"
31
 
32
  class TokenFlow(nn.Module):
33
  def __init__(self, config,
 
38
  super().__init__()
39
  self.config = config
40
  self.device = config["device"]
41
+ self.to = torch.float16 if self.device == 'cuda' else torch.float32
42
 
43
  sd_version = config["sd_version"]
44
  self.sd_version = sd_version
 
117
  depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
118
  depth_maps.append(depth_map)
119
 
120
+ return torch.cat(depth_maps).to(self.to).to(self.device)
121
 
122
  def get_pnp_inversion_prompt(self):
123
  inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
 
205
  frames = [Image.open(paths[idx]).convert('RGB') for idx in range(self.config["n_frames"])]
206
  if frames[0].size[0] == frames[0].size[1]:
207
  frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
208
+ frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(self.to).to(self.device)
209
  save_video(frames, f'{self.config["output_path"]}/input_fps10.mp4', fps=10)
210
  save_video(frames, f'{self.config["output_path"]}/input_fps20.mp4', fps=20)
211
  save_video(frames, f'{self.config["output_path"]}/input_fps30.mp4', fps=30)
212
  else:
213
  frames = self.frames
214
  # encode to latents
215
+ latents = self.encode_imgs(frames, deterministic=True).to(self.to).to(self.device)
216
  # get noise
217
+ eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(self.to).to(self.device)
218
  if not read_from_files:
219
  return None, frames, latents, eps
220
  return paths, frames, latents, eps
 
275
  denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
276
  return denoised_latent
277
 
278
+ @torch.autocast(dtype=self.to, device_type=device)
279
  def batched_denoise_step(self, x, t, indices):
280
  batch_size = self.config["batch_size"]
281
  denoised_latents = []