Spaces:
Running
on
Zero
Running
on
Zero
import functools | |
import json | |
import time | |
from contextlib import contextmanager | |
import torch | |
from diffusers.utils import logging as diffusers_logging | |
from huggingface_hub._snapshot_download import snapshot_download | |
from huggingface_hub.utils import are_progress_bars_disabled | |
from transformers import logging as transformers_logging | |
def timer(message="Operation", logger=print): | |
start = time.perf_counter() | |
logger(message) | |
try: | |
yield | |
finally: | |
end = time.perf_counter() | |
logger(f"{message} took {end - start:.2f}s") | |
def read_json(path: str) -> dict: | |
with open(path, "r", encoding="utf-8") as file: | |
data = json.load(file) | |
return json.dumps(data, indent=4) | |
def read_file(path: str) -> str: | |
with open(path, "r", encoding="utf-8") as file: | |
return file.read() | |
def disable_progress_bars(): | |
transformers_logging.disable_progress_bar() | |
diffusers_logging.disable_progress_bar() | |
def enable_progress_bars(): | |
# warns if `HF_HUB_DISABLE_PROGRESS_BARS` env var is not None | |
transformers_logging.enable_progress_bar() | |
diffusers_logging.enable_progress_bar() | |
def safe_progress(progress, current=0, total=0, desc=""): | |
if progress is not None: | |
progress((current, total), desc=desc) | |
def cuda_collect(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.synchronize() | |
def download_repo_files(repo_id, allow_patterns, token=None): | |
was_disabled = are_progress_bars_disabled() | |
enable_progress_bars() | |
snapshot_path = snapshot_download( | |
repo_id=repo_id, | |
repo_type="model", | |
revision="main", | |
token=token, | |
allow_patterns=allow_patterns, | |
ignore_patterns=None, | |
) | |
if was_disabled: | |
disable_progress_bars() | |
return snapshot_path | |