Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
6ad0411
1
Parent(s):
0d34381
Clean up loader
Browse files- .gitignore +1 -0
- app.py +15 -17
- lib/__init__.py +0 -2
- lib/config.py +16 -11
- lib/inference.py +21 -21
- lib/loader.py +126 -177
- lib/utils.py +1 -23
- requirements.txt +2 -3
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
__pycache__/
|
2 |
.venv/
|
|
|
|
1 |
__pycache__/
|
2 |
.venv/
|
3 |
+
app.log
|
app.py
CHANGED
@@ -2,7 +2,14 @@ import argparse
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
-
from lib import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Update refresh button hover text
|
8 |
seed_js = """
|
@@ -55,28 +62,19 @@ random_prompt_js = f"""
|
|
55 |
|
56 |
|
57 |
# Transform the raw inputs before generation
|
58 |
-
|
59 |
if len(args) > 0:
|
60 |
prompt = args[0]
|
61 |
else:
|
62 |
prompt = None
|
63 |
if prompt is None or prompt.strip() == "":
|
64 |
raise gr.Error("You must enter a prompt")
|
65 |
-
|
66 |
try:
|
67 |
-
if Config.ZERO_GPU:
|
68 |
-
|
69 |
-
|
70 |
-
images = await async_call(
|
71 |
-
generate,
|
72 |
-
*args,
|
73 |
-
Error=gr.Error,
|
74 |
-
Info=gr.Info,
|
75 |
-
progress=progress,
|
76 |
-
)
|
77 |
except RuntimeError:
|
78 |
raise gr.Error("Error: Please try again")
|
79 |
-
|
80 |
return images
|
81 |
|
82 |
|
@@ -259,7 +257,7 @@ with gr.Blocks(
|
|
259 |
)
|
260 |
use_refiner = gr.Checkbox(
|
261 |
elem_classes=["checkbox"],
|
262 |
-
label="
|
263 |
value=False,
|
264 |
)
|
265 |
|
@@ -322,8 +320,8 @@ if __name__ == "__main__":
|
|
322 |
parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
|
323 |
args = parser.parse_args()
|
324 |
|
325 |
-
disable_progress_bars()
|
326 |
-
for repo_id, allow_patterns in Config.
|
327 |
download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
|
328 |
|
329 |
# https://www.gradio.app/docs/gradio/interface#interface-queue
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
+
from lib import (
|
6 |
+
Config,
|
7 |
+
# disable_progress_bars,
|
8 |
+
download_repo_files,
|
9 |
+
generate,
|
10 |
+
read_file,
|
11 |
+
read_json,
|
12 |
+
)
|
13 |
|
14 |
# Update refresh button hover text
|
15 |
seed_js = """
|
|
|
62 |
|
63 |
|
64 |
# Transform the raw inputs before generation
|
65 |
+
def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
|
66 |
if len(args) > 0:
|
67 |
prompt = args[0]
|
68 |
else:
|
69 |
prompt = None
|
70 |
if prompt is None or prompt.strip() == "":
|
71 |
raise gr.Error("You must enter a prompt")
|
|
|
72 |
try:
|
73 |
+
# if Config.ZERO_GPU:
|
74 |
+
# progress((0, 100), desc="ZeroGPU init")
|
75 |
+
images = generate(*args, Error=gr.Error, Info=gr.Info, progress=progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
except RuntimeError:
|
77 |
raise gr.Error("Error: Please try again")
|
|
|
78 |
return images
|
79 |
|
80 |
|
|
|
257 |
)
|
258 |
use_refiner = gr.Checkbox(
|
259 |
elem_classes=["checkbox"],
|
260 |
+
label="Use refiner",
|
261 |
value=False,
|
262 |
)
|
263 |
|
|
|
320 |
parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
|
321 |
args = parser.parse_args()
|
322 |
|
323 |
+
# disable_progress_bars()
|
324 |
+
for repo_id, allow_patterns in Config.HF_REPOS.items():
|
325 |
download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
|
326 |
|
327 |
# https://www.gradio.app/docs/gradio/interface#interface-queue
|
lib/__init__.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from .config import Config
|
2 |
from .inference import generate
|
3 |
from .utils import (
|
4 |
-
async_call,
|
5 |
disable_progress_bars,
|
6 |
download_repo_files,
|
7 |
read_file,
|
@@ -10,7 +9,6 @@ from .utils import (
|
|
10 |
|
11 |
__all__ = [
|
12 |
"Config",
|
13 |
-
"async_call",
|
14 |
"disable_progress_bars",
|
15 |
"download_repo_files",
|
16 |
"generate",
|
|
|
1 |
from .config import Config
|
2 |
from .inference import generate
|
3 |
from .utils import (
|
|
|
4 |
disable_progress_bars,
|
5 |
download_repo_files,
|
6 |
read_file,
|
|
|
9 |
|
10 |
__all__ = [
|
11 |
"Config",
|
|
|
12 |
"disable_progress_bars",
|
13 |
"download_repo_files",
|
14 |
"generate",
|
lib/config.py
CHANGED
@@ -56,15 +56,12 @@ _sdxl_files = [
|
|
56 |
"tokenizer/vocab.json",
|
57 |
]
|
58 |
|
|
|
|
|
59 |
# Using namespace instead of dataclass for simplicity
|
60 |
Config = SimpleNamespace(
|
61 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
62 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
63 |
-
HF_MODELS={
|
64 |
-
"segmind/Segmind-Vega": [*_sdxl_files],
|
65 |
-
"stabilityai/stable-diffusion-xl-base-1.0": [*_sdxl_files, "vae_1_0/config.json"],
|
66 |
-
"stabilityai/stable-diffusion-xl-refiner-1.0": [*_sdxl_refiner_files],
|
67 |
-
},
|
68 |
PIPELINES={
|
69 |
"txt2img": StableDiffusionXLPipeline,
|
70 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
@@ -77,13 +74,21 @@ Config = SimpleNamespace(
|
|
77 |
"SG161222/RealVisXL_V5.0",
|
78 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
79 |
],
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
"
|
84 |
-
"
|
85 |
-
"
|
|
|
|
|
|
|
86 |
},
|
|
|
|
|
|
|
|
|
|
|
87 |
VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
|
88 |
REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
|
89 |
SCHEDULER="Euler",
|
|
|
56 |
"tokenizer/vocab.json",
|
57 |
]
|
58 |
|
59 |
+
_sdxl_files_with_vae = [*_sdxl_files, "vae_1_0/config.json"]
|
60 |
+
|
61 |
# Using namespace instead of dataclass for simplicity
|
62 |
Config = SimpleNamespace(
|
63 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
64 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
|
|
|
|
|
|
|
|
|
|
65 |
PIPELINES={
|
66 |
"txt2img": StableDiffusionXLPipeline,
|
67 |
"img2img": StableDiffusionXLImg2ImgPipeline,
|
|
|
74 |
"SG161222/RealVisXL_V5.0",
|
75 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
76 |
],
|
77 |
+
HF_REPOS={
|
78 |
+
"ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
|
79 |
+
"cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
|
80 |
+
"fluently/Fluently-XL-Final": ["FluentlyXL-Final.safetensors"],
|
81 |
+
"madebyollin/sdxl-vae-fp16-fix": ["config.json", "diffusion_pytorch_model.fp16.safetensors"],
|
82 |
+
"segmind/Segmind-Vega": _sdxl_files,
|
83 |
+
"SG161222/RealVisXL_V5.0": ["RealVisXL_V5.0_fp16.safetensors"],
|
84 |
+
"stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
|
85 |
+
"stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
|
86 |
},
|
87 |
+
SINGLE_FILE_MODELS=[
|
88 |
+
"cyberdelia/cyberrealsticxl",
|
89 |
+
"fluently/fluently-xl-final",
|
90 |
+
"sg161222/realvisxl_v5.0",
|
91 |
+
],
|
92 |
VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
|
93 |
REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
|
94 |
SCHEDULER="Euler",
|
lib/inference.py
CHANGED
@@ -9,7 +9,7 @@ from spaces import GPU
|
|
9 |
from .config import Config
|
10 |
from .loader import Loader
|
11 |
from .logger import Logger
|
12 |
-
from .utils import
|
13 |
|
14 |
|
15 |
# Dynamic signature for the GPU duration function; max 60s per image
|
@@ -55,6 +55,11 @@ def generate(
|
|
55 |
Info=None,
|
56 |
progress=None,
|
57 |
):
|
|
|
|
|
|
|
|
|
|
|
58 |
start = time.perf_counter()
|
59 |
log = Logger("generate")
|
60 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
@@ -69,11 +74,6 @@ def generate(
|
|
69 |
if seed is None or seed < 0:
|
70 |
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
|
71 |
|
72 |
-
KIND = "txt2img"
|
73 |
-
CURRENT_STEP = 0
|
74 |
-
CURRENT_IMAGE = 1
|
75 |
-
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
76 |
-
|
77 |
# custom progress bar for multiple images
|
78 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
79 |
nonlocal CURRENT_IMAGE, CURRENT_STEP
|
@@ -107,29 +107,29 @@ def generate(
|
|
107 |
progress,
|
108 |
)
|
109 |
|
110 |
-
if loader.pipe is None:
|
111 |
-
raise Error(f"Error loading {model}")
|
112 |
-
|
113 |
-
pipe = loader.pipe
|
114 |
refiner = loader.refiner
|
|
|
115 |
upscaler = loader.upscaler
|
116 |
|
|
|
|
|
|
|
117 |
# prompt embeds for base and refiner
|
118 |
compel_1 = Compel(
|
119 |
-
text_encoder=[
|
120 |
-
tokenizer=[
|
121 |
requires_pooled=[False, True],
|
122 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
123 |
-
dtype_for_device_getter=lambda _:
|
124 |
-
device=
|
125 |
)
|
126 |
compel_2 = Compel(
|
127 |
-
text_encoder=[
|
128 |
-
tokenizer=[
|
129 |
requires_pooled=[True],
|
130 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
131 |
-
dtype_for_device_getter=lambda _:
|
132 |
-
device=
|
133 |
)
|
134 |
|
135 |
images = []
|
@@ -138,7 +138,7 @@ def generate(
|
|
138 |
|
139 |
for i in range(num_images):
|
140 |
try:
|
141 |
-
generator = torch.Generator(device=
|
142 |
conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
|
143 |
conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
|
144 |
except PromptParser.ParsingException:
|
@@ -186,7 +186,7 @@ def generate(
|
|
186 |
refiner_kwargs["callback_on_step_end"] = callback_on_step_end
|
187 |
|
188 |
try:
|
189 |
-
image =
|
190 |
if use_refiner:
|
191 |
refiner_kwargs["image"] = image
|
192 |
image = refiner(**refiner_kwargs).images[0]
|
@@ -207,7 +207,7 @@ def generate(
|
|
207 |
safe_progress(progress, i + 1, num_images, desc=msg)
|
208 |
|
209 |
# Flush memory after generating
|
210 |
-
|
211 |
|
212 |
end = time.perf_counter()
|
213 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
|
|
9 |
from .config import Config
|
10 |
from .loader import Loader
|
11 |
from .logger import Logger
|
12 |
+
from .utils import cuda_collect, safe_progress, timer
|
13 |
|
14 |
|
15 |
# Dynamic signature for the GPU duration function; max 60s per image
|
|
|
55 |
Info=None,
|
56 |
progress=None,
|
57 |
):
|
58 |
+
KIND = "txt2img"
|
59 |
+
CURRENT_STEP = 0
|
60 |
+
CURRENT_IMAGE = 1
|
61 |
+
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
62 |
+
|
63 |
start = time.perf_counter()
|
64 |
log = Logger("generate")
|
65 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
|
|
74 |
if seed is None or seed < 0:
|
75 |
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
77 |
# custom progress bar for multiple images
|
78 |
def callback_on_step_end(pipeline, step, timestep, latents):
|
79 |
nonlocal CURRENT_IMAGE, CURRENT_STEP
|
|
|
107 |
progress,
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
110 |
refiner = loader.refiner
|
111 |
+
pipeline = loader.pipeline
|
112 |
upscaler = loader.upscaler
|
113 |
|
114 |
+
if pipeline is None:
|
115 |
+
raise Error(f"Error loading {model}")
|
116 |
+
|
117 |
# prompt embeds for base and refiner
|
118 |
compel_1 = Compel(
|
119 |
+
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
|
120 |
+
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
|
121 |
requires_pooled=[False, True],
|
122 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
123 |
+
dtype_for_device_getter=lambda _: pipeline.dtype,
|
124 |
+
device=pipeline.device,
|
125 |
)
|
126 |
compel_2 = Compel(
|
127 |
+
text_encoder=[pipeline.text_encoder_2],
|
128 |
+
tokenizer=[pipeline.tokenizer_2],
|
129 |
requires_pooled=[True],
|
130 |
returned_embeddings_type=EMBEDDINGS_TYPE,
|
131 |
+
dtype_for_device_getter=lambda _: pipeline.dtype,
|
132 |
+
device=pipeline.device,
|
133 |
)
|
134 |
|
135 |
images = []
|
|
|
138 |
|
139 |
for i in range(num_images):
|
140 |
try:
|
141 |
+
generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
|
142 |
conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
|
143 |
conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
|
144 |
except PromptParser.ParsingException:
|
|
|
186 |
refiner_kwargs["callback_on_step_end"] = callback_on_step_end
|
187 |
|
188 |
try:
|
189 |
+
image = pipeline(**pipe_kwargs).images[0]
|
190 |
if use_refiner:
|
191 |
refiner_kwargs["image"] = image
|
192 |
image = refiner(**refiner_kwargs).images[0]
|
|
|
207 |
safe_progress(progress, i + 1, num_images, desc=msg)
|
208 |
|
209 |
# Flush memory after generating
|
210 |
+
cuda_collect()
|
211 |
|
212 |
end = time.perf_counter()
|
213 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
lib/loader.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
import gc
|
2 |
-
from threading import Lock
|
3 |
|
4 |
import torch
|
5 |
from DeepCache import DeepCacheSDHelper
|
@@ -8,194 +7,125 @@ from diffusers.models import AutoencoderKL
|
|
8 |
from .config import Config
|
9 |
from .logger import Logger
|
10 |
from .upscaler import RealESRGAN
|
11 |
-
from .utils import
|
12 |
|
13 |
|
14 |
class Loader:
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
cls._instance = super().__new__(cls)
|
22 |
-
cls._instance.pipe = None
|
23 |
-
cls._instance.model = None
|
24 |
-
cls._instance.refiner = None
|
25 |
-
cls._instance.upscaler = None
|
26 |
-
cls._instance.log = Logger("Loader")
|
27 |
-
return cls._instance
|
28 |
-
|
29 |
-
def _should_unload_refiner(self, refiner=False):
|
30 |
-
if self.refiner is None:
|
31 |
-
return False
|
32 |
-
if not refiner:
|
33 |
-
return True
|
34 |
-
return False
|
35 |
|
36 |
-
def
|
37 |
-
|
38 |
-
return True
|
39 |
-
return False
|
40 |
|
41 |
-
def
|
42 |
-
|
|
|
|
|
|
|
43 |
if has_deepcache and interval == 1:
|
44 |
return True
|
45 |
-
if has_deepcache and self.
|
46 |
return True
|
47 |
return False
|
48 |
|
49 |
-
def
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
return True
|
54 |
return False
|
55 |
|
56 |
-
def
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
with timer("Unloading refiner"):
|
59 |
self.refiner.to("cpu", silence_dtype_warnings=True)
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
if self.upscaler is not None:
|
63 |
with timer(f"Unloading {self.upscaler.scale}x upscaler"):
|
64 |
self.upscaler.to("cpu")
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
if self.pipe.deepcache is not None:
|
68 |
-
self.log.info("Disabling DeepCache")
|
69 |
-
self.pipe.deepcache.disable()
|
70 |
-
delattr(self.pipe, "deepcache")
|
71 |
-
if self.refiner is not None:
|
72 |
-
if hasattr(self.refiner, "deepcache"):
|
73 |
-
self.refiner.deepcache.disable()
|
74 |
-
delattr(self.refiner, "deepcache")
|
75 |
-
|
76 |
-
def _unload_pipeline(self):
|
77 |
-
if self.pipe is not None:
|
78 |
with timer(f"Unloading {self.model}"):
|
79 |
-
self.
|
80 |
-
if self.refiner
|
81 |
self.refiner.vae = None
|
82 |
self.refiner.scheduler = None
|
83 |
self.refiner.tokenizer_2 = None
|
84 |
self.refiner.text_encoder_2 = None
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
self.
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
# Flush cache and run garbage collector
|
105 |
-
clear_cuda_cache()
|
106 |
-
for component in to_unload:
|
107 |
-
setattr(self, component, None)
|
108 |
-
gc.collect()
|
109 |
-
|
110 |
-
def _should_load_refiner(self, refiner=False):
|
111 |
-
if self.refiner is None and refiner:
|
112 |
-
return True
|
113 |
-
return False
|
114 |
-
|
115 |
-
def _should_load_upscaler(self, scale=1):
|
116 |
-
if self.upscaler is None and scale > 1:
|
117 |
-
return True
|
118 |
-
return False
|
119 |
-
|
120 |
-
def _should_load_deepcache(self, interval=1):
|
121 |
-
has_deepcache = hasattr(self.pipe, "deepcache")
|
122 |
-
if not has_deepcache and interval != 1:
|
123 |
-
return True
|
124 |
-
if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
|
125 |
-
return True
|
126 |
-
return False
|
127 |
-
|
128 |
-
def _should_load_pipeline(self):
|
129 |
-
if self.pipe is None:
|
130 |
-
return True
|
131 |
-
return False
|
132 |
-
|
133 |
-
def _load_refiner(self, refiner, progress, **kwargs):
|
134 |
-
if self._should_load_refiner(refiner):
|
135 |
-
model = Config.REFINER_MODEL
|
136 |
-
pipeline = Config.PIPELINES["img2img"]
|
137 |
-
try:
|
138 |
-
with timer(f"Loading {model}"):
|
139 |
-
self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
|
140 |
-
except Exception as e:
|
141 |
-
self.log.error(f"Error loading {model}: {e}")
|
142 |
-
self.refiner = None
|
143 |
-
return
|
144 |
if self.refiner is not None:
|
145 |
self.refiner.set_progress_bar_config(disable=progress is not None)
|
146 |
|
147 |
-
def
|
148 |
-
if self.
|
149 |
try:
|
150 |
with timer(f"Loading {scale}x upscaler"):
|
151 |
-
self.upscaler = RealESRGAN(scale, device=self.
|
152 |
self.upscaler.load_weights()
|
153 |
except Exception as e:
|
154 |
self.log.error(f"Error loading {scale}x upscaler: {e}")
|
155 |
self.upscaler = None
|
156 |
|
157 |
-
def
|
158 |
-
if self.
|
159 |
self.log.info("Enabling DeepCache")
|
160 |
-
self.
|
161 |
-
self.
|
162 |
-
self.
|
163 |
-
if self.refiner
|
164 |
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
165 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
166 |
self.refiner.deepcache.enable()
|
167 |
|
168 |
-
def
|
169 |
-
pipeline = Config.PIPELINES[kind]
|
170 |
-
if self._should_load_pipeline():
|
171 |
-
try:
|
172 |
-
with timer(f"Loading {model}"):
|
173 |
-
self.model = model
|
174 |
-
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
|
175 |
-
self.pipe = pipeline.from_single_file(
|
176 |
-
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
|
177 |
-
**kwargs,
|
178 |
-
).to("cuda")
|
179 |
-
else:
|
180 |
-
self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
|
181 |
-
if self.refiner is not None:
|
182 |
-
self.refiner.vae = self.pipe.vae
|
183 |
-
self.refiner.scheduler = self.pipe.scheduler
|
184 |
-
self.refiner.tokenizer_2 = self.pipe.tokenizer_2
|
185 |
-
self.refiner.text_encoder_2 = self.pipe.text_encoder_2
|
186 |
-
self.refiner.to(self.pipe.device)
|
187 |
-
except Exception as e:
|
188 |
-
self.log.error(f"Error loading {model}: {e}")
|
189 |
-
self.model = None
|
190 |
-
self.pipe = None
|
191 |
-
self.refiner = None
|
192 |
-
return
|
193 |
-
if not isinstance(self.pipe, pipeline):
|
194 |
-
self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
|
195 |
-
if self.pipe is not None:
|
196 |
-
self.pipe.set_progress_bar_config(disable=progress is not None)
|
197 |
-
|
198 |
-
def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
|
199 |
scheduler_kwargs = {
|
200 |
"beta_start": 0.00085,
|
201 |
"beta_end": 0.012,
|
@@ -205,14 +135,13 @@ class Loader:
|
|
205 |
}
|
206 |
|
207 |
if scheduler not in ["DDIM", "Euler a"]:
|
208 |
-
scheduler_kwargs["use_karras_sigmas"] =
|
209 |
|
210 |
-
# https://github.com/huggingface/diffusers/blob/8a3f0c1/scripts/convert_original_stable_diffusion_to_diffusers.py#L939
|
211 |
if scheduler == "DDIM":
|
212 |
scheduler_kwargs["clip_sample"] = False
|
213 |
scheduler_kwargs["set_alpha_to_one"] = False
|
214 |
|
215 |
-
if model.lower() not in Config.
|
216 |
variant = "fp16"
|
217 |
else:
|
218 |
variant = None
|
@@ -226,47 +155,67 @@ class Loader:
|
|
226 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
227 |
}
|
228 |
|
229 |
-
self.
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
return
|
235 |
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
same_karras = (
|
238 |
-
not hasattr(self.
|
239 |
-
or self.
|
240 |
)
|
241 |
|
242 |
-
# same model, different scheduler
|
243 |
if self.model.lower() == model.lower():
|
244 |
if not same_scheduler:
|
245 |
self.log.info(f"Enabling {scheduler}")
|
246 |
if not same_karras:
|
247 |
-
self.log.info(f"{'Enabling' if
|
248 |
if not same_scheduler or not same_karras:
|
249 |
-
self.
|
250 |
if self.refiner is not None:
|
251 |
-
self.refiner.scheduler = self.
|
252 |
|
253 |
-
if self.
|
254 |
-
# https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/model_index.json
|
255 |
refiner_kwargs = {
|
256 |
"variant": "fp16",
|
257 |
"torch_dtype": dtype,
|
258 |
"add_watermarker": False,
|
259 |
"requires_aesthetics_score": True,
|
260 |
"force_zeros_for_empty_prompt": False,
|
261 |
-
"vae": self.
|
262 |
-
"scheduler": self.
|
263 |
-
"tokenizer_2": self.
|
264 |
-
"text_encoder_2": self.
|
265 |
}
|
266 |
-
self.
|
267 |
|
268 |
-
if self.
|
269 |
-
self.
|
270 |
|
271 |
-
if self.
|
272 |
-
self.
|
|
|
1 |
+
# import gc
|
|
|
2 |
|
3 |
import torch
|
4 |
from DeepCache import DeepCacheSDHelper
|
|
|
7 |
from .config import Config
|
8 |
from .logger import Logger
|
9 |
from .upscaler import RealESRGAN
|
10 |
+
from .utils import cuda_collect, timer
|
11 |
|
12 |
|
13 |
class Loader:
|
14 |
+
def __init__(self):
|
15 |
+
self.model = ""
|
16 |
+
self.refiner = None
|
17 |
+
self.pipeline = None
|
18 |
+
self.upscaler = None
|
19 |
+
self.log = Logger("Loader")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
def should_unload_refiner(self, use_refiner=False):
|
22 |
+
return self.refiner is not None and not use_refiner
|
|
|
|
|
23 |
|
24 |
+
def should_unload_upscaler(self, scale=1):
|
25 |
+
return self.upscaler is not None and self.upscaler.scale != scale
|
26 |
+
|
27 |
+
def should_unload_deepcache(self, interval=1):
|
28 |
+
has_deepcache = hasattr(self.pipeline, "deepcache")
|
29 |
if has_deepcache and interval == 1:
|
30 |
return True
|
31 |
+
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
|
32 |
return True
|
33 |
return False
|
34 |
|
35 |
+
def should_unload_pipeline(self, model=""):
|
36 |
+
return self.pipeline is not None and self.model.lower() != model.lower()
|
37 |
+
|
38 |
+
def should_load_refiner(self, use_refiner=False):
|
39 |
+
return self.refiner is None and use_refiner
|
40 |
+
|
41 |
+
def should_load_upscaler(self, scale=1):
|
42 |
+
return self.upscaler is None and scale > 1
|
43 |
+
|
44 |
+
def should_load_deepcache(self, interval=1):
|
45 |
+
has_deepcache = hasattr(self.pipeline, "deepcache")
|
46 |
+
if not has_deepcache and interval != 1:
|
47 |
+
return True
|
48 |
+
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
|
49 |
return True
|
50 |
return False
|
51 |
|
52 |
+
def should_load_pipeline(self):
|
53 |
+
return self.pipeline is None
|
54 |
+
|
55 |
+
def unload(self, model, use_refiner, deepcache_interval, scale):
|
56 |
+
needs_gc = False
|
57 |
+
|
58 |
+
if self.should_unload_deepcache(deepcache_interval):
|
59 |
+
self.log.info("Disabling DeepCache")
|
60 |
+
self.pipeline.deepcache.disable()
|
61 |
+
delattr(self.pipeline, "deepcache")
|
62 |
+
if self.refiner:
|
63 |
+
self.refiner.deepcache.disable()
|
64 |
+
delattr(self.refiner, "deepcache")
|
65 |
+
|
66 |
+
if self.should_unload_refiner(use_refiner):
|
67 |
with timer("Unloading refiner"):
|
68 |
self.refiner.to("cpu", silence_dtype_warnings=True)
|
69 |
+
self.refiner = None
|
70 |
+
needs_gc = True
|
71 |
|
72 |
+
if self.should_unload_upscaler(scale):
|
|
|
73 |
with timer(f"Unloading {self.upscaler.scale}x upscaler"):
|
74 |
self.upscaler.to("cpu")
|
75 |
+
self.upscaler = None
|
76 |
+
needs_gc = True
|
77 |
|
78 |
+
if self.should_unload_pipeline(model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
with timer(f"Unloading {self.model}"):
|
80 |
+
self.pipeline.to("cpu", silence_dtype_warnings=True)
|
81 |
+
if self.refiner:
|
82 |
self.refiner.vae = None
|
83 |
self.refiner.scheduler = None
|
84 |
self.refiner.tokenizer_2 = None
|
85 |
self.refiner.text_encoder_2 = None
|
86 |
+
self.pipeline = None
|
87 |
+
self.model = None
|
88 |
+
needs_gc = True
|
89 |
+
|
90 |
+
if needs_gc:
|
91 |
+
cuda_collect()
|
92 |
+
# gc.collect()
|
93 |
+
|
94 |
+
def load_refiner(self, refiner_kwargs={}, progress=None):
|
95 |
+
model = Config.REFINER_MODEL
|
96 |
+
try:
|
97 |
+
with timer(f"Loading {model}"):
|
98 |
+
Pipeline = Config.PIPELINES["img2img"]
|
99 |
+
self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
|
100 |
+
except Exception as e:
|
101 |
+
self.log.error(f"Error loading {model}: {e}")
|
102 |
+
self.refiner = None
|
103 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
if self.refiner is not None:
|
105 |
self.refiner.set_progress_bar_config(disable=progress is not None)
|
106 |
|
107 |
+
def load_upscaler(self, scale=1):
|
108 |
+
if self.should_load_upscaler(scale):
|
109 |
try:
|
110 |
with timer(f"Loading {scale}x upscaler"):
|
111 |
+
self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
|
112 |
self.upscaler.load_weights()
|
113 |
except Exception as e:
|
114 |
self.log.error(f"Error loading {scale}x upscaler: {e}")
|
115 |
self.upscaler = None
|
116 |
|
117 |
+
def load_deepcache(self, interval=1):
|
118 |
+
if self.should_load_deepcache(interval):
|
119 |
self.log.info("Enabling DeepCache")
|
120 |
+
self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline)
|
121 |
+
self.pipeline.deepcache.set_params(cache_interval=interval)
|
122 |
+
self.pipeline.deepcache.enable()
|
123 |
+
if self.refiner:
|
124 |
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
125 |
self.refiner.deepcache.set_params(cache_interval=interval)
|
126 |
self.refiner.deepcache.enable()
|
127 |
|
128 |
+
def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
scheduler_kwargs = {
|
130 |
"beta_start": 0.00085,
|
131 |
"beta_end": 0.012,
|
|
|
135 |
}
|
136 |
|
137 |
if scheduler not in ["DDIM", "Euler a"]:
|
138 |
+
scheduler_kwargs["use_karras_sigmas"] = use_karras
|
139 |
|
|
|
140 |
if scheduler == "DDIM":
|
141 |
scheduler_kwargs["clip_sample"] = False
|
142 |
scheduler_kwargs["set_alpha_to_one"] = False
|
143 |
|
144 |
+
if model.lower() not in Config.SINGLE_FILE_MODELS:
|
145 |
variant = "fp16"
|
146 |
else:
|
147 |
variant = None
|
|
|
155 |
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
|
156 |
}
|
157 |
|
158 |
+
self.unload(model, use_refiner, deepcache_interval, scale)
|
159 |
+
|
160 |
+
Pipeline = Config.PIPELINES[kind]
|
161 |
+
Scheduler = Config.SCHEDULERS[scheduler]
|
162 |
+
|
163 |
+
try:
|
164 |
+
with timer(f"Loading {model}"):
|
165 |
+
self.model = model
|
166 |
+
if model.lower() in Config.SINGLE_FILE_MODELS:
|
167 |
+
checkpoint = Config.HF_REPOS[model][0]
|
168 |
+
self.pipeline = Pipeline.from_single_file(
|
169 |
+
f"https://huggingface.co/{model}/{checkpoint}",
|
170 |
+
**pipe_kwargs,
|
171 |
+
).to("cuda")
|
172 |
+
else:
|
173 |
+
self.pipeline = Pipeline.from_pretrained(model, **pipe_kwargs).to("cuda")
|
174 |
+
except Exception as e:
|
175 |
+
self.log.error(f"Error loading {model}: {e}")
|
176 |
+
self.model = None
|
177 |
+
self.pipeline = None
|
178 |
return
|
179 |
|
180 |
+
if not isinstance(self.pipeline, Pipeline):
|
181 |
+
self.pipeline = Pipeline.from_pipe(self.pipeline).to("cuda")
|
182 |
+
|
183 |
+
if self.pipeline is not None:
|
184 |
+
self.pipeline.set_progress_bar_config(disable=progress is not None)
|
185 |
+
|
186 |
+
# Check and update scheduler if necessary
|
187 |
+
same_scheduler = isinstance(self.pipeline.scheduler, Scheduler)
|
188 |
same_karras = (
|
189 |
+
not hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
|
190 |
+
or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
|
191 |
)
|
192 |
|
|
|
193 |
if self.model.lower() == model.lower():
|
194 |
if not same_scheduler:
|
195 |
self.log.info(f"Enabling {scheduler}")
|
196 |
if not same_karras:
|
197 |
+
self.log.info(f"{'Enabling' if use_karras else 'Disabling'} Karras sigmas")
|
198 |
if not same_scheduler or not same_karras:
|
199 |
+
self.pipeline.scheduler = Scheduler(**scheduler_kwargs)
|
200 |
if self.refiner is not None:
|
201 |
+
self.refiner.scheduler = self.pipeline.scheduler
|
202 |
|
203 |
+
if self.should_load_refiner(use_refiner):
|
|
|
204 |
refiner_kwargs = {
|
205 |
"variant": "fp16",
|
206 |
"torch_dtype": dtype,
|
207 |
"add_watermarker": False,
|
208 |
"requires_aesthetics_score": True,
|
209 |
"force_zeros_for_empty_prompt": False,
|
210 |
+
"vae": self.pipeline.vae,
|
211 |
+
"scheduler": self.pipeline.scheduler,
|
212 |
+
"tokenizer_2": self.pipeline.tokenizer_2,
|
213 |
+
"text_encoder_2": self.pipeline.text_encoder_2,
|
214 |
}
|
215 |
+
self.load_refiner(refiner_kwargs, progress)
|
216 |
|
217 |
+
if self.should_load_deepcache(deepcache_interval):
|
218 |
+
self.load_deepcache(deepcache_interval)
|
219 |
|
220 |
+
if self.should_load_upscaler(scale):
|
221 |
+
self.load_upscaler(scale)
|
lib/utils.py
CHANGED
@@ -1,24 +1,13 @@
|
|
1 |
import functools
|
2 |
-
import inspect
|
3 |
import json
|
4 |
import time
|
5 |
from contextlib import contextmanager
|
6 |
-
from typing import Callable, TypeVar
|
7 |
|
8 |
-
import anyio
|
9 |
import torch
|
10 |
-
from anyio import Semaphore
|
11 |
from diffusers.utils import logging as diffusers_logging
|
12 |
from huggingface_hub._snapshot_download import snapshot_download
|
13 |
from huggingface_hub.utils import are_progress_bars_disabled
|
14 |
from transformers import logging as transformers_logging
|
15 |
-
from typing_extensions import ParamSpec
|
16 |
-
|
17 |
-
T = TypeVar("T")
|
18 |
-
P = ParamSpec("P")
|
19 |
-
|
20 |
-
MAX_CONCURRENT_THREADS = 1
|
21 |
-
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
22 |
|
23 |
|
24 |
@contextmanager
|
@@ -61,7 +50,7 @@ def safe_progress(progress, current=0, total=0, desc=""):
|
|
61 |
progress((current, total), desc=desc)
|
62 |
|
63 |
|
64 |
-
def
|
65 |
if torch.cuda.is_available():
|
66 |
torch.cuda.empty_cache()
|
67 |
torch.cuda.ipc_collect()
|
@@ -83,14 +72,3 @@ def download_repo_files(repo_id, allow_patterns, token=None):
|
|
83 |
if was_disabled:
|
84 |
disable_progress_bars()
|
85 |
return snapshot_path
|
86 |
-
|
87 |
-
|
88 |
-
# Like the original but supports args and kwargs instead of a dict
|
89 |
-
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
90 |
-
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
91 |
-
async with MAX_THREADS_GUARD:
|
92 |
-
sig = inspect.signature(fn)
|
93 |
-
bound_args = sig.bind(*args, **kwargs)
|
94 |
-
bound_args.apply_defaults()
|
95 |
-
partial_fn = functools.partial(fn, **bound_args.arguments)
|
96 |
-
return await anyio.to_thread.run_sync(partial_fn)
|
|
|
1 |
import functools
|
|
|
2 |
import json
|
3 |
import time
|
4 |
from contextlib import contextmanager
|
|
|
5 |
|
|
|
6 |
import torch
|
|
|
7 |
from diffusers.utils import logging as diffusers_logging
|
8 |
from huggingface_hub._snapshot_download import snapshot_download
|
9 |
from huggingface_hub.utils import are_progress_bars_disabled
|
10 |
from transformers import logging as transformers_logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
@contextmanager
|
|
|
50 |
progress((current, total), desc=desc)
|
51 |
|
52 |
|
53 |
+
def cuda_collect():
|
54 |
if torch.cuda.is_available():
|
55 |
torch.cuda.empty_cache()
|
56 |
torch.cuda.ipc_collect()
|
|
|
72 |
if was_disabled:
|
73 |
disable_progress_bars()
|
74 |
return snapshot_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
anyio==4.6.1
|
2 |
compel==2.0.3
|
3 |
deepcache==0.1.1
|
4 |
diffusers==0.30.3
|
@@ -8,5 +7,5 @@ hf-transfer
|
|
8 |
numpy==1.26.4
|
9 |
ruff==0.6.9
|
10 |
spaces==0.30.4
|
11 |
-
torch==2.
|
12 |
-
torchvision==0.
|
|
|
|
|
1 |
compel==2.0.3
|
2 |
deepcache==0.1.1
|
3 |
diffusers==0.30.3
|
|
|
7 |
numpy==1.26.4
|
8 |
ruff==0.6.9
|
9 |
spaces==0.30.4
|
10 |
+
torch==2.4.0
|
11 |
+
torchvision==0.19.0
|