r3gm commited on
Commit
4c2c548
1 Parent(s): 3c8d6fa

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +29 -15
utils.py CHANGED
@@ -7,11 +7,14 @@ from constants import (
7
  HF_TOKEN,
8
  MODEL_TYPE_CLASS,
9
  DIRECTORY_LORAS,
 
10
  )
11
  from huggingface_hub import HfApi
 
12
  from diffusers import DiffusionPipeline
13
  from huggingface_hub import model_info as model_info_data
14
  from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
 
15
  from pathlib import PosixPath
16
  from unidecode import unidecode
17
  import urllib.parse
@@ -283,10 +286,15 @@ def get_model_type(repo_id: str):
283
  api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
284
  default = "SD 1.5"
285
  try:
286
- model = api.model_info(repo_id=repo_id, timeout=5.0)
287
- tags = model.tags
288
- for tag in tags:
289
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
 
 
 
 
 
290
  except Exception:
291
  return default
292
  return default
@@ -371,17 +379,23 @@ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "mai
371
  if len(variant_filenames):
372
  variant = "fp16"
373
 
374
- cached_folder = DiffusionPipeline.download(
375
- pretrained_model_name=repo_name,
376
- force_download=False,
377
- token=token,
378
- revision=revision,
379
- # mirror="https://hf-mirror.com",
380
- variant=variant,
381
- use_safetensors=True,
382
- trust_remote_code=False,
383
- timeout=5.0,
384
- )
 
 
 
 
 
 
385
 
386
  if isinstance(cached_folder, PosixPath):
387
  cached_folder = cached_folder.as_posix()
 
7
  HF_TOKEN,
8
  MODEL_TYPE_CLASS,
9
  DIRECTORY_LORAS,
10
+ DIFFUSECRAFT_CHECKPOINT_NAME,
11
  )
12
  from huggingface_hub import HfApi
13
+ from huggingface_hub import snapshot_download
14
  from diffusers import DiffusionPipeline
15
  from huggingface_hub import model_info as model_info_data
16
  from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
17
+ from stablepy.diffusers_vanilla.utils import checkpoint_model_type
18
  from pathlib import PosixPath
19
  from unidecode import unidecode
20
  import urllib.parse
 
286
  api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
287
  default = "SD 1.5"
288
  try:
289
+ if os.path.exists(repo_id):
290
+ tag = checkpoint_model_type(repo_id)
291
+ return DIFFUSECRAFT_CHECKPOINT_NAME[tag]
292
+ else:
293
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
294
+ tags = model.tags
295
+ for tag in tags:
296
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
297
+
298
  except Exception:
299
  return default
300
  return default
 
379
  if len(variant_filenames):
380
  variant = "fp16"
381
 
382
+ if model_type == "FLUX":
383
+ cached_folder = snapshot_download(
384
+ repo_id=repo_name,
385
+ allow_patterns="transformer/*"
386
+ )
387
+ else:
388
+ cached_folder = DiffusionPipeline.download(
389
+ pretrained_model_name=repo_name,
390
+ force_download=False,
391
+ token=token,
392
+ revision=revision,
393
+ # mirror="https://hf-mirror.com",
394
+ variant=variant,
395
+ use_safetensors=True,
396
+ trust_remote_code=False,
397
+ timeout=5.0,
398
+ )
399
 
400
  if isinstance(cached_folder, PosixPath):
401
  cached_folder = cached_folder.as_posix()