Update gradio_app.py
Browse files- gradio_app.py +141 -140
gradio_app.py
CHANGED
@@ -132,152 +132,153 @@ def untranspose(tensor):
|
|
132 |
def get_image(image1, prompt, image2, dim_steps=50, ddim_eta=1., fs=None, seed=123, \
|
133 |
unconditional_guidance_scale=1.0, cfg_img=None, text_input=False, multiple_cond_cfg=False, \
|
134 |
loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, noise_shape=[72, 108], n_samples=1, **kwargs):
|
135 |
-
|
136 |
-
seed_everything(seed)
|
137 |
-
video_size = (576, 1024)
|
138 |
-
transform = transforms.Compose([
|
139 |
-
transforms.Resize(min(video_size)),
|
140 |
-
transforms.CenterCrop(video_size),
|
141 |
-
# transforms.ToTensor(),
|
142 |
-
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
143 |
-
])
|
144 |
-
|
145 |
-
image1 = torch.from_numpy(image1).permute(2, 0, 1).float().cuda()
|
146 |
-
input_h, input_w = image1.shape[1:]
|
147 |
-
image1 = (image1 / 255. - 0.5) * 2
|
148 |
-
|
149 |
-
image2 = torch.from_numpy(image2).permute(2, 0, 1).float().cuda()
|
150 |
-
input_h, input_w = image2.shape[1:]
|
151 |
-
image2 = (image2 / 255. - 0.5) * 2
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
|
158 |
-
frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=8)
|
159 |
-
frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=8)
|
160 |
-
videos = torch.cat([frame_tensor1, frame_tensor2], dim=1).unsqueeze(0)
|
161 |
-
# frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
|
162 |
-
# _, filename = os.path.split(file_list[idx*2])
|
163 |
-
|
164 |
-
global model
|
165 |
-
model.cuda()
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
if not text_input:
|
172 |
-
prompts = [""]*batch_size
|
173 |
-
|
174 |
-
img = videos[:,:,0] #bchw
|
175 |
-
img_emb = model.embedder(img) ## blc
|
176 |
-
img_emb = model.image_proj_model(img_emb)
|
177 |
-
|
178 |
-
cond_emb = model.get_learned_conditioning(prompts)
|
179 |
-
cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
|
180 |
-
if model.model.conditioning_key == 'hybrid':
|
181 |
-
z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
|
182 |
-
if loop or interp:
|
183 |
-
img_cat_cond = torch.zeros_like(z)
|
184 |
-
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
|
185 |
-
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
|
186 |
-
else:
|
187 |
-
img_cat_cond = z[:,:,:1,:,:]
|
188 |
-
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
|
189 |
-
cond["c_concat"] = [img_cat_cond] # b c 1 h w
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
if model.model.conditioning_key == 'hybrid':
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
else:
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
|
283 |
|
|
|
132 |
def get_image(image1, prompt, image2, dim_steps=50, ddim_eta=1., fs=None, seed=123, \
|
133 |
unconditional_guidance_scale=1.0, cfg_img=None, text_input=False, multiple_cond_cfg=False, \
|
134 |
loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, noise_shape=[72, 108], n_samples=1, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
with torch.no_grad():
|
137 |
+
seed_everything(seed)
|
138 |
+
video_size = (576, 1024)
|
139 |
+
transform = transforms.Compose([
|
140 |
+
transforms.Resize(min(video_size)),
|
141 |
+
transforms.CenterCrop(video_size),
|
142 |
+
# transforms.ToTensor(),
|
143 |
+
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
144 |
+
])
|
145 |
|
146 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float().cuda()
|
147 |
+
input_h, input_w = image1.shape[1:]
|
148 |
+
image1 = (image1 / 255. - 0.5) * 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
+
image2 = torch.from_numpy(image2).permute(2, 0, 1).float().cuda()
|
151 |
+
input_h, input_w = image2.shape[1:]
|
152 |
+
image2 = (image2 / 255. - 0.5) * 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
|
155 |
+
# image1 = Image.open(file_list[2*idx]).convert('RGB')
|
156 |
+
image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
|
157 |
+
# image2 = Image.open(file_list[2*idx+1]).convert('RGB')
|
158 |
+
image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
|
159 |
+
frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=8)
|
160 |
+
frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=8)
|
161 |
+
videos = torch.cat([frame_tensor1, frame_tensor2], dim=1).unsqueeze(0)
|
162 |
+
# frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
|
163 |
+
# _, filename = os.path.split(file_list[idx*2])
|
164 |
+
|
165 |
+
global model
|
166 |
+
model.cuda()
|
167 |
+
|
168 |
+
ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
|
169 |
+
batch_size = 1
|
170 |
+
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
|
171 |
+
|
172 |
+
if not text_input:
|
173 |
+
prompts = [""]*batch_size
|
174 |
+
|
175 |
+
img = videos[:,:,0] #bchw
|
176 |
+
img_emb = model.embedder(img) ## blc
|
177 |
+
img_emb = model.image_proj_model(img_emb)
|
178 |
+
|
179 |
+
cond_emb = model.get_learned_conditioning(prompts)
|
180 |
+
cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
|
181 |
if model.model.conditioning_key == 'hybrid':
|
182 |
+
z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
|
183 |
+
if loop or interp:
|
184 |
+
img_cat_cond = torch.zeros_like(z)
|
185 |
+
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
|
186 |
+
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
|
187 |
+
else:
|
188 |
+
img_cat_cond = z[:,:,:1,:,:]
|
189 |
+
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
|
190 |
+
cond["c_concat"] = [img_cat_cond] # b c 1 h w
|
191 |
+
|
192 |
+
if unconditional_guidance_scale != 1.0:
|
193 |
+
if model.uncond_type == "empty_seq":
|
194 |
+
prompts = batch_size * [""]
|
195 |
+
uc_emb = model.get_learned_conditioning(prompts)
|
196 |
+
elif model.uncond_type == "zero_embed":
|
197 |
+
uc_emb = torch.zeros_like(cond_emb)
|
198 |
+
uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
|
199 |
+
uc_img_emb = model.image_proj_model(uc_img_emb)
|
200 |
+
uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
|
201 |
+
if model.model.conditioning_key == 'hybrid':
|
202 |
+
uc["c_concat"] = [img_cat_cond]
|
203 |
+
else:
|
204 |
+
uc = None
|
205 |
+
#
|
206 |
+
# for i, h in enumerate(hs):
|
207 |
+
# print("h:", h.shape)
|
208 |
+
# hs[i] = hs[i][:,:,0,:,:].unsqueeze(2)
|
209 |
+
additional_decode_kwargs = {'ref_context': hs}
|
210 |
+
# additional_decode_kwargs = {'ref_context': None}
|
211 |
+
|
212 |
+
## we need one more unconditioning image=yes, text=""
|
213 |
+
if multiple_cond_cfg and cfg_img != 1.0:
|
214 |
+
uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
|
215 |
+
if model.model.conditioning_key == 'hybrid':
|
216 |
+
uc_2["c_concat"] = [img_cat_cond]
|
217 |
+
kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
|
218 |
else:
|
219 |
+
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
220 |
+
|
221 |
+
z0 = None
|
222 |
+
cond_mask = None
|
223 |
+
|
224 |
+
batch_variants = []
|
225 |
+
for _ in range(n_samples):
|
226 |
+
|
227 |
+
if z0 is not None:
|
228 |
+
cond_z0 = z0.clone()
|
229 |
+
kwargs.update({"clean_cond": True})
|
230 |
+
else:
|
231 |
+
cond_z0 = None
|
232 |
+
if ddim_sampler is not None:
|
233 |
+
|
234 |
+
samples, _ = ddim_sampler.sample(S=ddim_steps,
|
235 |
+
conditioning=cond,
|
236 |
+
batch_size=batch_size,
|
237 |
+
shape=noise_shape,
|
238 |
+
verbose=False,
|
239 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
240 |
+
unconditional_conditioning=uc,
|
241 |
+
eta=ddim_eta,
|
242 |
+
cfg_img=cfg_img,
|
243 |
+
mask=cond_mask,
|
244 |
+
x0=cond_z0,
|
245 |
+
fs=fs,
|
246 |
+
timestep_spacing=timestep_spacing,
|
247 |
+
guidance_rescale=guidance_rescale,
|
248 |
+
**kwargs
|
249 |
+
)
|
250 |
+
|
251 |
+
## reconstruct from latent to pixel space
|
252 |
+
batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
|
253 |
+
|
254 |
+
index = list(range(samples.shape[2]))
|
255 |
+
del index[1]
|
256 |
+
del index[-2]
|
257 |
+
samples = samples[:,:,index,:,:]
|
258 |
+
## reconstruct from latent to pixel space
|
259 |
+
batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
|
260 |
+
batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
batch_variants.append(batch_images)
|
265 |
+
## variants, batch, c, t, h, w
|
266 |
+
batch_variants = torch.stack(batch_variants)
|
267 |
+
# return batch_variants.permute(1, 0, 2, 3, 4, 5)
|
268 |
+
|
269 |
+
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
|
270 |
+
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
|
271 |
+
prompt_str=prompt_str[:40]
|
272 |
+
if len(prompt_str) == 0:
|
273 |
+
prompt_str = 'empty_prompt'
|
274 |
+
|
275 |
+
result_dir = "./tmp/"
|
276 |
+
save_videos(batch_image, result_dir, filenames=[prompt_str], fps=8)
|
277 |
+
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
|
278 |
+
model = model.cpu()
|
279 |
+
saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
|
280 |
+
print("result saved to:", saved_result_dir)
|
281 |
+
return saved_result_dir
|
282 |
|
283 |
|
284 |
|