"""Gradio helpers for caching, downloading etc.""" import concurrent.futures import contextlib import datetime import functools import logging import os import shutil import threading import time import huggingface_hub import numpy as np import psutil def should_mock(): """Returns `True` if `MOCK_MODEL=yes` is set in environment.""" return os.environ.get('MOCK_MODEL') == 'yes' @contextlib.contextmanager def timed(name, start_message=False): """Emits "Timed {name}: .1f secs" message to INFO logs.""" t0 = time.monotonic() timing = dict(dt=None) try: if start_message: logging.info('Timing %s...', name) yield timing finally: timing['secs'] = time.monotonic() - t0 logging.info('Timed %s: %.1f secs', name, timing['secs']) def synced(f): """Syncs calls to `f` with a `threading.Lock()`.""" lock = threading.Lock() @functools.wraps(f) def wrapper(*args, **kw): t0 = time.monotonic() with lock: lock_dt = time.monotonic() - t0 logging.info('synced wait: %.1f secs', lock_dt) return f(*args, **kw) return wrapper _warmed_up = set() _warmup_function = None def set_warmup_function(warmup_function): global _warmup_function _warmup_function = warmup_function _lock = threading.Lock() _scheduled = {} _download_secs = 0 _warmup_secs = 0 _loading_secs = 0 _done = {} _failed = {} def _do_download(): """Downloading files, to be started in background thread.""" global _download_secs executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) while True: if not _scheduled: time.sleep(1) continue name, (repo, filenames, revision) = next(iter(_scheduled.items())) logging.info('Downloading "%s" %s/%s/%s...', name, repo, filenames, revision) with timed(f'downloading {name}', True) as t: if should_mock(): logging.warning('Mocking loading') time.sleep(10.) _done[name] = None else: try: _done[name] = tuple(huggingface_hub.hf_hub_download(repo_id=repo, filename=filename, revision=revision) for filename in filenames) except Exception as e: # pylint: disable=broad-exception-caught logging.exception('Could not download "%s" from hub!', name) _failed[name] = str(e) with _lock: _scheduled.pop(name) continue if _warmup_function: def warmup(name): global _warmup_secs with timed(f'warming up {name}', True) as t: try: _warmup_function(name) _warmed_up.add(name) except Exception: # pylint: disable=broad-exception-caught logging.exception('Could not warmup "%s"!', name) _warmup_secs += t['secs'] executor.submit(warmup, name) _download_secs += t['secs'] with _lock: _scheduled.pop(name) def register_download(name, repo, filenames, revision='main'): """Will cause download of `filename` from HF `repo` in background thread.""" with _lock: if name not in _scheduled: _scheduled[name] = (repo, filenames, revision) def _hms(secs): """Formats `secs=3700` to `"01:01:40"`.""" secs = int(secs) h = secs // 3600 m = (secs - h * 3600) // 60 s = secs % 60 return (f'{h}:' if h else '') + f'{m:02}:{s:02}' def downloads_status(): """Returns string representation of download stats.""" done_t = remaining_t = '' if _done: done_t = f' in {_hms(_download_secs)}' remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}' status = f'Downloaded {len(_done)}{done_t}' if _scheduled: status += f', {len(_scheduled)}{remaining_t} remaining' if _warmup_function: status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}' if _failed: status += f', {len(_failed)} failed' return status def get_paths(): """Returns dictionary `name` to `path` from previous `register_download()`.""" return dict(_done) _download_thread = threading.Thread(target=_do_download) _download_thread.daemon = True _download_thread.start() _estimated_real = [(10, 10)] _memory_cache = {} def get_with_progress(getter, secs, progress, step=0.1): """Returns result from `getter` while showing a progress bar.""" if progress is None: return getter() with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(getter) for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): if not future.done(): time.sleep(step) return future.result() def _get_array_sizes(tree): return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] def get_memory_cache( key, getter, max_cache_size_bytes, progress=None, estimated_secs=None ): """Keeps cache below specified size by removing elements not last accessed.""" if key in _memory_cache: _memory_cache[key] = _memory_cache.pop(key) # Updates "last accessed" order return _memory_cache[key] est, real = zip(*_estimated_real) if estimated_secs is None: estimated_secs = sum(est) / len(est) with timed(f'loading {key}') as t: estimated_secs *= sum(real) / sum(est) value = get_with_progress(getter, estimated_secs, progress) _estimated_real.append((estimated_secs, t['secs'])) if not max_cache_size_bytes: return value _memory_cache[key] = value sz = sum(_get_array_sizes(list(_memory_cache.values()))) logging.info('New memory cache size=%.1f MB', sz/1e6) while sz > max_cache_size_bytes: k, v = next(iter(_memory_cache.items())) if k == key: break s = sum(_get_array_sizes(v)) logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) _memory_cache.pop(k) sz -= s return value def get_memory_cache_info(): """Returns number of items and total size in bytes.""" sizes = _get_array_sizes(_memory_cache) return len(_memory_cache), sum(sizes) def get_system_info(): """Returns string describing system's RAM/disk status.""" host_colocation = int(os.environ.get('HOST_COLOCATION', '1')) vm = psutil.virtual_memory() du = shutil.disk_usage('.') return ( f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, ' f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G' ) def get_status(include_system_info=True): """Returns string about download/memory/system status.""" mc_len, mc_sz = get_memory_cache_info() mc_t = _hms(sum(real for _, real in _estimated_real[1:])) return ( 'Timestamp: ' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ' – Model stats: ' + downloads_status() + ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' + (' – System: ' + get_system_info() if include_system_info else '') )