Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
•
844cfac
1
Parent(s):
5cf0d23
chore: Refactor CUDA device usage in app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,9 @@ from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
|
25 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
26 |
import spaces
|
27 |
|
|
|
|
|
|
|
28 |
class ModifiedUNet(UNet2DConditionModel):
|
29 |
@classmethod
|
30 |
def from_config(cls, *args, **kwargs):
|
@@ -68,7 +71,6 @@ def find_best_bucket(h, w, options):
|
|
68 |
return best_bucket
|
69 |
|
70 |
|
71 |
-
@torch.inference_mode()
|
72 |
def encode_cropped_prompt_77tokens(txt: str):
|
73 |
cond_ids = tokenizer(txt,
|
74 |
padding="max_length",
|
@@ -79,7 +81,6 @@ def encode_cropped_prompt_77tokens(txt: str):
|
|
79 |
return text_cond
|
80 |
|
81 |
|
82 |
-
@torch.inference_mode()
|
83 |
def pytorch2numpy(imgs):
|
84 |
results = []
|
85 |
for x in imgs:
|
@@ -90,7 +91,6 @@ def pytorch2numpy(imgs):
|
|
90 |
return results
|
91 |
|
92 |
|
93 |
-
@torch.inference_mode()
|
94 |
def numpy2pytorch(imgs):
|
95 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
96 |
h = h.movedim(-1, 1)
|
@@ -103,14 +103,12 @@ def resize_without_crop(image, target_width, target_height):
|
|
103 |
return np.array(resized_image)
|
104 |
|
105 |
|
106 |
-
@torch.inference_mode()
|
107 |
@spaces.GPU()
|
108 |
def interrogator_process(x):
|
109 |
image_description = wd14tagger.default_interrogator(x)
|
110 |
return image_description, image_description
|
111 |
|
112 |
|
113 |
-
@torch.inference_mode()
|
114 |
@spaces.GPU()
|
115 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
116 |
progress=gr.Progress()):
|
@@ -147,7 +145,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
|
|
147 |
return pixels
|
148 |
|
149 |
|
150 |
-
@torch.inference_mode()
|
151 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
152 |
random.seed(seed)
|
153 |
np.random.seed(seed)
|
@@ -198,7 +195,6 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
198 |
return video, image_1, image_2
|
199 |
|
200 |
|
201 |
-
@torch.inference_mode()
|
202 |
@spaces.GPU(duration=360)
|
203 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
204 |
result_frames = []
|
|
|
25 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
26 |
import spaces
|
27 |
|
28 |
+
# Disable gradients globally
|
29 |
+
torch.set_grad_enabled(False)
|
30 |
+
|
31 |
class ModifiedUNet(UNet2DConditionModel):
|
32 |
@classmethod
|
33 |
def from_config(cls, *args, **kwargs):
|
|
|
71 |
return best_bucket
|
72 |
|
73 |
|
|
|
74 |
def encode_cropped_prompt_77tokens(txt: str):
|
75 |
cond_ids = tokenizer(txt,
|
76 |
padding="max_length",
|
|
|
81 |
return text_cond
|
82 |
|
83 |
|
|
|
84 |
def pytorch2numpy(imgs):
|
85 |
results = []
|
86 |
for x in imgs:
|
|
|
91 |
return results
|
92 |
|
93 |
|
|
|
94 |
def numpy2pytorch(imgs):
|
95 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
96 |
h = h.movedim(-1, 1)
|
|
|
103 |
return np.array(resized_image)
|
104 |
|
105 |
|
|
|
106 |
@spaces.GPU()
|
107 |
def interrogator_process(x):
|
108 |
image_description = wd14tagger.default_interrogator(x)
|
109 |
return image_description, image_description
|
110 |
|
111 |
|
|
|
112 |
@spaces.GPU()
|
113 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
114 |
progress=gr.Progress()):
|
|
|
145 |
return pixels
|
146 |
|
147 |
|
|
|
148 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
149 |
random.seed(seed)
|
150 |
np.random.seed(seed)
|
|
|
195 |
return video, image_1, image_2
|
196 |
|
197 |
|
|
|
198 |
@spaces.GPU(duration=360)
|
199 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
200 |
result_frames = []
|