MohamedRashad commited on
Commit
844cfac
1 Parent(s): 5cf0d23

chore: Refactor CUDA device usage in app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -25,6 +25,9 @@ from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
25
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
26
  import spaces
27
 
 
 
 
28
  class ModifiedUNet(UNet2DConditionModel):
29
  @classmethod
30
  def from_config(cls, *args, **kwargs):
@@ -68,7 +71,6 @@ def find_best_bucket(h, w, options):
68
  return best_bucket
69
 
70
 
71
- @torch.inference_mode()
72
  def encode_cropped_prompt_77tokens(txt: str):
73
  cond_ids = tokenizer(txt,
74
  padding="max_length",
@@ -79,7 +81,6 @@ def encode_cropped_prompt_77tokens(txt: str):
79
  return text_cond
80
 
81
 
82
- @torch.inference_mode()
83
  def pytorch2numpy(imgs):
84
  results = []
85
  for x in imgs:
@@ -90,7 +91,6 @@ def pytorch2numpy(imgs):
90
  return results
91
 
92
 
93
- @torch.inference_mode()
94
  def numpy2pytorch(imgs):
95
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
96
  h = h.movedim(-1, 1)
@@ -103,14 +103,12 @@ def resize_without_crop(image, target_width, target_height):
103
  return np.array(resized_image)
104
 
105
 
106
- @torch.inference_mode()
107
  @spaces.GPU()
108
  def interrogator_process(x):
109
  image_description = wd14tagger.default_interrogator(x)
110
  return image_description, image_description
111
 
112
 
113
- @torch.inference_mode()
114
  @spaces.GPU()
115
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
116
  progress=gr.Progress()):
@@ -147,7 +145,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
147
  return pixels
148
 
149
 
150
- @torch.inference_mode()
151
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
152
  random.seed(seed)
153
  np.random.seed(seed)
@@ -198,7 +195,6 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
198
  return video, image_1, image_2
199
 
200
 
201
- @torch.inference_mode()
202
  @spaces.GPU(duration=360)
203
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
204
  result_frames = []
 
25
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
26
  import spaces
27
 
28
+ # Disable gradients globally
29
+ torch.set_grad_enabled(False)
30
+
31
  class ModifiedUNet(UNet2DConditionModel):
32
  @classmethod
33
  def from_config(cls, *args, **kwargs):
 
71
  return best_bucket
72
 
73
 
 
74
  def encode_cropped_prompt_77tokens(txt: str):
75
  cond_ids = tokenizer(txt,
76
  padding="max_length",
 
81
  return text_cond
82
 
83
 
 
84
  def pytorch2numpy(imgs):
85
  results = []
86
  for x in imgs:
 
91
  return results
92
 
93
 
 
94
  def numpy2pytorch(imgs):
95
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
96
  h = h.movedim(-1, 1)
 
103
  return np.array(resized_image)
104
 
105
 
 
106
  @spaces.GPU()
107
  def interrogator_process(x):
108
  image_description = wd14tagger.default_interrogator(x)
109
  return image_description, image_description
110
 
111
 
 
112
  @spaces.GPU()
113
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
114
  progress=gr.Progress()):
 
145
  return pixels
146
 
147
 
 
148
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
149
  random.seed(seed)
150
  np.random.seed(seed)
 
195
  return video, image_1, image_2
196
 
197
 
 
198
  @spaces.GPU(duration=360)
199
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
200
  result_frames = []