paligemma-cpu-gguf / gradio_helpers.py
abetlen's picture
Update gradio_helpers.py
318300d verified
"""Gradio helpers for caching, downloading etc."""
import concurrent.futures
import contextlib
import datetime
import functools
import logging
import os
import shutil
import threading
import time
import huggingface_hub
import numpy as np
import psutil
def should_mock():
"""Returns `True` if `MOCK_MODEL=yes` is set in environment."""
return os.environ.get('MOCK_MODEL') == 'yes'
@contextlib.contextmanager
def timed(name, start_message=False):
"""Emits "Timed {name}: .1f secs" message to INFO logs."""
t0 = time.monotonic()
timing = dict(dt=None)
try:
if start_message:
logging.info('Timing %s...', name)
yield timing
finally:
timing['secs'] = time.monotonic() - t0
logging.info('Timed %s: %.1f secs', name, timing['secs'])
def synced(f):
"""Syncs calls to `f` with a `threading.Lock()`."""
lock = threading.Lock()
@functools.wraps(f)
def wrapper(*args, **kw):
t0 = time.monotonic()
with lock:
lock_dt = time.monotonic() - t0
logging.info('synced wait: %.1f secs', lock_dt)
return f(*args, **kw)
return wrapper
_warmed_up = set()
_warmup_function = None
def set_warmup_function(warmup_function):
global _warmup_function
_warmup_function = warmup_function
_lock = threading.Lock()
_scheduled = {}
_download_secs = 0
_warmup_secs = 0
_loading_secs = 0
_done = {}
_failed = {}
def _do_download():
"""Downloading files, to be started in background thread."""
global _download_secs
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
while True:
if not _scheduled:
time.sleep(1)
continue
name, (repo, filenames, revision) = next(iter(_scheduled.items()))
logging.info('Downloading "%s" %s/%s/%s...', name, repo, filenames, revision)
with timed(f'downloading {name}', True) as t:
if should_mock():
logging.warning('Mocking loading')
time.sleep(10.)
_done[name] = None
else:
try:
_done[name] = tuple(huggingface_hub.hf_hub_download(repo_id=repo, filename=filename, revision=revision) for filename in filenames)
except Exception as e: # pylint: disable=broad-exception-caught
logging.exception('Could not download "%s" from hub!', name)
_failed[name] = str(e)
with _lock:
_scheduled.pop(name)
continue
if _warmup_function:
def warmup(name):
global _warmup_secs
with timed(f'warming up {name}', True) as t:
try:
_warmup_function(name)
_warmed_up.add(name)
except Exception: # pylint: disable=broad-exception-caught
logging.exception('Could not warmup "%s"!', name)
_warmup_secs += t['secs']
executor.submit(warmup, name)
_download_secs += t['secs']
with _lock:
_scheduled.pop(name)
def register_download(name, repo, filenames, revision='main'):
"""Will cause download of `filename` from HF `repo` in background thread."""
with _lock:
if name not in _scheduled:
_scheduled[name] = (repo, filenames, revision)
def _hms(secs):
"""Formats `secs=3700` to `"01:01:40"`."""
secs = int(secs)
h = secs // 3600
m = (secs - h * 3600) // 60
s = secs % 60
return (f'{h}:' if h else '') + f'{m:02}:{s:02}'
def downloads_status():
"""Returns string representation of download stats."""
done_t = remaining_t = ''
if _done:
done_t = f' in {_hms(_download_secs)}'
remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}'
status = f'Downloaded {len(_done)}{done_t}'
if _scheduled:
status += f', {len(_scheduled)}{remaining_t} remaining'
if _warmup_function:
status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}'
if _failed:
status += f', {len(_failed)} failed'
return status
def get_paths():
"""Returns dictionary `name` to `path` from previous `register_download()`."""
return dict(_done)
_download_thread = threading.Thread(target=_do_download)
_download_thread.daemon = True
_download_thread.start()
_estimated_real = [(10, 10)]
_memory_cache = {}
def get_with_progress(getter, secs, progress, step=0.1):
"""Returns result from `getter` while showing a progress bar."""
if progress is None:
return getter()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(getter)
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
if not future.done():
time.sleep(step)
return future.result()
def _get_array_sizes(tree):
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
def get_memory_cache(
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
"""Keeps cache below specified size by removing elements not last accessed."""
if key in _memory_cache:
_memory_cache[key] = _memory_cache.pop(key) # Updates "last accessed" order
return _memory_cache[key]
est, real = zip(*_estimated_real)
if estimated_secs is None:
estimated_secs = sum(est) / len(est)
with timed(f'loading {key}') as t:
estimated_secs *= sum(real) / sum(est)
value = get_with_progress(getter, estimated_secs, progress)
_estimated_real.append((estimated_secs, t['secs']))
if not max_cache_size_bytes:
return value
_memory_cache[key] = value
sz = sum(_get_array_sizes(list(_memory_cache.values())))
logging.info('New memory cache size=%.1f MB', sz/1e6)
while sz > max_cache_size_bytes:
k, v = next(iter(_memory_cache.items()))
if k == key:
break
s = sum(_get_array_sizes(v))
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
_memory_cache.pop(k)
sz -= s
return value
def get_memory_cache_info():
"""Returns number of items and total size in bytes."""
sizes = _get_array_sizes(_memory_cache)
return len(_memory_cache), sum(sizes)
def get_system_info():
"""Returns string describing system's RAM/disk status."""
host_colocation = int(os.environ.get('HOST_COLOCATION', '1'))
vm = psutil.virtual_memory()
du = shutil.disk_usage('.')
return (
f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, '
f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G'
)
def get_status(include_system_info=True):
"""Returns string about download/memory/system status."""
mc_len, mc_sz = get_memory_cache_info()
mc_t = _hms(sum(real for _, real in _estimated_real[1:]))
return (
'Timestamp: '
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ ' – Model stats: '
+ downloads_status()
+ ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' +
(' – System: ' + get_system_info() if include_system_info else '')
)