Yw22 commited on
Commit
5e15821
·
1 Parent(s): 5e39dac
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -237,7 +237,7 @@ def points_to_flows(track_points, model_length, height, width):
237
  class ImageConductor:
238
  def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
239
  self.device = device
240
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer").to(device)
241
  text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
242
  vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
243
  inference_config = OmegaConf.load("configs/inference/inference.yaml")
@@ -338,8 +338,8 @@ class ImageConductor:
338
  os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
339
  trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
340
  torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
341
- controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...]
342
- controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w")
343
 
344
  dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
345
  lora_model_path = LORA.get(personalized, '')
 
237
  class ImageConductor:
238
  def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
239
  self.device = device
240
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
241
  text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
242
  vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
243
  inference_config = OmegaConf.load("configs/inference/inference.yaml")
 
338
  os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
339
  trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
340
  torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
341
+ controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, :self.model_length, ...]
342
+ controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").to(device)
343
 
344
  dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
345
  lora_model_path = LORA.get(personalized, '')