John6666 commited on
Commit
b57ab2a
·
verified ·
1 Parent(s): 0dae5e8

Upload dc.py

Browse files
Files changed (1) hide show
  1. dc.py +126 -304
dc.py CHANGED
@@ -59,202 +59,60 @@ from stablepy import logger
59
  logger.setLevel(logging.DEBUG)
60
 
61
  from env import (
62
- HF_TOKEN, hf_read_token, # to use only for private repos
63
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
64
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
65
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
66
- directory_models, directory_loras, directory_vaes, directory_embeds,
67
- directory_embeds_sdxl, directory_embeds_positive_sdxl,
68
- load_diffusers_format_model, download_model_list, download_lora_list,
69
- download_vae_list, download_embeds)
70
-
71
- PREPROCESSOR_CONTROLNET = {
72
- "openpose": [
73
- "Openpose",
74
- "None",
75
- ],
76
- "scribble": [
77
- "HED",
78
- "PidiNet",
79
- "None",
80
- ],
81
- "softedge": [
82
- "PidiNet",
83
- "HED",
84
- "HED safe",
85
- "PidiNet safe",
86
- "None",
87
- ],
88
- "segmentation": [
89
- "UPerNet",
90
- "None",
91
- ],
92
- "depth": [
93
- "DPT",
94
- "Midas",
95
- "None",
96
- ],
97
- "normalbae": [
98
- "NormalBae",
99
- "None",
100
- ],
101
- "lineart": [
102
- "Lineart",
103
- "Lineart coarse",
104
- "Lineart (anime)",
105
- "None",
106
- "None (anime)",
107
- ],
108
- "lineart_anime": [
109
- "Lineart",
110
- "Lineart coarse",
111
- "Lineart (anime)",
112
- "None",
113
- "None (anime)",
114
- ],
115
- "shuffle": [
116
- "ContentShuffle",
117
- "None",
118
- ],
119
- "canny": [
120
- "Canny",
121
- "None",
122
- ],
123
- "mlsd": [
124
- "MLSD",
125
- "None",
126
- ],
127
- "ip2p": [
128
- "ip2p"
129
- ],
130
- "recolor": [
131
- "Recolor luminance",
132
- "Recolor intensity",
133
- "None",
134
- ],
135
- "tile": [
136
- "Mild Blur",
137
- "Moderate Blur",
138
- "Heavy Blur",
139
- "None",
140
- ],
141
- }
142
-
143
- TASK_STABLEPY = {
144
- 'txt2img': 'txt2img',
145
- 'img2img': 'img2img',
146
- 'inpaint': 'inpaint',
147
- # 'canny T2I Adapter': 'sdxl_canny_t2i', # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
148
- # 'sketch T2I Adapter': 'sdxl_sketch_t2i',
149
- # 'lineart T2I Adapter': 'sdxl_lineart_t2i',
150
- # 'depth-midas T2I Adapter': 'sdxl_depth-midas_t2i',
151
- # 'openpose T2I Adapter': 'sdxl_openpose_t2i',
152
- 'openpose ControlNet': 'openpose',
153
- 'canny ControlNet': 'canny',
154
- 'mlsd ControlNet': 'mlsd',
155
- 'scribble ControlNet': 'scribble',
156
- 'softedge ControlNet': 'softedge',
157
- 'segmentation ControlNet': 'segmentation',
158
- 'depth ControlNet': 'depth',
159
- 'normalbae ControlNet': 'normalbae',
160
- 'lineart ControlNet': 'lineart',
161
- 'lineart_anime ControlNet': 'lineart_anime',
162
- 'shuffle ControlNet': 'shuffle',
163
- 'ip2p ControlNet': 'ip2p',
164
- 'optical pattern ControlNet': 'pattern',
165
- 'recolor ControlNet': 'recolor',
166
- 'tile ControlNet': 'tile',
167
- }
168
-
169
- TASK_MODEL_LIST = list(TASK_STABLEPY.keys())
170
-
171
- UPSCALER_DICT_GUI = {
172
- None: None,
173
- "Lanczos": "Lanczos",
174
- "Nearest": "Nearest",
175
- 'Latent': 'Latent',
176
- 'Latent (antialiased)': 'Latent (antialiased)',
177
- 'Latent (bicubic)': 'Latent (bicubic)',
178
- 'Latent (bicubic antialiased)': 'Latent (bicubic antialiased)',
179
- 'Latent (nearest)': 'Latent (nearest)',
180
- 'Latent (nearest-exact)': 'Latent (nearest-exact)',
181
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
182
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
183
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
184
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
185
- "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
186
- "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
187
- "realesr-general-wdn-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
188
- "4x-UltraSharp": "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
189
- "4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
190
- "Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
191
- "AnimeSharp4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
192
- "lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
193
- "RealisticRescaler4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
194
- "NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
195
- }
196
-
197
- UPSCALER_KEYS = list(UPSCALER_DICT_GUI.keys())
198
-
199
-
200
- def get_model_list(directory_path):
201
- model_list = []
202
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
203
-
204
- for filename in os.listdir(directory_path):
205
- if os.path.splitext(filename)[1] in valid_extensions:
206
- # name_without_extension = os.path.splitext(filename)[0]
207
- file_path = os.path.join(directory_path, filename)
208
- # model_list.append((name_without_extension, file_path))
209
- model_list.append(file_path)
210
- print('\033[34mFILE: ' + file_path + '\033[0m')
211
- return model_list
212
 
