fbnnb commited on
Commit
ae7105e
Β·
verified Β·
1 Parent(s): df4181d

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +141 -140
gradio_app.py CHANGED
@@ -132,152 +132,153 @@ def untranspose(tensor):
132
  def get_image(image1, prompt, image2, dim_steps=50, ddim_eta=1., fs=None, seed=123, \
133
  unconditional_guidance_scale=1.0, cfg_img=None, text_input=False, multiple_cond_cfg=False, \
134
  loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, noise_shape=[72, 108], n_samples=1, **kwargs):
135
-
136
- seed_everything(seed)
137
- video_size = (576, 1024)
138
- transform = transforms.Compose([
139
- transforms.Resize(min(video_size)),
140
- transforms.CenterCrop(video_size),
141
- # transforms.ToTensor(),
142
- # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
143
- ])
144
-
145
- image1 = torch.from_numpy(image1).permute(2, 0, 1).float().cuda()
146
- input_h, input_w = image1.shape[1:]
147
- image1 = (image1 / 255. - 0.5) * 2
148
-
149
- image2 = torch.from_numpy(image2).permute(2, 0, 1).float().cuda()
150
- input_h, input_w = image2.shape[1:]
151
- image2 = (image2 / 255. - 0.5) * 2
152
 
 
 
 
 
 
 
 
 
 
153
 
154
- # image1 = Image.open(file_list[2*idx]).convert('RGB')
155
- image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
156
- # image2 = Image.open(file_list[2*idx+1]).convert('RGB')
157
- image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
158
- frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=8)
159
- frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=8)
160
- videos = torch.cat([frame_tensor1, frame_tensor2], dim=1).unsqueeze(0)
161
- # frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
162
- # _, filename = os.path.split(file_list[idx*2])
163
-
164
- global model
165
- model.cuda()
166
 
167
- ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
168
- batch_size = 1
169
- fs = torch.tensor([fs], dtype=torch.long, device=model.device)
170
-
171
- if not text_input:
172
- prompts = [""]*batch_size
173
-
174
- img = videos[:,:,0] #bchw
175
- img_emb = model.embedder(img) ## blc
176
- img_emb = model.image_proj_model(img_emb)
177
-
178
- cond_emb = model.get_learned_conditioning(prompts)
179
- cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
180
- if model.model.conditioning_key == 'hybrid':
181
- z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
182
- if loop or interp:
183
- img_cat_cond = torch.zeros_like(z)
184
- img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
185
- img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
186
- else:
187
- img_cat_cond = z[:,:,:1,:,:]
188
- img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
189
- cond["c_concat"] = [img_cat_cond] # b c 1 h w
190
 
191
- if unconditional_guidance_scale != 1.0:
192
- if model.uncond_type == "empty_seq":
193
- prompts = batch_size * [""]
194
- uc_emb = model.get_learned_conditioning(prompts)
195
- elif model.uncond_type == "zero_embed":
196
- uc_emb = torch.zeros_like(cond_emb)
197
- uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
198
- uc_img_emb = model.image_proj_model(uc_img_emb)
199
- uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
200
- if model.model.conditioning_key == 'hybrid':
201
- uc["c_concat"] = [img_cat_cond]
202
- else:
203
- uc = None
204
- #
205
- # for i, h in enumerate(hs):
206
- # print("h:", h.shape)
207
- # hs[i] = hs[i][:,:,0,:,:].unsqueeze(2)
208
- additional_decode_kwargs = {'ref_context': hs}
209
- # additional_decode_kwargs = {'ref_context': None}
210
-
211
- ## we need one more unconditioning image=yes, text=""
212
- if multiple_cond_cfg and cfg_img != 1.0:
213
- uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
 
 
 
 
214
  if model.model.conditioning_key == 'hybrid':
