zhangyang-0123 commited on
Commit
5e20c42
·
1 Parent(s): 7ad3113
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -14,6 +14,7 @@ from diffusers import StableDiffusionXLPipeline
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
17
  def get_model_param_summary(model, verbose=False):
18
  params_dict = dict()
19
  overall_params = 0
@@ -50,7 +51,7 @@ class GradioArgs:
50
  if self.ratio is None:
51
  self.ratio = [0.68, 0.88]
52
 
53
- @spaces.GPU
54
  def prune_model(pipe, hookers):
55
  # remove parameters in attention blocks
56
  cross_attn_hooker = hookers[0]
@@ -91,18 +92,18 @@ def prune_model(pipe, hookers):
91
  ffn_hook.clear_hooks()
92
  return pipe
93
 
94
- @spaces.GPU
95
  def binary_mask_eval(args):
96
  # load sdxl model
97
  pipe = StableDiffusionXLPipeline.from_pretrained(
98
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
99
- ).to(device)
100
 
101
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
102
  mask_pipe, hookers = create_pipeline(
103
  pipe,
104
  args.model,
105
- device,
106
  torch_dtype,
107
  args.ckpt,
108
  binary=args.binary,
@@ -132,7 +133,7 @@ def binary_mask_eval(args):
132
  # reload the original model
133
  pipe = StableDiffusionXLPipeline.from_pretrained(
134
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
135
- ).to(device)
136
 
137
  # get model param summary
138
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
@@ -140,12 +141,15 @@ def binary_mask_eval(args):
140
  print("prune complete")
141
  return pipe, pruned_pipe
142
 
 
143
  @spaces.GPU
144
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
 
 
145
  # Run the model and return images directly
146
- g_cpu = torch.Generator(device).manual_seed(seed)
147
  original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
148
- g_cpu = torch.Generator(device).manual_seed(seed)
149
  ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
150
  return original_image, ecodiff_image
151
 
@@ -177,8 +181,8 @@ def create_demo():
177
  with gr.Row():
178
  model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
179
  pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
180
- prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
181
  status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
 
182
  with gr.Row():
183
  prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
184
  seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
 
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+
18
  def get_model_param_summary(model, verbose=False):
19
  params_dict = dict()
20
  overall_params = 0
 
51
  if self.ratio is None:
52
  self.ratio = [0.68, 0.88]
53
 
54
+
55
  def prune_model(pipe, hookers):
56
  # remove parameters in attention blocks
57
  cross_attn_hooker = hookers[0]
 
92
  ffn_hook.clear_hooks()
93
  return pipe
94
 
95
+
96
  def binary_mask_eval(args):
97
  # load sdxl model
98
  pipe = StableDiffusionXLPipeline.from_pretrained(
99
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
100
+ ).to("cpu")
101
 
102
  torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
103
  mask_pipe, hookers = create_pipeline(
104
  pipe,
105
  args.model,
106
+ "cpu",
107
  torch_dtype,
108
  args.ckpt,
109
  binary=args.binary,
 
133
  # reload the original model
134
  pipe = StableDiffusionXLPipeline.from_pretrained(
135
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
136
+ ).to("cpu")
137
 
138
  # get model param summary
139
  print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
 
141
  print("prune complete")
142
  return pipe, pruned_pipe
143
 
144
+
145
  @spaces.GPU
146
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
147
+ pipe.to("cuda")
148
+ pruned_pipe.to("cuda")
149
  # Run the model and return images directly
150
+ g_cpu = torch.Generator("cuda").manual_seed(seed)
151
  original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
152
+ g_cpu = torch.Generator("cuda").manual_seed(seed)
153
  ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
154
  return original_image, ecodiff_image
155
 
 
181
  with gr.Row():
182
  model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
183
  pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
 
184
  status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
185
+ prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
186
  with gr.Row():
187
  prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
188
  seed = gr.Number(label="Seed", value=44, precision=0, scale=1)