213
  ## BEGIN MOD
214
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
215
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
216
 
217
  # - **Download Models**
218
- download_model = ", ".join(download_model_list)
219
  # - **Download VAEs**
220
- download_vae = ", ".join(download_vae_list)
221
  # - **Download LoRAs**
222
- download_lora = ", ".join(download_lora_list)
223
 
224
- #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
225
- download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)
226
 
227
- load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
228
  ## END MOD
229
 
230
  # Download stuffs
231
  for url in [url.strip() for url in download_model.split(',')]:
232
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
233
- download_things(directory_models, url, HF_TOKEN, CIVITAI_API_KEY)
234
  for url in [url.strip() for url in download_vae.split(',')]:
235
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
236
- download_things(directory_vaes, url, HF_TOKEN, CIVITAI_API_KEY)
237
  for url in [url.strip() for url in download_lora.split(',')]:
238
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
239
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
240
 
241
  # Download Embeddings
242
  for url_embed in download_embeds:
243
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
244
- download_things(directory_embeds, url_embed, HF_TOKEN, CIVITAI_API_KEY)
245
 
246
  # Build list models
247
- embed_list = get_model_list(directory_embeds)
248
- model_list = get_model_list(directory_models)
249
  model_list = load_diffusers_format_model + model_list
250
  ## BEGIN MOD
251
  lora_model_list = get_lora_model_list()
252
- vae_model_list = get_model_list(directory_vaes)
253
  vae_model_list.insert(0, "None")
254
 
255
- #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
256
- #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
257
- embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)
258
 
259
  def get_embed_list(pipeline_name):
260
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
@@ -262,99 +120,13 @@ def get_embed_list(pipeline_name):
262
 
263
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
264
 
