Doubiiu commited on
Commit
b5d93b2
·
verified ·
1 Parent(s): 7f31be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -65,34 +65,34 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
65
  noise_shape = [batch_size, channels, frames, h, w]
66
 
67
  # text cond
68
- text_emb = model.get_learned_conditioning([prompt])
69
-
70
- # img cond
71
- img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
72
- img_tensor = (img_tensor / 255. - 0.5) * 2
73
-
74
- image_tensor_resized = transform(img_tensor) #3,256,256
75
- videos = image_tensor_resized.unsqueeze(0) # bchw
76
 
77
- z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
 
 
78
 
79
- img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
80
-
81
- cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
82
- img_emb = model.image_proj_model(cond_images)
83
-
84
- imtext_cond = torch.cat([text_emb, img_emb], dim=1)
85
-
86
- fs = torch.tensor([fs], dtype=torch.long, device=model.device)
87
- cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
88
 
89
- ## inference
90
- with torch.no_grad(), torch.cuda.amp.autocast():
 
 
 
 
 
 
 
91
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
92
- ## b,samples,c,t,h,w
93
-
94
- video_path = './output.mp4'
95
- save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
96
  model = model.cpu()
97
  return video_path
98
 
 
65
  noise_shape = [batch_size, channels, frames, h, w]
66
 
67
  # text cond
68
+ with torch.no_grad(), torch.cuda.amp.autocast():
69
+ text_emb = model.get_learned_conditioning([prompt])
 
 
 
 
 
 
70
 
71
+ # img cond
72
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
73
+ img_tensor = (img_tensor / 255. - 0.5) * 2
74
 
75
+ image_tensor_resized = transform(img_tensor) #3,256,256
76
+ videos = image_tensor_resized.unsqueeze(0) # bchw
77
+
78
+ z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
79
+
80
+ img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
 
 
 
81
 
82
+ cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
83
+ img_emb = model.image_proj_model(cond_images)
84
+
85
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
86
+
87
+ fs = torch.tensor([fs], dtype=torch.long, device=model.device)
88
+ cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
89
+
90
+ ## inference
91
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
92
+ ## b,samples,c,t,h,w
93
+
94
+ video_path = './output.mp4'
95
+ save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
96
  model = model.cpu()
97
  return video_path
98