215
- uc_2["c_concat"] = [img_cat_cond]
216
- kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
217
- else:
218
- kwargs.update({"unconditional_conditioning_img_nonetext": None})
219
-
220
- z0 = None
221
- cond_mask = None
222
-
223
- batch_variants = []
224
- for _ in range(n_samples):
225
-
226
- if z0 is not None:
227
- cond_z0 = z0.clone()
228
- kwargs.update({"clean_cond": True})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  else:
230
- cond_z0 = None
231
- if ddim_sampler is not None:
232
-
233
- samples, _ = ddim_sampler.sample(S=ddim_steps,
234
- conditioning=cond,
235
- batch_size=batch_size,
236
- shape=noise_shape,
237
- verbose=False,
238
- unconditional_guidance_scale=unconditional_guidance_scale,
239
- unconditional_conditioning=uc,
240
- eta=ddim_eta,
241
- cfg_img=cfg_img,
242
- mask=cond_mask,
243
- x0=cond_z0,
244
- fs=fs,
245
- timestep_spacing=timestep_spacing,
246
- guidance_rescale=guidance_rescale,
247
- **kwargs
248
- )
249
-
250
- ## reconstruct from latent to pixel space
251
- batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
252
-
253
- index = list(range(samples.shape[2]))
254
- del index[1]
255
- del index[-2]
256
- samples = samples[:,:,index,:,:]
257
- ## reconstruct from latent to pixel space
258
- batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
259
- batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
260
-
261
-
262
-
263
- batch_variants.append(batch_images)
264
- ## variants, batch, c, t, h, w
265
- batch_variants = torch.stack(batch_variants)
266
- # return batch_variants.permute(1, 0, 2, 3, 4, 5)
267
-
268
- prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
269
- prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
270
- prompt_str=prompt_str[:40]
271
- if len(prompt_str) == 0:
272
- prompt_str = 'empty_prompt'
273
-
274
- result_dir = "./tmp/"
275
- save_videos(batch_image, result_dir, filenames=[prompt_str], fps=8)
276
- print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
277
- model = model.cpu()
278
- saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
279
- print("result saved to:", saved_result_dir)
280
- return saved_result_dir
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
 
283
 
 
132
  def get_image(image1, prompt, image2, dim_steps=50, ddim_eta=1., fs=None, seed=123, \
133
  unconditional_guidance_scale=1.0, cfg_img=None, text_input=False, multiple_cond_cfg=False, \
134
  loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, noise_shape=[72, 108], n_samples=1, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ with torch.no_grad():
137
+ seed_everything(seed)
138
+ video_size = (576, 1024)
139
+ transform = transforms.Compose([
140
+ transforms.Resize(min(video_size)),
141
+ transforms.CenterCrop(video_size),
142
+ # transforms.ToTensor(),
143
+ # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
144
+ ])
145
 
146
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float().cuda()
147
+ input_h, input_w = image1.shape[1:]
148
+ image1 = (image1 / 255. - 0.5) * 2
 
 
 
 
 
 
 
 
 
149
 
150
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float().cuda()
151
+ input_h, input_w = image2.shape[1:]
152
+ image2 = (image2 / 255. - 0.5) * 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+
155
+ # image1 = Image.open(file_list[2*idx]).convert('RGB')
156
+ image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
157
+ # image2 = Image.open(file_list[2*idx+1]).convert('RGB')
158
+ image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
159
+ frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=8)
160
+ frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=8)
161
+ videos = torch.cat([frame_tensor1, frame_tensor2], dim=1).unsqueeze(0)
162
+ # frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
163
+ # _, filename = os.path.split(file_list[idx*2])
164
+
165
+ global model
166
+ model.cuda()
167
+
168
+ ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
169
+ batch_size = 1
170
+ fs = torch.tensor([fs], dtype=torch.long, device=model.device)
171
+
172
+ if not text_input:
173
+ prompts = [""]*batch_size
174
+
175
+ img = videos[:,:,0] #bchw
176
+ img_emb = model.embedder(img) ## blc
177
+ img_emb = model.image_proj_model(img_emb)
178
+
179
+ cond_emb = model.get_learned_conditioning(prompts)
180
+ cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
181
  if model.model.conditioning_key == 'hybrid':