265
- msg_inc_vae = (
266
- "Use the right VAE for your model to maintain image quality. The wrong"
267
- " VAE can lead to poor results, like blurriness in the generated images."
268
- )
269
-
270
- SDXL_TASK = [k for k, v in TASK_STABLEPY.items() if v in SDXL_TASKS]
271
- SD_TASK = [k for k, v in TASK_STABLEPY.items() if v in SD15_TASKS]
272
- FLUX_TASK = list(TASK_STABLEPY.keys())[:3] + [k for k, v in TASK_STABLEPY.items() if v in FLUX_CN_UNION_MODES.keys()]
273
-
274
- MODEL_TYPE_TASK = {
275
- "SD 1.5": SD_TASK,
276
- "SDXL": SDXL_TASK,
277
- "FLUX": FLUX_TASK,
278
- }
279
-
280
- MODEL_TYPE_CLASS = {
281
- "diffusers:StableDiffusionPipeline": "SD 1.5",
282
- "diffusers:StableDiffusionXLPipeline": "SDXL",
283
- "diffusers:FluxPipeline": "FLUX",
284
- }
285
-
286
- POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
287
-
288
- def extract_parameters(input_string):
289
- parameters = {}
290
- input_string = input_string.replace("\n", "")
291
-
292
- if "Negative prompt:" not in input_string:
293
- if "Steps:" in input_string:
294
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
295
- else:
296
- print("Invalid metadata")
297
- parameters["prompt"] = input_string
298
- return parameters
299
-
300
- parm = input_string.split("Negative prompt:")
301
- parameters["prompt"] = parm[0].strip()
302
- if "Steps:" not in parm[1]:
303
- print("Steps not detected")
304
- parameters["neg_prompt"] = parm[1].strip()
305
- return parameters
306
- parm = parm[1].split("Steps:")
307
- parameters["neg_prompt"] = parm[0].strip()
308
- input_string = "Steps:" + parm[1]
309
-
310
- # Extracting Steps
311
- steps_match = re.search(r'Steps: (\d+)', input_string)
312
- if steps_match:
313
- parameters['Steps'] = int(steps_match.group(1))
314
-
315
- # Extracting Size
316
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
317
- if size_match:
318
- parameters['Size'] = size_match.group(1)
319
- width, height = map(int, parameters['Size'].split('x'))
320
- parameters['width'] = width
321
- parameters['height'] = height
322
-
323
- # Extracting other parameters
324
- other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
325
- for param in other_parameters:
326
- parameters[param[0]] = param[1].strip('"')
327
-
328
- return parameters
329
-
330
- def get_model_type(repo_id: str):
331
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
332
- default = "SD 1.5"
333
- try:
334
- model = api.model_info(repo_id=repo_id, timeout=5.0)
335
- tags = model.tags
336
- for tag in tags:
337
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
338
- except Exception:
339
- return default
340
- return default
341
-
342
  ## BEGIN MOD
343
  class GuiSD:
344
- def __init__(self):
345
  self.model = None
346
-
347
- print("Loading model...")
348
- self.model = Model_Diffusers(
349
- base_model_id="Lykon/dreamshaper-8",
350
- task_name="txt2img",
351
- vae_model=None,
352
- type_model_precision=torch.float16,
353
- retain_task_model_in_cache=False,
354
- device="cpu",
355
- )
356
- self.model.load_beta_styles()
357
- #self.model.device = torch.device("cpu") #
358
 
359
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
360
  #progress(0, desc="Start inference...")
@@ -368,31 +140,86 @@ class GuiSD:
368
  return img
369
 
370
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
371
-
372
- #yield f"Loading model: {model_name}"
373
-
374
  vae_model = vae_model if vae_model != "None" else None
