fffiloni commited on
Commit
5b6bab6
·
verified ·
1 Parent(s): 8706ce3

Update app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +63 -47
app_gradio.py CHANGED
@@ -117,47 +117,67 @@ pipe = TextToVideoSDPipelineModded.from_pretrained(
117
  @torch.no_grad()
118
  def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
119
  pipe_inversion.to(device)
120
- id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
 
 
 
 
 
 
121
  latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
122
  generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
123
- video_frames = pipe(
124
- prompt=caption,
125
- negative_prompt="",
126
- num_frames=num_frames,
127
- num_inference_steps=25,
128
- inv_latents=latents,
129
- guidance_scale=9,
130
- generator=generator,
131
- lambda_=lambda_,
132
- ).frames
133
-
 
 
 
 
 
 
134
  gifs = []
135
- for seed in range(num_seeds):
136
- vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4"
137
- gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif"
138
-
139
- os.makedirs(os.path.dirname(vid_name), exist_ok=True)
140
- os.makedirs(os.path.dirname(gif_name), exist_ok=True)
141
-
142
- video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
143
- VideoFileClip(vid_name).write_gif(gif_name)
144
-
145
- with Image.open(gif_name) as im:
146
- frames = load_frames(im)
 
147
 
148
- frames_collect = np.empty((0, 1024, 1024), int)
149
- for frame in frames:
150
- frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
151
- frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
152
- _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
153
- frames_collect = np.append(frames_collect, [frame], axis=0)
154
 
155
- save_gif(frames_collect, gif_name)
156
- gifs.append(gif_name)
 
 
 
157
 
158
  return gifs
159
 
160
  def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)) -> List[str]:
 
 
 
 
161
  if prompt is None:
162
  raise gr.Error("You forgot to describe the motion !")
163
  """Main function to generate output GIFs"""
@@ -175,21 +195,17 @@ def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda
175
 
176
  image.save(temp_image_path)
177
 
178
- try:
179
- # Attempt to process video
180
- generated_gifs = process_video(
181
- num_frames=10,
182
- num_seeds=num_seeds,
183
- generator=None,
184
- exp_dir=exp_dir,
185
- load_name=temp_image_path,
186
- caption=prompt,
187
- lambda_=1 - lambda_value
188
- )
189
- except Exception as e:
190
- torch.cuda.empty_cache() # Clear CUDA cache in case of failure
191
- gc.collect()
192
- raise gr.Error(f"Video processing failed: {str(e)}") from e
193
 
194
  if apply_filter:
195
  try:
 
117
  @torch.no_grad()
118
  def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
119
  pipe_inversion.to(device)
120
+ try:
121
+ id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
122
+ except Exception as e:
123
+ torch.cuda.empty_cache() # Clear CUDA cache in case of failure
124
+ gc.collect()
125
+ raise gr.Error(f"Invert latents failed: {str(e)}") from e
126
+
127
  latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
128
  generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
129
+
130
+ try:
131
+ video_frames = pipe(
132
+ prompt=caption,
133
+ negative_prompt="",
134
+ num_frames=num_frames,
135
+ num_inference_steps=25,
136
+ inv_latents=latents,
137
+ guidance_scale=9,
138
+ generator=generator,
139
+ lambda_=lambda_,
140
+ ).frames
141
+ except Exception as e:
142
+ torch.cuda.empty_cache()
143
+ gc.collect()
144
+ raise RuntimeError(f"Failed to process video: {e}") from e
145
+
146
  gifs = []
147
+ try:
148
+ for seed in range(num_seeds):
149
+ vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4"
150
+ gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif"
151
+
152
+ os.makedirs(os.path.dirname(vid_name), exist_ok=True)
153
+ os.makedirs(os.path.dirname(gif_name), exist_ok=True)
154
+
155
+ video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
156
+ VideoFileClip(vid_name).write_gif(gif_name)
157
+
158
+ with Image.open(gif_name) as im:
159
+ frames = load_frames(im)
160
 
161
+ frames_collect = np.empty((0, 1024, 1024), int)
162
+ for frame in frames:
163
+ frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
164
+ frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
165
+ _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
166
+ frames_collect = np.append(frames_collect, [frame], axis=0)
167
 
168
+ save_gif(frames_collect, gif_name)
169
+ gifs.append(gif_name)
170
+ except Exception as e:
171
+ torch.cuda.empty_cache()
172
+ raise RuntimeError(f"Failed during GIF generation: {e}") from e
173
 
174
  return gifs
175
 
176
  def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)) -> List[str]:
177
+ gc.collect()
178
+ torch.cuda.empty_cache()
179
+ torch.cuda.ipc_collect()
180
+
181
  if prompt is None:
182
  raise gr.Error("You forgot to describe the motion !")
183
  """Main function to generate output GIFs"""
 
195
 
196
  image.save(temp_image_path)
197
 
198
+
199
+ # Attempt to process video
200
+ generated_gifs = process_video(
201
+ num_frames=10,
202
+ num_seeds=num_seeds,
203
+ generator=None,
204
+ exp_dir=exp_dir,
205
+ load_name=temp_image_path,
206
+ caption=prompt,
207
+ lambda_=1 - lambda_value
208
+ )
 
 
 
 
209
 
210
  if apply_filter:
211
  try: