fbnnb commited on
Commit
73b78df
Β·
verified Β·
1 Parent(s): 4bc9607

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +346 -348
gradio_app.py CHANGED
@@ -1,349 +1,347 @@
1
- import os, argparse
2
- import sys
3
- import gradio as gr
4
- # from scripts.gradio.i2v_test_application import Image2Video
5
- sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
6
- import spaces
7
-
8
-
9
- import os
10
- import time
11
- from omegaconf import OmegaConf
12
- import torch
13
- from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
14
- from utils.utils import instantiate_from_config
15
- from huggingface_hub import hf_hub_download
16
- from einops import repeat
17
- import torchvision.transforms as transforms
18
- from pytorch_lightning import seed_everything
19
- from einops import rearrange
20
- from cldm.model import load_state_dict
21
- import cv2
22
-
23
- import torch
24
- print("cuda available:", torch.cuda.is_available())
25
-
26
-
27
- from huggingface_hub import snapshot_download
28
- import os
29
-
30
-
31
-
32
- def download_model():
33
- REPO_ID = 'fbnnb/TC_sketch'
34
- filename_list = ['tc_sketch.pt']
35
- tar_dir = './checkpoints/tooncrafter_1024_interp_sketch/'
36
- if not os.path.exists(tar_dir):
37
- os.makedirs(tar_dir)
38
- for filename in filename_list:
39
- local_file = os.path.join(tar_dir, filename)
40
- if not os.path.exists(local_file):
41
- hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=tar_dir, local_dir_use_symlinks=False)
42
- print("downloaded")
43
-
44
-
45
- def get_latent_z_with_hidden_states(model, videos):
46
- b, c, t, h, w = videos.shape
47
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
48
- encoder_posterior, hidden_states = model.first_stage_model.encode(x, return_hidden_states=True)
49
-
50
- hidden_states_first_last = []
51
- ### use only the first and last hidden states
52
- for hid in hidden_states:
53
- hid = rearrange(hid, '(b t) c h w -> b c t h w', t=t)
54
- hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
55
- hidden_states_first_last.append(hid_new)
56
-
57
- z = model.get_first_stage_encoding(encoder_posterior).detach()
58
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
59
- return z, hidden_states_first_last
60
-
61
-
62
-
63
- def extract_frames(video_path):
64
- # 動画フゑむルをθͺ­γΏθΎΌγ‚€
65
- cap = cv2.VideoCapture(video_path)
66
-
67
- frame_list = []
68
- frame_num = 0
69
-
70
- while True:
71
- # フレームをθͺ­γΏθΎΌγ‚€
72
- ret, frame = cap.read()
73
- if not ret:
74
- break
75
-
76
- # フレームをγƒͺγ‚Ήγƒˆγ«θΏ½εŠ 
77
- frame_list.append(frame)
78
- frame_num += 1
79
-
80
- print("load video length:", len(frame_list))
81
- # ε‹•η”»γƒ•γ‚‘γ‚€γƒ«γ‚’ι–‰γ˜γ‚‹
82
- cap.release()
83
-
84
- return frame_list
85
-
86
-
87
- resolution = '576_1024'
88
- resolution = (576, 1024)
89
- download_model()
90
- print("after download model")
91
- result_dir = "./results/"
92
- if not os.path.exists(result_dir):
93
- os.mkdir(result_dir)
94
-
95
- #ToonCrafterModel
96
- ckpt_path='checkpoints/tooncrafter_1024_interp_sketch/tc_sketch.pt'
97
- config_file='configs/inference_1024_v1.0.yaml'
98
- config = OmegaConf.load(config_file)
99
- model_config = config.pop("model", OmegaConf.create())
100
- model_config['params']['unet_config']['params']['use_checkpoint']=False
101
-
102
- model = instantiate_from_config(model_config)
103
- assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
104
- # ckpt_path = "/group/40005/gzhiwang/tc_sketch.pt"
105
- ckpt_path = "/group/40034/gzhiwang/tc_sketch.pt"
106
- model = load_model_checkpoint(model, ckpt_path)
107
- model.eval()
108
-
109
- # cn_model.load_state_dict(load_state_dict(cn_ckpt_path, location='cpu'))
110
- # cn_model.eval()
111
-
112
- # model.control_model = cn_model
113
- # model_list.append(model)
114
-
115
- save_fps = 8
116
- print("resolution:", resolution)
117
- print("init done.")
118
-
119
- def transpose_if_needed(tensor):
120
- h = tensor.shape[-2]
121
- w = tensor.shape[-1]
122
- if h > w:
123
- tensor = tensor.permute(0, 2, 1)
124
- return tensor
125
-
126
- def untranspose(tensor):
127
- ndim = tensor.ndim
128
- return tensor.transpose(ndim-1, ndim-2)
129
-
130
- @spaces.GPU(duration=200)
131
- def get_image(image, sketch, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, control_scale=0.6):
132
- print("enter fn")
133
- # control_frames = extract_frames(frame_guides)
134
- print("extract frames")
135
- seed_everything(seed)
136
- transform = transforms.Compose([
137
- transforms.Resize(min(resolution)),
138
- transforms.CenterCrop(resolution),
139
- ])
140
- print("before empty cache")
141
- torch.cuda.empty_cache()
142
- print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
143
- start = time.time()
144
- gpu_id=0
145
- if steps > 60:
146
- steps = 60
147
-
148
- global model
149
- # model = model_list[gpu_id]
150
- model = model.cuda()
151
-
152
- batch_size=1
153
- channels = model.model.diffusion_model.out_channels
154
- frames = model.temporal_length
155
- h, w = resolution[0] // 8, resolution[1] // 8
156
- noise_shape = [batch_size, channels, frames, h, w]
157
-
158
- # text cond
159
- transposed = False
160
- with torch.no_grad(), torch.cuda.amp.autocast():
161
- text_emb = model.get_learned_conditioning([prompt])
162
- print("before control")
163
- #control cond
164
- # if frame_guides is not None:
165
- # cn_videos = []
166
- # for frame in control_frames:
167
- # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
168
- # frame = cv2.bitwise_not(frame)
169
- # cn_tensor = torch.from_numpy(frame).unsqueeze(2).permute(2, 0, 1).float().to(model.device)
170
-
171
- # #cn_tensor = (cn_tensor / 255. - 0.5) * 2
172
- # cn_tensor = ( cn_tensor/255.0 )
173
- # cn_tensor = transpose_if_needed(cn_tensor)
174
- # cn_tensor_resized = transform(cn_tensor) #3,h,w
175
-
176
- # cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
177
- # cn_videos.append(cn_video)
178
-
179
- # cn_videos = torch.cat(cn_videos, dim=2)
180
- # if cn_videos.shape[2] > frames:
181
- # idxs = []
182
- # for i in range(frames):
183
- # index = int((i + 0.5) * cn_videos.shape[2] / frames)
184
- # idxs.append(min(index, cn_videos.shape[2] - 1))
185
- # cn_videos = cn_videos[:, :, idxs, :, :]
186
- # print("cn_videos.shape after slicing", cn_videos.shape)
187
- # model_list = []
188
- # for model in model_list:
189
- # model.control_scale = control_scale
190
- # model_list.append(model)
191
-
192
- # else:
193
- cn_videos = None
194
-
195
- print("image cond")
196
-
197
- # img cond
198
- img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
199
- input_h, input_w = img_tensor.shape[1:]
200
- img_tensor = (img_tensor / 255. - 0.5) * 2
201
- img_tensor = transpose_if_needed(img_tensor)
202
-
203
- image_tensor_resized = transform(img_tensor) #3,h,w
204
- videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
205
- print("get latent z")
206
- # z = get_latent_z(model, videos) #bc,1,hw
207
- videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
208
-
209
- if sketch is not None:
210
- img_tensor2 = torch.from_numpy(sketch).permute(2, 0, 1).float().to(model.device)
211
- img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
212
- img_tensor2 = transpose_if_needed(img_tensor2)
213
- image_tensor_resized2 = transform(img_tensor2) #3,h,w
214
- videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
215
- videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
216
-
217
- videos = torch.cat([videos, videos2], dim=2)
218
- else:
219
- videos = torch.cat([videos, videos], dim=2)
220
-
221
- z, hs = get_latent_z_with_hidden_states(model, videos)
222
-
223
- img_tensor_repeat = torch.zeros_like(z)
224
-
225
- img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
226
- img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
227
-
228
- print("image embedder")
229
- cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
230
- img_emb = model.image_proj_model(cond_images)
231
-
232
- imtext_cond = torch.cat([text_emb, img_emb], dim=1)
233
-
234
- fs = torch.tensor([fs], dtype=torch.long, device=model.device)
235
- # print("cn videos:",cn_videos.shape, "img emb:", img_emb.shape)
236
- cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
237
-
238
- print("before sample loop")
239
- ## inference
240
- batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
241
-
242
- ## remove the last frame
243
- if image2 is None:
244
- batch_samples = batch_samples[:,:,:,:-1,...]
245
- ## b,samples,c,t,h,w
246
- prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
247
- prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
248
- prompt_str=prompt_str[:40]
249
- if len(prompt_str) == 0:
250
- prompt_str = 'empty_prompt'
251
-
252
- global result_dir
253
- global save_fps
254
- if input_h > input_w:
255
- batch_samples = untranspose(batch_samples)
256
-
257
- save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
258
- print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
259
- model = model.cpu()
260
- saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
261
- print("result saved to:", saved_result_dir)
262
- return saved_result_dir
263
-
264
-
265
- # @spaces.GPU
266
-
267
-
268
-
269
- i2v_examples_interp_1024 = [
270
- ['prompts/1024_interp/frame_000000.jpg', 'prompts/1024_interp/frame_000041.jpg', 'a cat is eating', 50, 7.5, 1.0, 10, 123]
271
- ]
272
-
273
-
274
-
275
-
276
- def dynamicrafter_demo(result_dir='./tmp/', res=1024):
277
- if res == 1024:
278
- resolution = '576_1024'
279
- css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}"""
280
- elif res == 512:
281
- resolution = '320_512'
282
- css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px} #input_img2 {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
283
- elif res == 256:
284
- resolution = '256_256'
285
- css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}"""
286
- else:
287
- raise NotImplementedError(f"Unsupported resolution: {res}")
288
- # image2video = Image2Video(result_dir, resolution=resolution)
289
- with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
290
-
291
-
292
-
293
- with gr.Tab(label='ToonCrafter_320x512'):
294
- with gr.Column():
295
- with gr.Row():
296
- with gr.Column():
297
- with gr.Row():
298
- i2v_input_image = gr.Image(label="Input Image1",elem_id="input_img")
299
- # frame_guides = gr.Video(label="Input Guidance",elem_id="input_guidance", autoplay=True,show_share_button=True)
300
- with gr.Row():
301
- i2v_input_text = gr.Text(label='Prompts')
302
- with gr.Row():
303
- i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=50000, step=1, value=123)
304
- i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
305
- i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
306
- with gr.Row():
307
- i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
308
- i2v_motion = gr.Slider(minimum=5, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=10)
309
- control_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, elem_id="i2v_ctrl_scale", label="control_scale", value=0.6)
310
- i2v_end_btn = gr.Button("Generate")
311
- with gr.Column():
312
- with gr.Row():
313
- i2v_input_sketch = gr.Image(label="Input End SKetch",elem_id="input_img2")
314
- with gr.Row():
315
- i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
316
-
317
- gr.Examples(examples=i2v_examples_interp_1024,
318
- inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
319
- outputs=[i2v_output_video],
320
- fn = get_image,
321
- cache_examples=False,
322
- )
323
- i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
324
- outputs=[i2v_output_video],
325
- fn = get_image
326
- )
327
-
328
-
329
- return dynamicrafter_iface
330
-
331
-
332
- def get_parser():
333
- parser = argparse.ArgumentParser()
334
- return parser
335
-
336
-
337
- if __name__ == "__main__":
338
- parser = get_parser()
339
- args = parser.parse_args()
340
-
341
- result_dir = os.path.join('./', 'results')
342
- dynamicrafter_iface = dynamicrafter_demo(result_dir)
343
- dynamicrafter_iface.queue(max_size=12)
344
- print("launching...")
345
- # dynamicrafter_iface.launch(max_threads=1, share=True)
346
-
347
- dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=12345)
348
- # dynamicrafter_iface.launch()
349
  # print("launched...")
 
1
+ import os, argparse
2
+ import sys
3
+ import gradio as gr
4
+ # from scripts.gradio.i2v_test_application import Image2Video
5
+ sys.path.insert(1, os.path.join(sys.path[0], 'lvdm'))
6
+ import spaces
7
+
8
+
9
+ import os
10
+ import time
11
+ from omegaconf import OmegaConf
12
+ import torch
13
+ from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
14
+ from utils.utils import instantiate_from_config
15
+ from huggingface_hub import hf_hub_download
16
+ from einops import repeat
17
+ import torchvision.transforms as transforms
18
+ from pytorch_lightning import seed_everything
19
+ from einops import rearrange
20
+ from cldm.model import load_state_dict
21
+ import cv2
22
+
23
+ import torch
24
+ print("cuda available:", torch.cuda.is_available())
25
+
26
+
27
+ from huggingface_hub import snapshot_download
28
+ import os
29
+
30
+
31
+
32
+ def download_model():
33
+ REPO_ID = 'fbnnb/TC_sketch'
34
+ filename_list = ['tc_sketch.pt']
35
+ tar_dir = './checkpoints/tooncrafter_1024_interp_sketch/'
36
+ if not os.path.exists(tar_dir):
37
+ os.makedirs(tar_dir)
38
+ for filename in filename_list:
39
+ local_file = os.path.join(tar_dir, filename)
40
+ if not os.path.exists(local_file):
41
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=tar_dir, local_dir_use_symlinks=False)
42
+ print("downloaded")
43
+
44
+
45
+ def get_latent_z_with_hidden_states(model, videos):
46
+ b, c, t, h, w = videos.shape
47
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
48
+ encoder_posterior, hidden_states = model.first_stage_model.encode(x, return_hidden_states=True)
49
+
50
+ hidden_states_first_last = []
51
+ ### use only the first and last hidden states
52
+ for hid in hidden_states:
53
+ hid = rearrange(hid, '(b t) c h w -> b c t h w', t=t)
54
+ hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
55
+ hidden_states_first_last.append(hid_new)
56
+
57
+ z = model.get_first_stage_encoding(encoder_posterior).detach()
58
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
59
+ return z, hidden_states_first_last
60
+
61
+
62
+
63
+ def extract_frames(video_path):
64
+ # 動画フゑむルをθͺ­γΏθΎΌγ‚€
65
+ cap = cv2.VideoCapture(video_path)
66
+
67
+ frame_list = []
68
+ frame_num = 0
69
+
70
+ while True:
71
+ # フレームをθͺ­γΏθΎΌγ‚€
72
+ ret, frame = cap.read()
73
+ if not ret:
74
+ break
75
+
76
+ # フレームをγƒͺγ‚Ήγƒˆγ«θΏ½εŠ 
77
+ frame_list.append(frame)
78
+ frame_num += 1
79
+
80
+ print("load video length:", len(frame_list))
81
+ # ε‹•η”»γƒ•γ‚‘γ‚€γƒ«γ‚’ι–‰γ˜γ‚‹
82
+ cap.release()
83
+
84
+ return frame_list
85
+
86
+
87
+ resolution = '576_1024'
88
+ resolution = (576, 1024)
89
+ download_model()
90
+ print("after download model")
91
+ result_dir = "./results/"
92
+ if not os.path.exists(result_dir):
93
+ os.mkdir(result_dir)
94
+
95
+ #ToonCrafterModel
96
+ ckpt_path='checkpoints/tooncrafter_1024_interp_sketch/tc_sketch.pt'
97
+ config_file='configs/inference_1024_v1.0.yaml'
98
+ config = OmegaConf.load(config_file)
99
+ model_config = config.pop("model", OmegaConf.create())
100
+ model_config['params']['unet_config']['params']['use_checkpoint']=False
101
+
102
+ model = instantiate_from_config(model_config)
103
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
104
+ model = load_model_checkpoint(model, ckpt_path)
105
+ model.eval()
106
+
107
+ # cn_model.load_state_dict(load_state_dict(cn_ckpt_path, location='cpu'))
108
+ # cn_model.eval()
109
+
110
+ # model.control_model = cn_model
111
+ # model_list.append(model)
112
+
113
+ save_fps = 8
114
+ print("resolution:", resolution)
115
+ print("init done.")
116
+
117
+ def transpose_if_needed(tensor):
118
+ h = tensor.shape[-2]
119
+ w = tensor.shape[-1]
120
+ if h > w:
121
+ tensor = tensor.permute(0, 2, 1)
122
+ return tensor
123
+
124
+ def untranspose(tensor):
125
+ ndim = tensor.ndim
126
+ return tensor.transpose(ndim-1, ndim-2)
127
+
128
+ @spaces.GPU(duration=200)
129
+ def get_image(image, sketch, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, control_scale=0.6):
130
+ print("enter fn")
131
+ # control_frames = extract_frames(frame_guides)
132
+ print("extract frames")
133
+ seed_everything(seed)
134
+ transform = transforms.Compose([
135
+ transforms.Resize(min(resolution)),
136
+ transforms.CenterCrop(resolution),
137
+ ])
138
+ print("before empty cache")
139
+ torch.cuda.empty_cache()
140
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
141
+ start = time.time()
142
+ gpu_id=0
143
+ if steps > 60:
144
+ steps = 60
145
+
146
+ global model
147
+ # model = model_list[gpu_id]
148
+ model = model.cuda()
149
+
150
+ batch_size=1
151
+ channels = model.model.diffusion_model.out_channels
152
+ frames = model.temporal_length
153
+ h, w = resolution[0] // 8, resolution[1] // 8
154
+ noise_shape = [batch_size, channels, frames, h, w]
155
+
156
+ # text cond
157
+ transposed = False
158
+ with torch.no_grad(), torch.cuda.amp.autocast():
159
+ text_emb = model.get_learned_conditioning([prompt])
160
+ print("before control")
161
+ #control cond
162
+ # if frame_guides is not None:
163
+ # cn_videos = []
164
+ # for frame in control_frames:
165
+ # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
166
+ # frame = cv2.bitwise_not(frame)
167
+ # cn_tensor = torch.from_numpy(frame).unsqueeze(2).permute(2, 0, 1).float().to(model.device)
168
+
169
+ # #cn_tensor = (cn_tensor / 255. - 0.5) * 2
170
+ # cn_tensor = ( cn_tensor/255.0 )
171
+ # cn_tensor = transpose_if_needed(cn_tensor)
172
+ # cn_tensor_resized = transform(cn_tensor) #3,h,w
173
+
174
+ # cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
175
+ # cn_videos.append(cn_video)
176
+
177
+ # cn_videos = torch.cat(cn_videos, dim=2)
178
+ # if cn_videos.shape[2] > frames:
179
+ # idxs = []
180
+ # for i in range(frames):
181
+ # index = int((i + 0.5) * cn_videos.shape[2] / frames)
182
+ # idxs.append(min(index, cn_videos.shape[2] - 1))
183
+ # cn_videos = cn_videos[:, :, idxs, :, :]
184
+ # print("cn_videos.shape after slicing", cn_videos.shape)
185
+ # model_list = []
186
+ # for model in model_list:
187
+ # model.control_scale = control_scale
188
+ # model_list.append(model)
189
+
190
+ # else:
191
+ cn_videos = None
192
+
193
+ print("image cond")
194
+
195
+ # img cond
196
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
197
+ input_h, input_w = img_tensor.shape[1:]
198
+ img_tensor = (img_tensor / 255. - 0.5) * 2
199
+ img_tensor = transpose_if_needed(img_tensor)
200
+
201
+ image_tensor_resized = transform(img_tensor) #3,h,w
202
+ videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
203
+ print("get latent z")
204
+ # z = get_latent_z(model, videos) #bc,1,hw
205
+ videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
206
+
207
+ if sketch is not None:
208
+ img_tensor2 = torch.from_numpy(sketch).permute(2, 0, 1).float().to(model.device)
209
+ img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
210
+ img_tensor2 = transpose_if_needed(img_tensor2)
211
+ image_tensor_resized2 = transform(img_tensor2) #3,h,w
212
+ videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
213
+ videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
214
+
215
+ videos = torch.cat([videos, videos2], dim=2)
216
+ else:
217
+ videos = torch.cat([videos, videos], dim=2)
218
+
219
+ z, hs = get_latent_z_with_hidden_states(model, videos)
220
+
221
+ img_tensor_repeat = torch.zeros_like(z)
222
+
223
+ img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
224
+ img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
225
+
226
+ print("image embedder")
227
+ cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
228
+ img_emb = model.image_proj_model(cond_images)
229
+
230
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
231
+
232
+ fs = torch.tensor([fs], dtype=torch.long, device=model.device)
233
+ # print("cn videos:",cn_videos.shape, "img emb:", img_emb.shape)
234
+ cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
235
+
236
+ print("before sample loop")
237
+ ## inference
238
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
239
+
240
+ ## remove the last frame
241
+ if image2 is None:
242
+ batch_samples = batch_samples[:,:,:,:-1,...]
243
+ ## b,samples,c,t,h,w
244
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
245
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
246
+ prompt_str=prompt_str[:40]
247
+ if len(prompt_str) == 0:
248
+ prompt_str = 'empty_prompt'
249
+
250
+ global result_dir
251
+ global save_fps
252
+ if input_h > input_w:
253
+ batch_samples = untranspose(batch_samples)
254
+
255
+ save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
256
+ print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
257
+ model = model.cpu()
258
+ saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
259
+ print("result saved to:", saved_result_dir)
260
+ return saved_result_dir
261
+
262
+
263
+ # @spaces.GPU
264
+
265
+
266
+
267
+ i2v_examples_interp_1024 = [
268
+ ['prompts/1024_interp/frame_000000.jpg', 'prompts/1024_interp/frame_000041.jpg', 'a cat is eating', 50, 7.5, 1.0, 10, 123]
269
+ ]
270
+
271
+
272
+
273
+
274
+ def dynamicrafter_demo(result_dir='./tmp/', res=1024):
275
+ if res == 1024:
276
+ resolution = '576_1024'
277
+ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}"""
278
+ elif res == 512:
279
+ resolution = '320_512'
280
+ css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px} #input_img2 {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
281
+ elif res == 256:
282
+ resolution = '256_256'
283
+ css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}"""
284
+ else:
285
+ raise NotImplementedError(f"Unsupported resolution: {res}")
286
+ # image2video = Image2Video(result_dir, resolution=resolution)
287
+ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
288
+
289
+
290
+
291
+ with gr.Tab(label='ToonCrafter_320x512'):
292
+ with gr.Column():
293
+ with gr.Row():
294
+ with gr.Column():
295
+ with gr.Row():
296
+ i2v_input_image = gr.Image(label="Input Image1",elem_id="input_img")
297
+ # frame_guides = gr.Video(label="Input Guidance",elem_id="input_guidance", autoplay=True,show_share_button=True)
298
+ with gr.Row():
299
+ i2v_input_text = gr.Text(label='Prompts')
300
+ with gr.Row():
301
+ i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=50000, step=1, value=123)
302
+ i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta")
303
+ i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale")
304
+ with gr.Row():
305
+ i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
306
+ i2v_motion = gr.Slider(minimum=5, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=10)
307
+ control_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, elem_id="i2v_ctrl_scale", label="control_scale", value=0.6)
308
+ i2v_end_btn = gr.Button("Generate")
309
+ with gr.Column():
310
+ with gr.Row():
311
+ i2v_input_sketch = gr.Image(label="Input End SKetch",elem_id="input_img2")
312
+ with gr.Row():
313
+ i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
314
+
315
+ gr.Examples(examples=i2v_examples_interp_1024,
316
+ inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
317
+ outputs=[i2v_output_video],
318
+ fn = get_image,
319
+ cache_examples=False,
320
+ )
321
+ i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
322
+ outputs=[i2v_output_video],
323
+ fn = get_image
324
+ )
325
+
326
+
327
+ return dynamicrafter_iface
328
+
329
+
330
+ def get_parser():
331
+ parser = argparse.ArgumentParser()
332
+ return parser
333
+
334
+
335
+ if __name__ == "__main__":
336
+ parser = get_parser()
337
+ args = parser.parse_args()
338
+
339
+ result_dir = os.path.join('./', 'results')
340
+ dynamicrafter_iface = dynamicrafter_demo(result_dir)
341
+ dynamicrafter_iface.queue(max_size=12)
342
+ print("launching...")
343
+ # dynamicrafter_iface.launch(max_threads=1, share=True)
344
+
345
+ dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=12345)
346
+ # dynamicrafter_iface.launch()
 
 
347
  # print("launched...")