375
  model_type = get_model_type(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  if vae_model:
378
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
379
  if model_type != vae_type:
380
- gr.Warning(msg_inc_vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- self.model.device = torch.device("cpu")
383
- dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
384
-
385
- self.model.load_pipe(
386
- model_name,
387
- task_name=TASK_STABLEPY[task],
388
- vae_model=vae_model if vae_model != "None" else None,
389
- type_model_precision=dtype_model,
390
- retain_task_model_in_cache=False,
391
- )
392
  #yield f"Model loaded: {model_name}"
393
 
394
  #@spaces.GPU
395
- @torch.inference_mode()
396
  def generate_pipeline(
397
  self,
398
  prompt,
@@ -497,23 +324,24 @@ class GuiSD:
497
  mode_ip2,
498
  scale_ip2,
499
  pag_scale,
500
- #progress=gr.Progress(track_tqdm=True),
501
  ):
502
- #progress(0, desc="Preparing inference...")
503
-
 
504
  vae_model = vae_model if vae_model != "None" else None
505
  loras_list = [lora1, lora2, lora3, lora4, lora5]
506
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
507
  msg_lora = ""
508
 
509
- print("Config model:", model_name, vae_model, loras_list)
510
-
511
  ## BEGIN MOD
 
512
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
513
  global lora_model_list
514
  lora_model_list = get_lora_model_list()
515
  ## END MOD
516
 
 
 
517
  task = TASK_STABLEPY[task]
518
 
519
  params_ip_img = []
@@ -536,6 +364,9 @@ class GuiSD:
536
  params_ip_mode.append(modeip)
537
  params_ip_scale.append(scaleip)
538
 
 
 
 
539
  if task != "txt2img" and not image_control:
540
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
541
 
@@ -665,18 +496,17 @@ class GuiSD:
665
  }
666
 
667
  self.model.device = torch.device("cuda:0")
668
- if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5 and loras_list != [""] * 5:
669
  self.model.pipe.transformer.to(self.model.device)
670
  print("transformer to cuda")
671
 
672
- #progress(1, desc="Inference preparation completed. Starting inference...")
673
-
674
- info_state = "" # for yield version
675
  return self.infer_short(self.model, pipe_params), info_state
676
  ## END MOD
677
 
 
678
  def dynamic_gpu_duration(func, duration, *args):
679
 
 
680
  @spaces.GPU(duration=duration)
681
  def wrapped_func():
682
  return func(*args)
@@ -696,7 +526,7 @@ def sd_gen_generate_pipeline(*args):
696
  load_lora_cpu = args[-3]
697
  generation_args = args[:-3]
698
  lora_list = [
699
- None if item == "None" or item == "" else item
700
  for item in [args[7], args[9], args[11], args[13], args[15]]
701
  ]
702
  lora_status = [None] * 5
@@ -706,7 +536,7 @@ def sd_gen_generate_pipeline(*args):
706
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
707
 
708
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
709
- # yield None, msg_load_lora
710
 
711
  # Load lora in CPU
712
  if load_lora_cpu:
@@ -732,14 +562,16 @@ def sd_gen_generate_pipeline(*args):
732
  )
733
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
734
 
735
- msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
 
736
  gr.Info(msg_request)
737
  print(msg_request)
738
-
739
- # yield from sd_gen.generate_pipeline(*generation_args)
740
 
741
  start_time = time.time()
742
 
 
 
743
  return dynamic_gpu_duration(
744
  sd_gen.generate_pipeline,
745
  gpu_duration_arg,
@@ -747,31 +579,19 @@ def sd_gen_generate_pipeline(*args):
747
  )
748
 
749
  end_time = time.time()
 
 
 
 
750
 
751
  if verbose_arg:
752
- execution_time = end_time - start_time
753
- msg_task_complete = (
754
- f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
755
- )
756
  gr.Info(msg_task_complete)
757
  print(msg_task_complete)
758
 
759
- def extract_exif_data(image):
760
- if image is None: return ""
761
 
762
- try:
763
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
764
-
765
- for key in metadata_keys:
766
- if key in image.info:
767
- return image.info[key]
768
 
769
- return str(image.info)
770
-
771
- except Exception as e:
772
- return f"Error extracting metadata: {str(e)}"
773
-
774
- @spaces.GPU(duration=20)
775
  def esrgan_upscale(image, upscaler_name, upscaler_size):
776
  if image is None: return None
777
 
@@ -793,9 +613,11 @@ def esrgan_upscale(image, upscaler_name, upscaler_size):
793
 
794
  return image_path
795
 
 
796
  dynamic_gpu_duration.zerogpu = True
797
  sd_gen_generate_pipeline.zerogpu = True
798
 
 
799
  from pathlib import Path
800
  from PIL import Image
801
  import random, json
