Spaces:
Sleeping
Sleeping
"""Gradio helpers for caching, downloading etc.""" | |
import concurrent.futures | |
import contextlib | |
import datetime | |
import functools | |
import logging | |
import os | |
import shutil | |
import threading | |
import time | |
import huggingface_hub | |
import numpy as np | |
import psutil | |
def should_mock(): | |
"""Returns `True` if `MOCK_MODEL=yes` is set in environment.""" | |
return os.environ.get('MOCK_MODEL') == 'yes' | |
def timed(name, start_message=False): | |
"""Emits "Timed {name}: .1f secs" message to INFO logs.""" | |
t0 = time.monotonic() | |
timing = dict(dt=None) | |
try: | |
if start_message: | |
logging.info('Timing %s...', name) | |
yield timing | |
finally: | |
timing['secs'] = time.monotonic() - t0 | |
logging.info('Timed %s: %.1f secs', name, timing['secs']) | |
def synced(f): | |
"""Syncs calls to `f` with a `threading.Lock()`.""" | |
lock = threading.Lock() | |
def wrapper(*args, **kw): | |
t0 = time.monotonic() | |
with lock: | |
lock_dt = time.monotonic() - t0 | |
logging.info('synced wait: %.1f secs', lock_dt) | |
return f(*args, **kw) | |
return wrapper | |
_warmed_up = set() | |
_warmup_function = None | |
def set_warmup_function(warmup_function): | |
global _warmup_function | |
_warmup_function = warmup_function | |
_lock = threading.Lock() | |
_scheduled = {} | |
_download_secs = 0 | |
_warmup_secs = 0 | |
_loading_secs = 0 | |
_done = {} | |
_failed = {} | |
def _do_download(): | |
"""Downloading files, to be started in background thread.""" | |
global _download_secs | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) | |
while True: | |
if not _scheduled: | |
time.sleep(1) | |
continue | |
name, (repo, filenames, revision) = next(iter(_scheduled.items())) | |
logging.info('Downloading "%s" %s/%s/%s...', name, repo, filenames, revision) | |
with timed(f'downloading {name}', True) as t: | |
if should_mock(): | |
logging.warning('Mocking loading') | |
time.sleep(10.) | |
_done[name] = None | |
else: | |
try: | |
_done[name] = (huggingface_hub.hf_hub_download(repo_id=repo, filename=filename, revision=revision) for filename in filenames) | |
except Exception as e: # pylint: disable=broad-exception-caught | |
logging.exception('Could not download "%s" from hub!', name) | |
_failed[name] = str(e) | |
with _lock: | |
_scheduled.pop(name) | |
continue | |
if _warmup_function: | |
def warmup(name): | |
global _warmup_secs | |
with timed(f'warming up {name}', True) as t: | |
try: | |
_warmup_function(name) | |
_warmed_up.add(name) | |
except Exception: # pylint: disable=broad-exception-caught | |
logging.exception('Could not warmup "%s"!', name) | |
_warmup_secs += t['secs'] | |
executor.submit(warmup, name) | |
_download_secs += t['secs'] | |
with _lock: | |
_scheduled.pop(name) | |
def register_download(name, repo, filenames, revision='main'): | |
"""Will cause download of `filename` from HF `repo` in background thread.""" | |
with _lock: | |
if name not in _scheduled: | |
_scheduled[name] = (repo, filenames, revision) | |
def _hms(secs): | |
"""Formats `secs=3700` to `"01:01:40"`.""" | |
secs = int(secs) | |
h = secs // 3600 | |
m = (secs - h * 3600) // 60 | |
s = secs % 60 | |
return (f'{h}:' if h else '') + f'{m:02}:{s:02}' | |
def downloads_status(): | |
"""Returns string representation of download stats.""" | |
done_t = remaining_t = '' | |
if _done: | |
done_t = f' in {_hms(_download_secs)}' | |
remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}' | |
status = f'Downloaded {len(_done)}{done_t}' | |
if _scheduled: | |
status += f', {len(_scheduled)}{remaining_t} remaining' | |
if _warmup_function: | |
status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}' | |
if _failed: | |
status += f', {len(_failed)} failed' | |
return status | |
def get_paths(): | |
"""Returns dictionary `name` to `path` from previous `register_download()`.""" | |
return dict(_done) | |
_download_thread = threading.Thread(target=_do_download) | |
_download_thread.daemon = True | |
_download_thread.start() | |
_estimated_real = [(10, 10)] | |
_memory_cache = {} | |
def get_with_progress(getter, secs, progress, step=0.1): | |
"""Returns result from `getter` while showing a progress bar.""" | |
if progress is None: | |
return getter() | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(getter) | |
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): | |
if not future.done(): | |
time.sleep(step) | |
return future.result() | |
def _get_array_sizes(tree): | |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] | |
def get_memory_cache( | |
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None | |
): | |
"""Keeps cache below specified size by removing elements not last accessed.""" | |
if key in _memory_cache: | |
_memory_cache[key] = _memory_cache.pop(key) # Updates "last accessed" order | |
return _memory_cache[key] | |
est, real = zip(*_estimated_real) | |
if estimated_secs is None: | |
estimated_secs = sum(est) / len(est) | |
with timed(f'loading {key}') as t: | |
estimated_secs *= sum(real) / sum(est) | |
value = get_with_progress(getter, estimated_secs, progress) | |
_estimated_real.append((estimated_secs, t['secs'])) | |
if not max_cache_size_bytes: | |
return value | |
_memory_cache[key] = value | |
sz = sum(_get_array_sizes(list(_memory_cache.values()))) | |
logging.info('New memory cache size=%.1f MB', sz/1e6) | |
while sz > max_cache_size_bytes: | |
k, v = next(iter(_memory_cache.items())) | |
if k == key: | |
break | |
s = sum(_get_array_sizes(v)) | |
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) | |
_memory_cache.pop(k) | |
sz -= s | |
return value | |
def get_memory_cache_info(): | |
"""Returns number of items and total size in bytes.""" | |
sizes = _get_array_sizes(_memory_cache) | |
return len(_memory_cache), sum(sizes) | |
def get_system_info(): | |
"""Returns string describing system's RAM/disk status.""" | |
host_colocation = int(os.environ.get('HOST_COLOCATION', '1')) | |
vm = psutil.virtual_memory() | |
du = shutil.disk_usage('.') | |
return ( | |
f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, ' | |
f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G' | |
) | |
def get_status(include_system_info=True): | |
"""Returns string about download/memory/system status.""" | |
mc_len, mc_sz = get_memory_cache_info() | |
mc_t = _hms(sum(real for _, real in _estimated_real[1:])) | |
return ( | |
'Timestamp: ' | |
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
+ ' – Model stats: ' | |
+ downloads_status() | |
+ ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' + | |
(' – System: ' + get_system_info() if include_system_info else '') | |
) | |