Spaces:
Running
on
Zero
Running
on
Zero
Update utils.py
Browse files
utils.py
CHANGED
@@ -7,11 +7,14 @@ from constants import (
|
|
7 |
HF_TOKEN,
|
8 |
MODEL_TYPE_CLASS,
|
9 |
DIRECTORY_LORAS,
|
|
|
10 |
)
|
11 |
from huggingface_hub import HfApi
|
|
|
12 |
from diffusers import DiffusionPipeline
|
13 |
from huggingface_hub import model_info as model_info_data
|
14 |
from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
|
|
|
15 |
from pathlib import PosixPath
|
16 |
from unidecode import unidecode
|
17 |
import urllib.parse
|
@@ -283,10 +286,15 @@ def get_model_type(repo_id: str):
|
|
283 |
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
284 |
default = "SD 1.5"
|
285 |
try:
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
290 |
except Exception:
|
291 |
return default
|
292 |
return default
|
@@ -371,17 +379,23 @@ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "mai
|
|
371 |
if len(variant_filenames):
|
372 |
variant = "fp16"
|
373 |
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
if isinstance(cached_folder, PosixPath):
|
387 |
cached_folder = cached_folder.as_posix()
|
|
|
7 |
HF_TOKEN,
|
8 |
MODEL_TYPE_CLASS,
|
9 |
DIRECTORY_LORAS,
|
10 |
+
DIFFUSECRAFT_CHECKPOINT_NAME,
|
11 |
)
|
12 |
from huggingface_hub import HfApi
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
from diffusers import DiffusionPipeline
|
15 |
from huggingface_hub import model_info as model_info_data
|
16 |
from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
|
17 |
+
from stablepy.diffusers_vanilla.utils import checkpoint_model_type
|
18 |
from pathlib import PosixPath
|
19 |
from unidecode import unidecode
|
20 |
import urllib.parse
|
|
|
286 |
api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
|
287 |
default = "SD 1.5"
|
288 |
try:
|
289 |
+
if os.path.exists(repo_id):
|
290 |
+
tag = checkpoint_model_type(repo_id)
|
291 |
+
return DIFFUSECRAFT_CHECKPOINT_NAME[tag]
|
292 |
+
else:
|
293 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
294 |
+
tags = model.tags
|
295 |
+
for tag in tags:
|
296 |
+
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
|
297 |
+
|
298 |
except Exception:
|
299 |
return default
|
300 |
return default
|
|
|
379 |
if len(variant_filenames):
|
380 |
variant = "fp16"
|
381 |
|
382 |
+
if model_type == "FLUX":
|
383 |
+
cached_folder = snapshot_download(
|
384 |
+
repo_id=repo_name,
|
385 |
+
allow_patterns="transformer/*"
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
cached_folder = DiffusionPipeline.download(
|
389 |
+
pretrained_model_name=repo_name,
|
390 |
+
force_download=False,
|
391 |
+
token=token,
|
392 |
+
revision=revision,
|
393 |
+
# mirror="https://hf-mirror.com",
|
394 |
+
variant=variant,
|
395 |
+
use_safetensors=True,
|
396 |
+
trust_remote_code=False,
|
397 |
+
timeout=5.0,
|
398 |
+
)
|
399 |
|
400 |
if isinstance(cached_folder, PosixPath):
|
401 |
cached_folder = cached_folder.as_posix()
|