@@ -1027,14 +849,14 @@ def update_lora_dict(path: str):
1027
  def download_lora(dl_urls: str):
1028
  global loras_url_to_path_dict
1029
  dl_path = ""
1030
- before = get_local_model_list(directory_loras)
1031
  urls = []
1032
  for url in [url.strip() for url in dl_urls.split(',')]:
1033
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
1034
  if not Path(local_path).exists():
1035
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
1036
  urls.append(url)
1037
- after = get_local_model_list(directory_loras)
1038
  new_files = list_sub(after, before)
1039
  i = 0
1040
  for file in new_files:
 
59
  logger.setLevel(logging.DEBUG)
60
 
61
  from env import (
62
+ HF_TOKEN, HF_READ_TOKEN, # to use only for private repos
63
  CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
64
  HF_LORA_ESSENTIAL_PRIVATE_REPO, HF_VAE_PRIVATE_REPO,
65
  HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
66
+ DIRECTORY_MODELS, DIRECTORY_LORAS, DIRECTORY_VAES, DIRECTORY_EMBEDS,
67
+ DIRECTORY_EMBEDS_SDXL, DIRECTORY_EMBEDS_POSITIVE_SDXL,
68
+ LOAD_DIFFUSERS_FORMAT_MODEL, DOWNLOAD_MODEL_LIST, DOWNLOAD_LORA_LIST,
69
+ DOWNLOAD_VAE_LIST, download_embeds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  ## BEGIN MOD
72
  from modutils import (to_list, list_uniq, list_sub, get_model_id_list, get_tupled_embed_list,
73
  get_tupled_model_list, get_lora_model_list, download_private_repo, download_things)
74
 
75
  # - **Download Models**
76
+ download_model = ", ".join(DOWNLOAD_MODEL_LIST)
77
  # - **Download VAEs**
78
+ download_vae = ", ".join(DOWNLOAD_VAE_LIST)
79
  # - **Download LoRAs**
80
+ download_lora = ", ".join(DOWNLOAD_LORA_LIST)
81
 
82
+ #download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, DIRECTORY_LORAS, True)
83
+ download_private_repo(HF_VAE_PRIVATE_REPO, DIRECTORY_VAES, False)
84
 
85
+ load_diffusers_format_model = list_uniq(LOAD_DIFFUSERS_FORMAT_MODEL + get_model_id_list())
86
  ## END MOD
87
 
88
  # Download stuffs
89
  for url in [url.strip() for url in download_model.split(',')]:
90
  if not os.path.exists(f"./models/{url.split('/')[-1]}"):
91
+ download_things(DIRECTORY_MODELS, url, HF_TOKEN, CIVITAI_API_KEY)
92
  for url in [url.strip() for url in download_vae.split(',')]:
93
  if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
94
+ download_things(DIRECTORY_VAES, url, HF_TOKEN, CIVITAI_API_KEY)
95
  for url in [url.strip() for url in download_lora.split(',')]:
96
  if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
97
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
98
 
99
  # Download Embeddings
100
  for url_embed in download_embeds:
101
  if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
102
+ download_things(DIRECTORY_EMBEDS, url_embed, HF_TOKEN, CIVITAI_API_KEY)
103
 
104
  # Build list models
105
+ embed_list = get_model_list(DIRECTORY_EMBEDS)
106
+ model_list = get_model_list(DIRECTORY_MODELS)
107
  model_list = load_diffusers_format_model + model_list
108
  ## BEGIN MOD
109
  lora_model_list = get_lora_model_list()
110
+ vae_model_list = get_model_list(DIRECTORY_VAES)
111
  vae_model_list.insert(0, "None")
112
 
113
+ #download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_SDXL, False)
114
+ #download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, DIRECTORY_EMBEDS_POSITIVE_SDXL, False)
115
+ embed_sdxl_list = get_model_list(DIRECTORY_EMBEDS_SDXL) + get_model_list(DIRECTORY_EMBEDS_POSITIVE_SDXL)
116
 
117
  def get_embed_list(pipeline_name):
118
  return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)
 