182
+ z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
183
+ if loop or interp:
184
+ img_cat_cond = torch.zeros_like(z)
185
+ img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
186
+ img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
187
+ else:
188
+ img_cat_cond = z[:,:,:1,:,:]
189
+ img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
190
+ cond["c_concat"] = [img_cat_cond] # b c 1 h w
191
+
192
+ if unconditional_guidance_scale != 1.0:
193
+ if model.uncond_type == "empty_seq":
194
+ prompts = batch_size * [""]
195
+ uc_emb = model.get_learned_conditioning(prompts)
196
+ elif model.uncond_type == "zero_embed":
197
+ uc_emb = torch.zeros_like(cond_emb)
198
+ uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
199
+ uc_img_emb = model.image_proj_model(uc_img_emb)
200
+ uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
201
+ if model.model.conditioning_key == 'hybrid':
202
+ uc["c_concat"] = [img_cat_cond]
203
+ else:
204
+ uc = None
205
+ #
206
+ # for i, h in enumerate(hs):
207
+ # print("h:", h.shape)
208
+ # hs[i] = hs[i][:,:,0,:,:].unsqueeze(2)
209
+ additional_decode_kwargs = {'ref_context': hs}
210
+ # additional_decode_kwargs = {'ref_context': None}
211
+
212
+ ## we need one more unconditioning image=yes, text=""
213
+ if multiple_cond_cfg and cfg_img != 1.0:
214
+ uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
215
+ if model.model.conditioning_key == 'hybrid':
216
+ uc_2["c_concat"] = [img_cat_cond]
217
+ kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
218
  else:
219
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
220
+
221
+ z0 = None
222
+ cond_mask = None
223
+
224
+ batch_variants = []
225
+ for _ in range(n_samples):
226
+
227
+ if z0 is not None:
228
+ cond_z0 = z0.clone()
229
+ kwargs.update({"clean_cond": True})
230
+ else:
231
+ cond_z0 = None
232
+ if ddim_sampler is not None:
233
+
234
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
235
+ conditioning=cond,
236
+ batch_size=batch_size,
237
+ shape=noise_shape,
238
+ verbose=False,
239
+ unconditional_guidance_scale=unconditional_guidance_scale,
240
+ unconditional_conditioning=uc,
241
+ eta=ddim_eta,
242
+ cfg_img=cfg_img,
243
+ mask=cond_mask,
244
+ x0=cond_z0,
245
+ fs=fs,
246
+ timestep_spacing=timestep_spacing,
247
+ guidance_rescale=guidance_rescale,
248
+ **kwargs
249
+ )
250
+
251
+ ## reconstruct from latent to pixel space
252
+ batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
253
+
254
+ index = list(range(samples.shape[2]))
255
+ del index[1]
256
+ del index[-2]
257
+ samples = samples[:,:,index,:,:]
258
+ ## reconstruct from latent to pixel space
259
+ batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
260
+ batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
261
+
262
+
263
+
264
+ batch_variants.append(batch_images)
265
+ ## variants, batch, c, t, h, w
266
+ batch_variants = torch.stack(batch_variants)
267
+ # return batch_variants.permute(1, 0, 2, 3, 4, 5)
268
+
269
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
270
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
271
+ prompt_str=prompt_str[:40]
272
+ if len(prompt_str) == 0:
273
+ prompt_str = 'empty_prompt'
274
+
275
+ result_dir = "./tmp/"
276
+ save_videos(batch_image, result_dir, filenames=[prompt_str], fps=8)
277
+ print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
278
+ model = model.cpu()
279
+ saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
280
+ print("result saved to:", saved_result_dir)
281
+ return saved_result_dir
282
 
283
 
284