fffiloni commited on
Commit
8706ce3
·
verified ·
1 Parent(s): a24bb36

Update app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +7 -0
app_gradio.py CHANGED
@@ -25,6 +25,7 @@ from gifs_filter import filter
25
  import subprocess
26
  import uuid
27
  import tempfile
 
28
 
29
  from huggingface_hub import snapshot_download
30
 
@@ -187,6 +188,7 @@ def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda
187
  )
188
  except Exception as e:
189
  torch.cuda.empty_cache() # Clear CUDA cache in case of failure
 
190
  raise gr.Error(f"Video processing failed: {str(e)}") from e
191
 
192
  if apply_filter:
@@ -194,12 +196,17 @@ def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda
194
  print("APPLYING FILTER")
195
  # Attempt to apply filtering
196
  filtered_gifs = filter(generated_gifs, temp_image_path)
 
 
197
  return filtered_gifs, filtered_gifs
198
  except Exception as e:
199
  torch.cuda.empty_cache() # Clear CUDA cache in case of failure
 
200
  raise gr.Error(f"Filtering failed: {str(e)}") from e
201
  else:
202
  print("NOT APPLYING FILTER")
 
 
203
  return generated_gifs, generated_gifs
204
 
205
  def generate_output_from_sketchpad(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)):
 
25
  import subprocess
26
  import uuid
27
  import tempfile
28
+ import gc
29
 
30
  from huggingface_hub import snapshot_download
31
 
 
188
  )
189
  except Exception as e:
190
  torch.cuda.empty_cache() # Clear CUDA cache in case of failure
191
+ gc.collect()
192
  raise gr.Error(f"Video processing failed: {str(e)}") from e
193
 
194
  if apply_filter:
 
196
  print("APPLYING FILTER")
197
  # Attempt to apply filtering
198
  filtered_gifs = filter(generated_gifs, temp_image_path)
199
+ torch.cuda.empty_cache() # Clear CUDA cache in case of failure
200
+ gc.collect()
201
  return filtered_gifs, filtered_gifs
202
  except Exception as e:
203
  torch.cuda.empty_cache() # Clear CUDA cache in case of failure
204
+ gc.collect()
205
  raise gr.Error(f"Filtering failed: {str(e)}") from e
206
  else:
207
  print("NOT APPLYING FILTER")
208
+ torch.cuda.empty_cache() # Clear CUDA cache in case of failure
209
+ gc.collect()
210
  return generated_gifs, generated_gifs
211
 
212
  def generate_output_from_sketchpad(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)):