120
 
121
  print('\033[33m🏁 Download and listing of valid models completed.\033[0m')
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  ## BEGIN MOD
124
  class GuiSD:
125
+ def __init__(self, stream=True):
126
  self.model = None
127
+ self.status_loading = False
128
+ self.sleep_loading = 4
129
+ self.last_load = datetime.now()
 
 
 
 
 
 
 
 
 
130
 
131
  def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
132
  #progress(0, desc="Start inference...")
 
140
  return img
141
 
142
  def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):
 
 
 
143
  vae_model = vae_model if vae_model != "None" else None
144
  model_type = get_model_type(model_name)
145
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
146
+
147
+ if not os.path.exists(model_name):
148
+ _ = download_diffuser_repo(
149
+ repo_name=model_name,
150
+ model_type=model_type,
151
+ revision="main",
152
+ token=True,
153
+ )
154
+
155
+ for i in range(68):
156
+ if not self.status_loading:
157
+ self.status_loading = True
158
+ if i > 0:
159
+ time.sleep(self.sleep_loading)
160
+ print("Previous model ops...")
161
+ break
162
+ time.sleep(0.5)
163
+ print(f"Waiting queue {i}")
164
+ yield "Waiting queue"
165
+
166
+ self.status_loading = True
167
+
168
+ #yield f"Loading model: {model_name}"
169
 
170
  if vae_model:
171
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
172
  if model_type != vae_type:
173
+ gr.Warning(WARNING_MSG_VAE)
174
+
175
+ print("Loading model...")
176
+
177
+ try:
178
+ start_time = time.time()
179
+
180
+ if self.model is None:
181
+ self.model = Model_Diffusers(
182
+ base_model_id=model_name,
183
+ task_name=TASK_STABLEPY[task],
184
+ vae_model=vae_model,
185
+ type_model_precision=dtype_model,
186
+ retain_task_model_in_cache=False,
187
+ device="cpu",
188
+ )
189
+ else:
190
+
191
+ if self.model.base_model_id != model_name:
192
+ load_now_time = datetime.now()
193
+ elapsed_time = (load_now_time - self.last_load).total_seconds()
194
+
195
+ if elapsed_time <= 8:
196
+ print("Waiting for the previous model's time ops...")
197
+ time.sleep(8-elapsed_time)
198
+
199
+ self.model.device = torch.device("cpu")
200
+ self.model.load_pipe(
201
+ model_name,
202
+ task_name=TASK_STABLEPY[task],
203
+ vae_model=vae_model,
204
+ type_model_precision=dtype_model,
205
+ retain_task_model_in_cache=False,
206
+ )
207
+
208
+ end_time = time.time()
209
+ self.sleep_loading = max(min(int(end_time - start_time), 10), 4)
210
+ except Exception as e:
211
+ self.last_load = datetime.now()
212
+ self.status_loading = False
213
+ self.sleep_loading = 4
214
+ raise e
215
+
216
+ self.last_load = datetime.now()
217
+ self.status_loading = False
218
 
 
 
 
 
 
 
 
 
 
 
219
  #yield f"Model loaded: {model_name}"
220
 
221
  #@spaces.GPU
222
+ #@torch.inference_mode()
223
  def generate_pipeline(
224
  self,
225
  prompt,
 
324
  mode_ip2,
325
  scale_ip2,
326
  pag_scale,
 
327
  ):
328
+ info_state = html_template_message("Navigating latent space...")
329
+ #yield info_state, gr.update(), gr.update()
330
+
331
  vae_model = vae_model if vae_model != "None" else None
332
  loras_list = [lora1, lora2, lora3, lora4, lora5]
333
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
334
  msg_lora = ""
335
 
 
 
336
  ## BEGIN MOD
337
+ loras_list = [s if s else "None" for s in loras_list]
338
  prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
