Try to use 2 models: one optimized for 25 f/s, another for 14 f/s

#17
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -12,10 +12,15 @@ from PIL import Image
12
  import random
13
  import spaces
14
 
15
- pipe = StableVideoDiffusionPipeline.from_pretrained(
16
  "vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
17
  )
18
- pipe.to("cuda")
 
 
 
 
 
19
 
20
  max_64_bit_int = 2**63 - 1
21
 
@@ -44,7 +49,10 @@ def sample(
44
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
45
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
46
 
47
- frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
 
 
 
48
  export_to_video(frames, video_path, fps=fps_id)
49
 
50
  return video_path, gr.update(label="Generated frames in *." + frame_format + " format", format = frame_format, value = frames), seed
 
12
  import random
13
  import spaces
14
 
15
+ fps25Pipe = StableVideoDiffusionPipeline.from_pretrained(
16
  "vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
17
  )
18
+ fps25Pipe.to("cuda")
19
+
20
+ fps14Pipe = StableVideoDiffusionPipeline.from_pretrained(
21
+ "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
22
+ )
23
+ fps14Pipe.to("cuda")
24
 
25
  max_64_bit_int = 2**63 - 1
26
 
 
49
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
50
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
51
 
52
+ if 14 < fps_id:
53
+ frames = fps25Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
54
+ else:
55
+ frames = fps14Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
56
  export_to_video(frames, video_path, fps=fps_id)
57
 
58
  return video_path, gr.update(label="Generated frames in *." + frame_format + " format", format = frame_format, value = frames), seed