diffusion-xl / lib /utils.py
adamelliotfields's picture
Clean up loader
6ad0411 verified
raw
history blame
1.97 kB
import functools
import json
import time
from contextlib import contextmanager
import torch
from diffusers.utils import logging as diffusers_logging
from huggingface_hub._snapshot_download import snapshot_download
from huggingface_hub.utils import are_progress_bars_disabled
from transformers import logging as transformers_logging
@contextmanager
def timer(message="Operation", logger=print):
start = time.perf_counter()
logger(message)
try:
yield
finally:
end = time.perf_counter()
logger(f"{message} took {end - start:.2f}s")
@functools.lru_cache()
def read_json(path: str) -> dict:
with open(path, "r", encoding="utf-8") as file:
data = json.load(file)
return json.dumps(data, indent=4)
@functools.lru_cache()
def read_file(path: str) -> str:
with open(path, "r", encoding="utf-8") as file:
return file.read()
def disable_progress_bars():
transformers_logging.disable_progress_bar()
diffusers_logging.disable_progress_bar()
def enable_progress_bars():
# warns if `HF_HUB_DISABLE_PROGRESS_BARS` env var is not None
transformers_logging.enable_progress_bar()
diffusers_logging.enable_progress_bar()
def safe_progress(progress, current=0, total=0, desc=""):
if progress is not None:
progress((current, total), desc=desc)
def cuda_collect():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
def download_repo_files(repo_id, allow_patterns, token=None):
was_disabled = are_progress_bars_disabled()
enable_progress_bars()
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type="model",
revision="main",
token=token,
allow_patterns=allow_patterns,
ignore_patterns=None,
)
if was_disabled:
disable_progress_bars()
return snapshot_path