Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5e20c42
1
Parent(s):
7ad3113
modify
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ from diffusers import StableDiffusionXLPipeline
|
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
|
|
17 |
def get_model_param_summary(model, verbose=False):
|
18 |
params_dict = dict()
|
19 |
overall_params = 0
|
@@ -50,7 +51,7 @@ class GradioArgs:
|
|
50 |
if self.ratio is None:
|
51 |
self.ratio = [0.68, 0.88]
|
52 |
|
53 |
-
|
54 |
def prune_model(pipe, hookers):
|
55 |
# remove parameters in attention blocks
|
56 |
cross_attn_hooker = hookers[0]
|
@@ -91,18 +92,18 @@ def prune_model(pipe, hookers):
|
|
91 |
ffn_hook.clear_hooks()
|
92 |
return pipe
|
93 |
|
94 |
-
|
95 |
def binary_mask_eval(args):
|
96 |
# load sdxl model
|
97 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
98 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
99 |
-
).to(
|
100 |
|
101 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
102 |
mask_pipe, hookers = create_pipeline(
|
103 |
pipe,
|
104 |
args.model,
|
105 |
-
|
106 |
torch_dtype,
|
107 |
args.ckpt,
|
108 |
binary=args.binary,
|
@@ -132,7 +133,7 @@ def binary_mask_eval(args):
|
|
132 |
# reload the original model
|
133 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
134 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
135 |
-
).to(
|
136 |
|
137 |
# get model param summary
|
138 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
@@ -140,12 +141,15 @@ def binary_mask_eval(args):
|
|
140 |
print("prune complete")
|
141 |
return pipe, pruned_pipe
|
142 |
|
|
|
143 |
@spaces.GPU
|
144 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
|
|
|
|
145 |
# Run the model and return images directly
|
146 |
-
g_cpu = torch.Generator(
|
147 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
148 |
-
g_cpu = torch.Generator(
|
149 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
150 |
return original_image, ecodiff_image
|
151 |
|
@@ -177,8 +181,8 @@ def create_demo():
|
|
177 |
with gr.Row():
|
178 |
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
|
179 |
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
|
180 |
-
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
181 |
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
|
|
182 |
with gr.Row():
|
183 |
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
|
184 |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
|
|
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
+
|
18 |
def get_model_param_summary(model, verbose=False):
|
19 |
params_dict = dict()
|
20 |
overall_params = 0
|
|
|
51 |
if self.ratio is None:
|
52 |
self.ratio = [0.68, 0.88]
|
53 |
|
54 |
+
|
55 |
def prune_model(pipe, hookers):
|
56 |
# remove parameters in attention blocks
|
57 |
cross_attn_hooker = hookers[0]
|
|
|
92 |
ffn_hook.clear_hooks()
|
93 |
return pipe
|
94 |
|
95 |
+
|
96 |
def binary_mask_eval(args):
|
97 |
# load sdxl model
|
98 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
99 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
100 |
+
).to("cpu")
|
101 |
|
102 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
103 |
mask_pipe, hookers = create_pipeline(
|
104 |
pipe,
|
105 |
args.model,
|
106 |
+
"cpu",
|
107 |
torch_dtype,
|
108 |
args.ckpt,
|
109 |
binary=args.binary,
|
|
|
133 |
# reload the original model
|
134 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
135 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
136 |
+
).to("cpu")
|
137 |
|
138 |
# get model param summary
|
139 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
|
141 |
print("prune complete")
|
142 |
return pipe, pruned_pipe
|
143 |
|
144 |
+
|
145 |
@spaces.GPU
|
146 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
147 |
+
pipe.to("cuda")
|
148 |
+
pruned_pipe.to("cuda")
|
149 |
# Run the model and return images directly
|
150 |
+
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
151 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
152 |
+
g_cpu = torch.Generator("cuda").manual_seed(seed)
|
153 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
154 |
return original_image, ecodiff_image
|
155 |
|
|
|
181 |
with gr.Row():
|
182 |
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
|
183 |
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
|
|
|
184 |
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
185 |
+
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
186 |
with gr.Row():
|
187 |
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
|
188 |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
|