339
  global lora_model_list
340
  lora_model_list = get_lora_model_list()
341
  ## END MOD
342
 
343
+ print("Config model:", model_name, vae_model, loras_list)
344
+
345
  task = TASK_STABLEPY[task]
346
 
347
  params_ip_img = []
 
364
  params_ip_mode.append(modeip)
365
  params_ip_scale.append(scaleip)
366
 
367
+ concurrency = 5
368
+ self.model.stream_config(concurrency=concurrency, latent_resize_by=1, vae_decoding=False)
369
+
370
  if task != "txt2img" and not image_control:
371
  raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")
372
 
 
496
  }
497
 
498
  self.model.device = torch.device("cuda:0")
499
+ if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
500
  self.model.pipe.transformer.to(self.model.device)
501
  print("transformer to cuda")
502
 
 
 
 
503
  return self.infer_short(self.model, pipe_params), info_state
504
  ## END MOD
505
 
506
+
507
  def dynamic_gpu_duration(func, duration, *args):
508
 
509
+ @torch.inference_mode()
510
  @spaces.GPU(duration=duration)
511
  def wrapped_func():
512
  return func(*args)
 
526
  load_lora_cpu = args[-3]
527
  generation_args = args[:-3]
528
  lora_list = [
529
+ None if item == "None" or item == "" else item # MOD
530
  for item in [args[7], args[9], args[11], args[13], args[15]]
531
  ]
532
  lora_status = [None] * 5
 
536
  msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
537
 
538
  #if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
539
+ # yield msg_load_lora, gr.update(), gr.update()
540
 
541
  # Load lora in CPU
542
  if load_lora_cpu:
 
562
  )
563
  gr.Info(f"LoRAs in cache: {lora_cache_msg}")
564
 
565
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time.\nModel: {sd_gen.model.base_model_id}"
566
+ if verbose_arg:
567
  gr.Info(msg_request)
568
  print(msg_request)
569
+ #yield msg_request.replace("\n", "<br>"), gr.update(), gr.update()
 
570
 
571
  start_time = time.time()
572
 
573
+ # yield from sd_gen.generate_pipeline(*generation_args)
574
+ #yield from dynamic_gpu_duration(
575
  return dynamic_gpu_duration(
576
  sd_gen.generate_pipeline,
577
  gpu_duration_arg,
 
579
  )
580
 
581
  end_time = time.time()
582
+ execution_time = end_time - start_time
583
+ msg_task_complete = (
584
+ f"GPU task complete in: {int(round(execution_time, 0) + 1)} seconds"
585
+ )
586
 
587
  if verbose_arg:
 
 
 
 
588
  gr.Info(msg_task_complete)
589
  print(msg_task_complete)
590
 
591
+ yield msg_task_complete, gr.update(), gr.update()
 
592
 
 
 
 
 
 
 
593
 
594
+ @spaces.GPU(duration=15)
 
 
 
 
 
595
  def esrgan_upscale(image, upscaler_name, upscaler_size):
596
  if image is None: return None
597
 
 
613
 
614
  return image_path
615
 
616
+
617
  dynamic_gpu_duration.zerogpu = True
618
  sd_gen_generate_pipeline.zerogpu = True
619
 
620
+
621
  from pathlib import Path
622
  from PIL import Image
623
  import random, json
 
849
  def download_lora(dl_urls: str):
850
  global loras_url_to_path_dict
851
  dl_path = ""
852
+ before = get_local_model_list(DIRECTORY_LORAS)
853
  urls = []
854
  for url in [url.strip() for url in dl_urls.split(',')]:
855
+ local_path = f"{DIRECTORY_LORAS}/{url.split('/')[-1]}"
856
  if not Path(local_path).exists():
857
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
858
  urls.append(url)
859
+ after = get_local_model_list(DIRECTORY_LORAS)
860
  new_files = list_sub(after, before)
861
  i = 0
862
  for file in new_files: