enable sfast in controlnet pipeline
Browse files- build-run.sh +1 -1
- pipelines/controlnet.py +12 -4
build-run.sh
CHANGED
@@ -13,4 +13,4 @@ if [ -z ${PIPELINE+x} ]; then
|
|
13 |
PIPELINE="controlnet"
|
14 |
fi
|
15 |
echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
|
16 |
-
python3 run.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE
|
|
|
13 |
PIPELINE="controlnet"
|
14 |
fi
|
15 |
echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
|
16 |
+
python3 run.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE --sfast
|
pipelines/controlnet.py
CHANGED
@@ -173,16 +173,24 @@ class Pipeline:
|
|
173 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
174 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
175 |
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
self.canny_torch = SobelOperator(device=device)
|
177 |
self.pipe.set_progress_bar_config(disable=True)
|
178 |
self.pipe.to(device=device, dtype=torch_dtype)
|
179 |
if device.type != "mps":
|
180 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
181 |
|
182 |
-
# check if computer has less than 64GB of RAM using sys or os
|
183 |
-
if psutil.virtual_memory().total < 64 * 1024**3:
|
184 |
-
self.pipe.enable_attention_slicing()
|
185 |
-
|
186 |
if args.torch_compile:
|
187 |
self.pipe.unet = torch.compile(
|
188 |
self.pipe.unet, mode="reduce-overhead", fullgraph=True
|
|
|
173 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
174 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
175 |
).to(device)
|
176 |
+
|
177 |
+
if args.sfast:
|
178 |
+
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
179 |
+
compile,
|
180 |
+
CompilationConfig,
|
181 |
+
)
|
182 |
+
|
183 |
+
config = CompilationConfig.Default()
|
184 |
+
config.enable_xformers = True
|
185 |
+
config.enable_triton = True
|
186 |
+
config.enable_cuda_graph = True
|
187 |
+
self.pipe = compile(self.pipe, config=config)
|
188 |
self.canny_torch = SobelOperator(device=device)
|
189 |
self.pipe.set_progress_bar_config(disable=True)
|
190 |
self.pipe.to(device=device, dtype=torch_dtype)
|
191 |
if device.type != "mps":
|
192 |
self.pipe.unet.to(memory_format=torch.channels_last)
|
193 |
|
|
|
|
|
|
|
|
|
194 |
if args.torch_compile:
|
195 |
self.pipe.unet = torch.compile(
|
196 |
self.pipe.unet, mode="reduce-overhead", fullgraph=True
|