adamelliotfields commited on
Commit
083766b
·
verified ·
1 Parent(s): bb42c8d

Terminal progress bar improvements

Browse files
Files changed (4) hide show
  1. app.py +2 -15
  2. lib/__init__.py +11 -1
  3. lib/config.py +15 -1
  4. lib/utils.py +20 -1
app.py CHANGED
@@ -1,23 +1,10 @@
1
  import argparse
2
  import json
3
- import os
4
  import random
5
- from warnings import filterwarnings
6
 
7
  import gradio as gr
8
- from diffusers.utils import logging as diffusers_logging
9
- from transformers import logging as transformers_logging
10
 
11
- from lib import Config, async_call, download_repo_files, generate, read_file
12
-
13
- filterwarnings("ignore", category=FutureWarning, module="diffusers")
14
- filterwarnings("ignore", category=FutureWarning, module="transformers")
15
-
16
- diffusers_logging.set_verbosity_error()
17
- transformers_logging.set_verbosity_error()
18
-
19
- diffusers_logging.disable_progress_bar()
20
- transformers_logging.disable_progress_bar()
21
 
22
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
23
  refresh_seed_js = """
@@ -337,7 +324,7 @@ if __name__ == "__main__":
337
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
338
  args = parser.parse_args()
339
 
340
- # download to hub cache
341
  for repo_id, allow_patterns in Config.HF_MODELS.items():
342
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
343
 
 
1
  import argparse
2
  import json
 
3
  import random
 
4
 
5
  import gradio as gr
 
 
6
 
7
+ from lib import Config, async_call, disable_progress_bars, download_repo_files, generate, read_file
 
 
 
 
 
 
 
 
 
8
 
9
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
10
  refresh_seed_js = """
 
324
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
325
  args = parser.parse_args()
326
 
327
+ disable_progress_bars()
328
  for repo_id, allow_patterns in Config.HF_MODELS.items():
329
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
330
 
lib/__init__.py CHANGED
@@ -2,15 +2,25 @@ from .config import Config
2
  from .inference import generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
- from .utils import async_call, download_civit_file, download_repo_files, load_json, read_file
 
 
 
 
 
 
 
 
6
 
7
  __all__ = [
8
  "Config",
9
  "Loader",
10
  "RealESRGAN",
11
  "async_call",
 
12
  "download_civit_file",
13
  "download_repo_files",
 
14
  "generate",
15
  "load_json",
16
  "read_file",
 
2
  from .inference import generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
+ from .utils import (
6
+ async_call,
7
+ disable_progress_bars,
8
+ download_civit_file,
9
+ download_repo_files,
10
+ enable_progress_bars,
11
+ load_json,
12
+ read_file,
13
+ )
14
 
15
  __all__ = [
16
  "Config",
17
  "Loader",
18
  "RealESRGAN",
19
  "async_call",
20
+ "disable_progress_bars",
21
  "download_civit_file",
22
  "download_repo_files",
23
+ "enable_progress_bars",
24
  "generate",
25
  "load_json",
26
  "read_file",
lib/config.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  from importlib import import_module
 
3
  from types import SimpleNamespace
 
4
 
5
  from diffusers import (
6
  DDIMScheduler,
@@ -11,9 +13,21 @@ from diffusers import (
11
  StableDiffusionXLImg2ImgPipeline,
12
  StableDiffusionXLPipeline,
13
  )
 
 
14
 
15
  # improved GPU handling and progress bars; set before importing spaces
16
- os.environ["ZEROGPU_V2"] = "true"
 
 
 
 
 
 
 
 
 
 
17
 
18
  _sdxl_refiner_files = [
19
  "scheduler/scheduler_config.json",
 
1
  import os
2
  from importlib import import_module
3
+ from importlib.util import find_spec
4
  from types import SimpleNamespace
5
+ from warnings import filterwarnings
6
 
7
  from diffusers import (
8
  DDIMScheduler,
 
13
  StableDiffusionXLImg2ImgPipeline,
14
  StableDiffusionXLPipeline,
15
  )
16
+ from diffusers.utils import logging as diffusers_logging
17
+ from transformers import logging as transformers_logging
18
 
19
  # improved GPU handling and progress bars; set before importing spaces
20
+ os.environ["ZEROGPU_V2"] = "1"
21
+
22
+ # use Rust downloader
23
+ if find_spec("hf_transfer"):
24
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
25
+
26
+ filterwarnings("ignore", category=FutureWarning, module="diffusers")
27
+ filterwarnings("ignore", category=FutureWarning, module="transformers")
28
+
29
+ diffusers_logging.set_verbosity_error()
30
+ transformers_logging.set_verbosity_error()
31
 
32
  _sdxl_refiner_files = [
33
  "scheduler/scheduler_config.json",
lib/utils.py CHANGED
@@ -7,7 +7,10 @@ from typing import Callable, TypeVar
7
  import anyio
8
  import httpx
9
  from anyio import Semaphore
 
10
  from huggingface_hub._snapshot_download import snapshot_download
 
 
11
  from typing_extensions import ParamSpec
12
 
13
  T = TypeVar("T")
@@ -29,8 +32,21 @@ def read_file(path: str) -> str:
29
  return file.read()
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  def download_repo_files(repo_id, allow_patterns, token=None):
33
- return snapshot_download(
 
 
34
  repo_id=repo_id,
35
  repo_type="model",
36
  revision="main",
@@ -38,6 +54,9 @@ def download_repo_files(repo_id, allow_patterns, token=None):
38
  allow_patterns=allow_patterns,
39
  ignore_patterns=None,
40
  )
 
 
 
41
 
42
 
43
  def download_civit_file(lora_id, version_id, file_path=".", token=None):
 
7
  import anyio
8
  import httpx
9
  from anyio import Semaphore
10
+ from diffusers.utils import logging as diffusers_logging
11
  from huggingface_hub._snapshot_download import snapshot_download
12
+ from huggingface_hub.utils import are_progress_bars_disabled
13
+ from transformers import logging as transformers_logging
14
  from typing_extensions import ParamSpec
15
 
16
  T = TypeVar("T")
 
32
  return file.read()
33
 
34
 
35
+ def disable_progress_bars():
36
+ transformers_logging.disable_progress_bar()
37
+ diffusers_logging.disable_progress_bar()
38
+
39
+
40
+ def enable_progress_bars():
41
+ # warns if `HF_HUB_DISABLE_PROGRESS_BARS` env var is not None
42
+ transformers_logging.enable_progress_bar()
43
+ diffusers_logging.enable_progress_bar()
44
+
45
+
46
  def download_repo_files(repo_id, allow_patterns, token=None):
47
+ was_disabled = are_progress_bars_disabled()
48
+ enable_progress_bars()
49
+ snapshot_path = snapshot_download(
50
  repo_id=repo_id,
51
  repo_type="model",
52
  revision="main",
 
54
  allow_patterns=allow_patterns,
55
  ignore_patterns=None,
56
  )
57
+ if was_disabled:
58
+ disable_progress_bars()
59
+ return snapshot_path
60
 
61
 
62
  def download_civit_file(lora_id, version_id, file_path=".", token=None):