Spaces:
Paused
Paused
Commit
•
33a8da6
1
Parent(s):
1cf330c
Reduce the usage of GPU (#20)
Browse files- Reduce the usage of GPU (d76675441b3ffd721192a61fdeb81cbb31fc9f6d)
Co-authored-by: Fabrice TIERCELIN <Fabrice-TIERCELIN@users.noreply.huggingface.co>
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
|
|
|
|
4 |
from glob import glob
|
5 |
from pathlib import Path
|
6 |
from typing import Optional
|
@@ -9,9 +11,6 @@ from diffusers import StableVideoDiffusionPipeline
|
|
9 |
from diffusers.utils import export_to_video
|
10 |
from PIL import Image
|
11 |
|
12 |
-
import random
|
13 |
-
import spaces
|
14 |
-
|
15 |
fps25Pipe = StableVideoDiffusionPipeline.from_pretrained(
|
16 |
"vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
|
17 |
)
|
@@ -24,8 +23,7 @@ fps14Pipe.to("cuda")
|
|
24 |
|
25 |
max_64_bit_int = 2**63 - 1
|
26 |
|
27 |
-
|
28 |
-
def sample(
|
29 |
image: Image,
|
30 |
seed: Optional[int] = 42,
|
31 |
randomize_seed: bool = True,
|
@@ -35,7 +33,6 @@ def sample(
|
|
35 |
decoding_t: int = 3,
|
36 |
frame_format: str = "webp",
|
37 |
version: str = "auto",
|
38 |
-
device: str = "cuda",
|
39 |
output_folder: str = "outputs",
|
40 |
):
|
41 |
if image.mode == "RGBA":
|
@@ -43,20 +40,47 @@ def sample(
|
|
43 |
|
44 |
if randomize_seed:
|
45 |
seed = random.randint(0, max_64_bit_int)
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
os.makedirs(output_folder, exist_ok=True)
|
49 |
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
50 |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
51 |
|
52 |
-
if version == "svdxt" or (14 < fps_id and version != "svd"):
|
53 |
-
frames = fps25Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
|
54 |
-
else:
|
55 |
-
frames = fps14Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
|
56 |
export_to_video(frames, video_path, fps=fps_id)
|
57 |
|
58 |
return video_path, gr.update(value=video_path, visible=True), gr.update(label="Generated frames in *." + frame_format + " format", format = frame_format, value = frames, visible=True), seed
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def resize_image(image, output_size=(1024, 576)):
|
61 |
# Calculate aspect ratios
|
62 |
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
@@ -117,7 +141,7 @@ with gr.Blocks() as demo:
|
|
117 |
gallery = gr.Gallery(label="Generated frames", visible=False)
|
118 |
|
119 |
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
120 |
-
generate_btn.click(fn=
|
121 |
|
122 |
gr.Examples(
|
123 |
examples=[
|
@@ -127,7 +151,7 @@ with gr.Blocks() as demo:
|
|
127 |
],
|
128 |
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format, version],
|
129 |
outputs=[video, download_button, gallery, seed],
|
130 |
-
fn=
|
131 |
run_on_click=True,
|
132 |
cache_examples=False,
|
133 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
import random
|
5 |
+
import spaces
|
6 |
from glob import glob
|
7 |
from pathlib import Path
|
8 |
from typing import Optional
|
|
|
11 |
from diffusers.utils import export_to_video
|
12 |
from PIL import Image
|
13 |
|
|
|
|
|
|
|
14 |
fps25Pipe = StableVideoDiffusionPipeline.from_pretrained(
|
15 |
"vdo/stable-video-diffusion-img2vid-xt-1-1", torch_dtype=torch.float16, variant="fp16"
|
16 |
)
|
|
|
23 |
|
24 |
max_64_bit_int = 2**63 - 1
|
25 |
|
26 |
+
def animate(
|
|
|
27 |
image: Image,
|
28 |
seed: Optional[int] = 42,
|
29 |
randomize_seed: bool = True,
|
|
|
33 |
decoding_t: int = 3,
|
34 |
frame_format: str = "webp",
|
35 |
version: str = "auto",
|
|
|
36 |
output_folder: str = "outputs",
|
37 |
):
|
38 |
if image.mode == "RGBA":
|
|
|
40 |
|
41 |
if randomize_seed:
|
42 |
seed = random.randint(0, max_64_bit_int)
|
43 |
+
|
44 |
+
frames = animate_on_gpu(
|
45 |
+
image,
|
46 |
+
seed,
|
47 |
+
randomize_seed,
|
48 |
+
motion_bucket_id,
|
49 |
+
fps_id,
|
50 |
+
noise_aug_strength,
|
51 |
+
decoding_t,
|
52 |
+
frame_format,
|
53 |
+
version
|
54 |
+
)
|
55 |
|
56 |
os.makedirs(output_folder, exist_ok=True)
|
57 |
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
58 |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
59 |
|
|
|
|
|
|
|
|
|
60 |
export_to_video(frames, video_path, fps=fps_id)
|
61 |
|
62 |
return video_path, gr.update(value=video_path, visible=True), gr.update(label="Generated frames in *." + frame_format + " format", format = frame_format, value = frames, visible=True), seed
|
63 |
|
64 |
+
@spaces.GPU(duration=120)
|
65 |
+
def animate_on_gpu(
|
66 |
+
image: Image,
|
67 |
+
seed: Optional[int] = 42,
|
68 |
+
randomize_seed: bool = True,
|
69 |
+
motion_bucket_id: int = 127,
|
70 |
+
fps_id: int = 6,
|
71 |
+
noise_aug_strength: float = 0.1,
|
72 |
+
decoding_t: int = 3,
|
73 |
+
frame_format: str = "webp",
|
74 |
+
version: str = "auto"
|
75 |
+
):
|
76 |
+
generator = torch.manual_seed(seed)
|
77 |
+
|
78 |
+
if version == "svdxt" or (14 < fps_id and version != "svd"):
|
79 |
+
return fps25Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
|
80 |
+
else:
|
81 |
+
return fps14Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
|
82 |
+
|
83 |
+
|
84 |
def resize_image(image, output_size=(1024, 576)):
|
85 |
# Calculate aspect ratios
|
86 |
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
|
|
141 |
gallery = gr.Gallery(label="Generated frames", visible=False)
|
142 |
|
143 |
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
144 |
+
generate_btn.click(fn=animate, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format, version], outputs=[video, download_button, gallery, seed], api_name="video")
|
145 |
|
146 |
gr.Examples(
|
147 |
examples=[
|
|
|
151 |
],
|
152 |
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format, version],
|
153 |
outputs=[video, download_button, gallery, seed],
|
154 |
+
fn=animate,
|
155 |
run_on_click=True,
|
156 |
cache_examples=False,
|
157 |
)
|