VIVEK JAYARAM commited on
Commit
3288b70
·
1 Parent(s): adba228
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -18,6 +18,7 @@ from diffusers import DiffusionPipeline
18
  model = None
19
  ddim_scheduler = None
20
  model_type = None
 
21
 
22
  def load_image(image_path):
23
  """Process input image to tensor format."""
@@ -40,11 +41,13 @@ def process_image(image_choice, noise_sigma, operator_key, T, K):
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
  # Initialize model inside GPU-decorated function
43
- global model, ddim_scheduler, model_type
44
- if model is None:
 
 
45
  model_type = "diffusers"
46
- model_name = "google/ddpm-celebahq-256" if "Celeb" in image_choice else "google/ddpm-church-256"
47
  model = DiffusionPipeline.from_pretrained(model_name).to(device).unet
 
48
  ddim_scheduler = DDIMScheduler(
49
  num_train_timesteps=1000,
50
  beta_start=0.0001,
 
18
  model = None
19
  ddim_scheduler = None
20
  model_type = None
21
+ curr_model_name = None
22
 
23
  def load_image(image_path):
24
  """Process input image to tensor format."""
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize model inside GPU-decorated function
44
+ global model, curr_model_name, ddim_scheduler, model_type
45
+ model_name = "google/ddpm-celebahq-256" if "Celeb" in image_choice else "google/ddpm-church-256"
46
+
47
+ if model is None or curr_model_name != model_name:
48
  model_type = "diffusers"
 
49
  model = DiffusionPipeline.from_pretrained(model_name).to(device).unet
50
+ curr_model_name = model_name
51
  ddim_scheduler = DDIMScheduler(
52
  num_train_timesteps=1000,
53
  beta_start=0.0001,