adamelliotfields commited on
Commit
65d64be
·
verified ·
1 Parent(s): 39a6792

Move to utils

Browse files
Files changed (5) hide show
  1. app.py +3 -1
  2. lib/__init__.py +1 -2
  3. lib/config.py +20 -0
  4. lib/download.py +0 -38
  5. lib/utils.py +12 -0
app.py CHANGED
@@ -475,7 +475,9 @@ if __name__ == "__main__":
475
  args = parser.parse_args()
476
 
477
  # download to hub cache
478
- download_repo_files()
 
 
479
 
480
  # https://www.gradio.app/docs/gradio/interface#interface-queue
481
  demo.queue().launch(
 
475
  args = parser.parse_args()
476
 
477
  # download to hub cache
478
+ for repo_id, allow_patterns in Config.DOWNLOAD_FILES.items():
479
+ print(f"Downloading {repo_id}...")
480
+ download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
481
 
482
  # https://www.gradio.app/docs/gradio/interface#interface-queue
483
  demo.queue().launch(
lib/__init__.py CHANGED
@@ -1,9 +1,8 @@
1
  from .config import Config
2
- from .download import download_repo_files
3
  from .inference import generate
4
  from .loader import Loader
5
  from .upscaler import RealESRGAN
6
- from .utils import async_call, load_json, read_file
7
 
8
  __all__ = [
9
  "Config",
 
1
  from .config import Config
 
2
  from .inference import generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
+ from .utils import async_call, download_repo_files, load_json, read_file
6
 
7
  __all__ = [
8
  "Config",
lib/config.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from types import SimpleNamespace
2
 
3
  from diffusers import (
@@ -12,6 +13,25 @@ from diffusers import (
12
  )
13
 
14
  Config = SimpleNamespace(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  MONO_FONTS=["monospace"],
16
  SANS_FONTS=[
17
  "sans-serif",
 
1
+ import os
2
  from types import SimpleNamespace
3
 
4
  from diffusers import (
 
13
  )
14
 
15
  Config = SimpleNamespace(
16
+ HF_TOKEN=os.environ.get("HF_TOKEN", None),
17
+ DOWNLOAD_FILES={
18
+ "Lykon/dreamshaper-8": [
19
+ "feature_extractor/preprocessor_config.json",
20
+ "safety_checker/config.json",
21
+ "scheduler/scheduler_config.json",
22
+ "text_encoder/config.json",
23
+ "text_encoder/model.fp16.safetensors",
24
+ "tokenizer/merges.txt",
25
+ "tokenizer/special_tokens_map.json",
26
+ "tokenizer/tokenizer_config.json",
27
+ "tokenizer/vocab.json",
28
+ "unet/config.json",
29
+ "unet/diffusion_pytorch_model.fp16.safetensors",
30
+ "vae/config.json",
31
+ "vae/diffusion_pytorch_model.fp16.safetensors",
32
+ "model_index.json",
33
+ ],
34
+ },
35
  MONO_FONTS=["monospace"],
36
  SANS_FONTS=[
37
  "sans-serif",
lib/download.py DELETED
@@ -1,38 +0,0 @@
1
- import os
2
-
3
- from huggingface_hub._snapshot_download import snapshot_download
4
-
5
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
6
-
7
- SPACES_ZERO_GPU = os.environ.get("SPACES_ZERO_GPU", "").lower() == "true"
8
-
9
- REPO = "Lykon/dreamshaper-8"
10
-
11
- FILES = [
12
- "feature_extractor/preprocessor_config.json",
13
- "safety_checker/config.json",
14
- "scheduler/scheduler_config.json",
15
- "text_encoder/config.json",
16
- "text_encoder/model.fp16.safetensors",
17
- "tokenizer/merges.txt",
18
- "tokenizer/special_tokens_map.json",
19
- "tokenizer/tokenizer_config.json",
20
- "tokenizer/vocab.json",
21
- "unet/config.json",
22
- "unet/diffusion_pytorch_model.fp16.safetensors",
23
- "vae/config.json",
24
- "vae/diffusion_pytorch_model.fp16.safetensors",
25
- "model_index.json",
26
- ]
27
-
28
-
29
- def download_repo_files():
30
- global REPO, FILES
31
- return snapshot_download(
32
- repo_id=REPO,
33
- repo_type="model",
34
- revision="main",
35
- token=HF_TOKEN,
36
- allow_patterns=FILES,
37
- ignore_patterns=None,
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/utils.py CHANGED
@@ -5,6 +5,7 @@ from typing import Callable, TypeVar
5
 
6
  import anyio
7
  from anyio import Semaphore
 
8
  from typing_extensions import ParamSpec
9
 
10
  T = TypeVar("T")
@@ -26,6 +27,17 @@ def read_file(path: str) -> str:
26
  return file.read()
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  # like the original but supports args and kwargs instead of a dict
30
  # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
31
  async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
 
5
 
6
  import anyio
7
  from anyio import Semaphore
8
+ from huggingface_hub._snapshot_download import snapshot_download
9
  from typing_extensions import ParamSpec
10
 
11
  T = TypeVar("T")
 
27
  return file.read()
28
 
29
 
30
+ def download_repo_files(repo_id, allow_patterns, token=None):
31
+ return snapshot_download(
32
+ repo_id=repo_id,
33
+ repo_type="model",
34
+ revision="main",
35
+ token=token,
36
+ allow_patterns=allow_patterns,
37
+ ignore_patterns=None,
38
+ )
39
+
40
+
41
  # like the original but supports args and kwargs instead of a dict
42
  # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
43
  async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: