Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Update loader
Browse files- lib/inference.py +6 -7
- lib/loader.py +127 -115
lib/inference.py
CHANGED
@@ -7,7 +7,7 @@ from compel.prompt_parser import PromptParser
|
|
7 |
from gradio import Error, Info, Progress
|
8 |
from spaces import GPU
|
9 |
|
10 |
-
from .loader import
|
11 |
from .logger import Logger
|
12 |
from .utils import cuda_collect, get_output_types, timer
|
13 |
|
@@ -28,7 +28,7 @@ def generate(
|
|
28 |
num_images=1,
|
29 |
use_karras=False,
|
30 |
use_refiner=False,
|
31 |
-
|
32 |
):
|
33 |
if not torch.cuda.is_available():
|
34 |
raise Error("CUDA not available")
|
@@ -43,7 +43,7 @@ def generate(
|
|
43 |
log = Logger("generate")
|
44 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
45 |
|
46 |
-
loader =
|
47 |
loader.load(
|
48 |
KIND,
|
49 |
model,
|
@@ -52,7 +52,6 @@ def generate(
|
|
52 |
scale,
|
53 |
use_karras,
|
54 |
use_refiner,
|
55 |
-
progress,
|
56 |
)
|
57 |
|
58 |
refiner = loader.refiner
|
@@ -143,9 +142,6 @@ def generate(
|
|
143 |
seed = images[i][1]
|
144 |
images[i] = (image, seed)
|
145 |
|
146 |
-
# Flush cache after generating
|
147 |
-
cuda_collect()
|
148 |
-
|
149 |
end = time.perf_counter()
|
150 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
151 |
log.info(msg)
|
@@ -153,4 +149,7 @@ def generate(
|
|
153 |
if Info:
|
154 |
Info(msg)
|
155 |
|
|
|
|
|
|
|
156 |
return images
|
|
|
7 |
from gradio import Error, Info, Progress
|
8 |
from spaces import GPU
|
9 |
|
10 |
+
from .loader import get_loader
|
11 |
from .logger import Logger
|
12 |
from .utils import cuda_collect, get_output_types, timer
|
13 |
|
|
|
28 |
num_images=1,
|
29 |
use_karras=False,
|
30 |
use_refiner=False,
|
31 |
+
_=Progress(track_tqdm=True),
|
32 |
):
|
33 |
if not torch.cuda.is_available():
|
34 |
raise Error("CUDA not available")
|
|
|
43 |
log = Logger("generate")
|
44 |
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
|
45 |
|
46 |
+
loader = get_loader()
|
47 |
loader.load(
|
48 |
KIND,
|
49 |
model,
|
|
|
52 |
scale,
|
53 |
use_karras,
|
54 |
use_refiner,
|
|
|
55 |
)
|
56 |
|
57 |
refiner = loader.refiner
|
|
|
142 |
seed = images[i][1]
|
143 |
images[i] = (image, seed)
|
144 |
|
|
|
|
|
|
|
145 |
end = time.perf_counter()
|
146 |
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
|
147 |
log.info(msg)
|
|
|
149 |
if Info:
|
150 |
Info(msg)
|
151 |
|
152 |
+
# Flush cache before returning
|
153 |
+
cuda_collect()
|
154 |
+
|
155 |
return images
|
lib/loader.py
CHANGED
@@ -5,52 +5,64 @@ from diffusers.models import AutoencoderKL
|
|
5 |
from .config import Config
|
6 |
from .logger import Logger
|
7 |
from .upscaler import RealESRGAN
|
8 |
-
from .utils import
|
9 |
|
10 |
|
11 |
class Loader:
|
12 |
def __init__(self):
|
13 |
self.model = ""
|
|
|
14 |
self.refiner = None
|
15 |
self.pipeline = None
|
16 |
self.upscaler = None
|
17 |
self.log = Logger("Loader")
|
|
|
18 |
|
19 |
-
def
|
20 |
-
return self.refiner is not None and not use_refiner
|
21 |
-
|
22 |
-
def should_unload_upscaler(self, scale=1):
|
23 |
-
return self.upscaler is not None and self.upscaler.scale != scale
|
24 |
-
|
25 |
-
def should_unload_deepcache(self, interval=1):
|
26 |
has_deepcache = hasattr(self.pipeline, "deepcache")
|
27 |
-
if has_deepcache and
|
28 |
return True
|
29 |
-
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] !=
|
30 |
return True
|
31 |
return False
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def should_unload_pipeline(self, model=""):
|
34 |
return self.pipeline is not None and self.model != model
|
35 |
|
36 |
-
def
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
def should_load_upscaler(self, scale=1):
|
40 |
return self.upscaler is None and scale > 1
|
41 |
|
42 |
-
def
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
return True
|
46 |
-
if
|
47 |
return True
|
48 |
return False
|
49 |
|
50 |
-
def
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
def
|
54 |
if self.should_unload_deepcache(deepcache_interval):
|
55 |
self.log.info("Disabling DeepCache")
|
56 |
self.pipeline.deepcache.disable()
|
@@ -59,14 +71,14 @@ class Loader:
|
|
59 |
self.refiner.deepcache.disable()
|
60 |
delattr(self.refiner, "deepcache")
|
61 |
|
62 |
-
if self.should_unload_refiner(use_refiner):
|
63 |
-
self.log.info("Unloading refiner")
|
64 |
-
self.refiner = None
|
65 |
-
|
66 |
if self.should_unload_upscaler(scale):
|
67 |
self.log.info("Unloading upscaler")
|
68 |
self.upscaler = None
|
69 |
|
|
|
|
|
|
|
|
|
70 |
if self.should_unload_pipeline(model):
|
71 |
self.log.info(f"Unloading {self.model}")
|
72 |
if self.refiner:
|
@@ -75,58 +87,81 @@ class Loader:
|
|
75 |
self.refiner.tokenizer_2 = None
|
76 |
self.refiner.text_encoder_2 = None
|
77 |
self.pipeline = None
|
78 |
-
self.model =
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
def load_refiner(self
|
84 |
model = Config.REFINER_MODEL
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
|
112 |
-
self.upscaler.load_weights()
|
113 |
-
except Exception as e:
|
114 |
-
self.log.error(f"Error loading {scale}x upscaler: {e}")
|
115 |
-
self.upscaler = None
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
self.refiner.deepcache.enable()
|
127 |
|
128 |
-
def load(self,
|
129 |
-
Pipeline = Config.PIPELINES[kind]
|
130 |
Scheduler = Config.SCHEDULERS[scheduler]
|
131 |
|
132 |
scheduler_kwargs = {
|
@@ -137,70 +172,47 @@ class Loader:
|
|
137 |
"steps_offset": 1,
|
138 |
}
|
139 |
|
|
|
|
|
|
|
140 |
pipeline_kwargs = {
|
141 |
"torch_dtype": torch.float16,
|
142 |
"add_watermarker": False,
|
143 |
-
"scheduler":
|
144 |
-
"vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=torch.float16),
|
145 |
}
|
146 |
|
147 |
-
|
148 |
-
scheduler_kwargs["use_karras_sigmas"] = use_karras
|
149 |
-
|
150 |
if model not in Config.SINGLE_FILE_MODELS:
|
151 |
pipeline_kwargs["variant"] = "fp16"
|
152 |
else:
|
153 |
pipeline_kwargs["variant"] = None
|
154 |
|
155 |
# Unload
|
156 |
-
self.
|
157 |
|
158 |
# Load
|
159 |
-
|
160 |
-
|
161 |
-
self.model = model
|
162 |
-
if model in Config.SINGLE_FILE_MODELS:
|
163 |
-
checkpoint = Config.HF_REPOS[model][0]
|
164 |
-
self.pipeline = Pipeline.from_single_file(
|
165 |
-
f"https://huggingface.co/{model}/{checkpoint}",
|
166 |
-
**pipeline_kwargs,
|
167 |
-
).to("cuda")
|
168 |
-
else:
|
169 |
-
self.pipeline = Pipeline.from_pretrained(model, **pipeline_kwargs).to("cuda")
|
170 |
-
except Exception as e:
|
171 |
-
self.log.error(f"Error loading {model}: {e}")
|
172 |
-
self.model = None
|
173 |
-
self.pipeline = None
|
174 |
-
return
|
175 |
-
|
176 |
-
if not isinstance(self.pipeline, Pipeline):
|
177 |
-
self.pipeline = Pipeline.from_pipe(self.pipeline).to("cuda")
|
178 |
-
|
179 |
-
if self.pipeline is not None:
|
180 |
-
self.pipeline.set_progress_bar_config(disable=progress is not None)
|
181 |
-
|
182 |
-
# Check and update scheduler if necessary
|
183 |
-
same_scheduler = isinstance(self.pipeline.scheduler, Scheduler)
|
184 |
-
same_karras = (
|
185 |
-
not hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
|
186 |
-
or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
|
187 |
-
)
|
188 |
-
|
189 |
-
if self.model == model:
|
190 |
-
if not same_scheduler:
|
191 |
-
self.log.info(f"Enabling {scheduler}")
|
192 |
-
if not same_karras:
|
193 |
-
self.log.info(f"{'Enabling' if use_karras else 'Disabling'} Karras sigmas")
|
194 |
-
if not same_scheduler or not same_karras:
|
195 |
-
self.pipeline.scheduler = Scheduler(**scheduler_kwargs)
|
196 |
-
if self.refiner is not None:
|
197 |
-
self.refiner.scheduler = self.pipeline.scheduler
|
198 |
|
199 |
if self.should_load_refiner(use_refiner):
|
200 |
-
self.load_refiner(
|
|
|
|
|
|
|
201 |
|
202 |
if self.should_load_deepcache(deepcache_interval):
|
203 |
self.load_deepcache(deepcache_interval)
|
204 |
|
205 |
if self.should_load_upscaler(scale):
|
206 |
self.load_upscaler(scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from .config import Config
|
6 |
from .logger import Logger
|
7 |
from .upscaler import RealESRGAN
|
8 |
+
from .utils import timer
|
9 |
|
10 |
|
11 |
class Loader:
|
12 |
def __init__(self):
|
13 |
self.model = ""
|
14 |
+
self.vae = None
|
15 |
self.refiner = None
|
16 |
self.pipeline = None
|
17 |
self.upscaler = None
|
18 |
self.log = Logger("Loader")
|
19 |
+
self.device = torch.device("cuda") # always called in CUDA context
|
20 |
|
21 |
+
def should_unload_deepcache(self, cache_interval=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
has_deepcache = hasattr(self.pipeline, "deepcache")
|
23 |
+
if has_deepcache and cache_interval == 1:
|
24 |
return True
|
25 |
+
if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval:
|
26 |
return True
|
27 |
return False
|
28 |
|
29 |
+
def should_unload_upscaler(self, scale=1):
|
30 |
+
return self.upscaler is not None and self.upscaler.scale != scale
|
31 |
+
|
32 |
+
def should_unload_refiner(self, use_refiner=False):
|
33 |
+
return self.refiner is not None and not use_refiner
|
34 |
+
|
35 |
def should_unload_pipeline(self, model=""):
|
36 |
return self.pipeline is not None and self.model != model
|
37 |
|
38 |
+
def should_load_deepcache(self, cache_interval=1):
|
39 |
+
has_deepcache = hasattr(self.pipeline, "deepcache")
|
40 |
+
if not has_deepcache and cache_interval > 1:
|
41 |
+
return True
|
42 |
+
return False
|
43 |
|
44 |
def should_load_upscaler(self, scale=1):
|
45 |
return self.upscaler is None and scale > 1
|
46 |
|
47 |
+
def should_load_refiner(self, use_refiner=False):
|
48 |
+
return self.refiner is None and use_refiner
|
49 |
+
|
50 |
+
def should_load_pipeline(self, pipeline_id=""):
|
51 |
+
if self.pipeline is None:
|
52 |
return True
|
53 |
+
if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]):
|
54 |
return True
|
55 |
return False
|
56 |
|
57 |
+
def should_load_scheduler(self, cls, use_karras=False):
|
58 |
+
has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
|
59 |
+
if not isinstance(self.pipeline.scheduler, cls):
|
60 |
+
return True
|
61 |
+
if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras:
|
62 |
+
return True
|
63 |
+
return False
|
64 |
|
65 |
+
def unload_all(self, model, deepcache_interval, scale, use_refiner):
|
66 |
if self.should_unload_deepcache(deepcache_interval):
|
67 |
self.log.info("Disabling DeepCache")
|
68 |
self.pipeline.deepcache.disable()
|
|
|
71 |
self.refiner.deepcache.disable()
|
72 |
delattr(self.refiner, "deepcache")
|
73 |
|
|
|
|
|
|
|
|
|
74 |
if self.should_unload_upscaler(scale):
|
75 |
self.log.info("Unloading upscaler")
|
76 |
self.upscaler = None
|
77 |
|
78 |
+
if self.should_unload_refiner(use_refiner):
|
79 |
+
self.log.info("Unloading refiner")
|
80 |
+
self.refiner = None
|
81 |
+
|
82 |
if self.should_unload_pipeline(model):
|
83 |
self.log.info(f"Unloading {self.model}")
|
84 |
if self.refiner:
|
|
|
87 |
self.refiner.tokenizer_2 = None
|
88 |
self.refiner.text_encoder_2 = None
|
89 |
self.pipeline = None
|
90 |
+
self.model = ""
|
91 |
|
92 |
+
def load_deepcache(self, interval=1):
|
93 |
+
self.log.info("Enabling DeepCache")
|
94 |
+
self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline)
|
95 |
+
self.pipeline.deepcache.set_params(cache_interval=interval)
|
96 |
+
self.pipeline.deepcache.enable()
|
97 |
+
if self.refiner:
|
98 |
+
self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
|
99 |
+
self.refiner.deepcache.set_params(cache_interval=interval)
|
100 |
+
self.refiner.deepcache.enable()
|
101 |
+
|
102 |
+
def load_upscaler(self, scale=1):
|
103 |
+
with timer(f"Loading {scale}x upscaler", logger=self.log.info):
|
104 |
+
self.upscaler = RealESRGAN(scale, device=self.device)
|
105 |
+
self.upscaler.load_weights()
|
106 |
|
107 |
+
def load_refiner(self):
|
108 |
model = Config.REFINER_MODEL
|
109 |
+
with timer(f"Loading {model}", logger=self.log.info):
|
110 |
+
refiner_kwargs = {
|
111 |
+
"variant": "fp16",
|
112 |
+
"torch_dtype": self.pipeline.dtype,
|
113 |
+
"add_watermarker": False,
|
114 |
+
"requires_aesthetics_score": True,
|
115 |
+
"force_zeros_for_empty_prompt": False,
|
116 |
+
"vae": self.pipeline.vae,
|
117 |
+
"scheduler": self.pipeline.scheduler,
|
118 |
+
"tokenizer_2": self.pipeline.tokenizer_2,
|
119 |
+
"text_encoder_2": self.pipeline.text_encoder_2,
|
120 |
+
}
|
121 |
+
Pipeline = Config.PIPELINES["img2img"]
|
122 |
+
self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to(self.device)
|
123 |
+
self.refiner.set_progress_bar_config(disable=True)
|
124 |
+
|
125 |
+
def load_pipeline(self, pipeline_id, model, **kwargs):
|
126 |
+
Pipeline = Config.PIPELINES[pipeline_id]
|
127 |
+
|
128 |
+
# Load VAE first
|
129 |
+
if self.vae is None:
|
130 |
+
self.vae = AutoencoderKL.from_pretrained(
|
131 |
+
Config.VAE_MODEL,
|
132 |
+
torch_dtype=torch.float32, # vae is full-precision
|
133 |
+
).to(self.device)
|
134 |
+
|
135 |
+
kwargs["vae"] = self.vae
|
136 |
+
|
137 |
+
# Load from scratch
|
138 |
+
if self.pipeline is None:
|
139 |
+
with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info):
|
140 |
+
if model in Config.SINGLE_FILE_MODELS:
|
141 |
+
checkpoint = Config.HF_REPOS[model][0]
|
142 |
+
self.pipeline = Pipeline.from_single_file(
|
143 |
+
f"https://huggingface.co/{model}/{checkpoint}",
|
144 |
+
**kwargs,
|
145 |
+
).to(self.device)
|
146 |
+
else:
|
147 |
+
self.pipeline = Pipeline.from_pretrained(model, **kwargs).to(self.device)
|
148 |
|
149 |
+
# Change to a different one
|
150 |
+
else:
|
151 |
+
with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info):
|
152 |
+
self.pipeline = Pipeline.from_pipe(self.pipeline).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
# Update model and disable terminal progress bars
|
155 |
+
self.model = model
|
156 |
+
self.pipeline.set_progress_bar_config(disable=True)
|
157 |
+
|
158 |
+
def load_scheduler(self, cls, use_karras=False, **kwargs):
|
159 |
+
self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}")
|
160 |
+
self.pipeline.scheduler = cls(**kwargs)
|
161 |
+
if self.refiner is not None:
|
162 |
+
self.refiner.scheduler = self.pipeline.scheduler
|
|
|
163 |
|
164 |
+
def load(self, pipeline_id, model, scheduler, deepcache_interval, scale, use_karras, use_refiner):
|
|
|
165 |
Scheduler = Config.SCHEDULERS[scheduler]
|
166 |
|
167 |
scheduler_kwargs = {
|
|
|
172 |
"steps_offset": 1,
|
173 |
}
|
174 |
|
175 |
+
if scheduler not in ["Euler a"]:
|
176 |
+
scheduler_kwargs["use_karras_sigmas"] = use_karras
|
177 |
+
|
178 |
pipeline_kwargs = {
|
179 |
"torch_dtype": torch.float16,
|
180 |
"add_watermarker": False,
|
181 |
+
"scheduler": Scheduler(**scheduler_kwargs),
|
|
|
182 |
}
|
183 |
|
184 |
+
# Single-file models don't need a variant
|
|
|
|
|
185 |
if model not in Config.SINGLE_FILE_MODELS:
|
186 |
pipeline_kwargs["variant"] = "fp16"
|
187 |
else:
|
188 |
pipeline_kwargs["variant"] = None
|
189 |
|
190 |
# Unload
|
191 |
+
self.unload_all(model, deepcache_interval, scale, use_refiner)
|
192 |
|
193 |
# Load
|
194 |
+
if self.should_load_pipeline(pipeline_id):
|
195 |
+
self.load_pipeline(pipeline_id, model, **pipeline_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
if self.should_load_refiner(use_refiner):
|
198 |
+
self.load_refiner()
|
199 |
+
|
200 |
+
if self.should_load_scheduler(Scheduler, use_karras):
|
201 |
+
self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs)
|
202 |
|
203 |
if self.should_load_deepcache(deepcache_interval):
|
204 |
self.load_deepcache(deepcache_interval)
|
205 |
|
206 |
if self.should_load_upscaler(scale):
|
207 |
self.load_upscaler(scale)
|
208 |
+
|
209 |
+
|
210 |
+
# Get a singleton or a new instance of the Loader
|
211 |
+
def get_loader(singleton=False):
|
212 |
+
if not singleton:
|
213 |
+
return Loader()
|
214 |
+
else:
|
215 |
+
if not hasattr(get_loader, "_instance"):
|
216 |
+
get_loader._instance = Loader()
|
217 |
+
assert isinstance(get_loader._instance, Loader)
|
218 |
+
return get_loader._instance
|