Spaces:
Runtime error
Runtime error
from collections import deque | |
import hashlib | |
import os | |
import shutil | |
import subprocess | |
import time | |
class WeightsDownloadCache: | |
def __init__( | |
self, min_disk_free: int = 10 * (2**30), base_dir: str = "/src/weights-cache" | |
): | |
""" | |
WeightsDownloadCache is meant to track and download weights files as fast | |
as possible, while ensuring there's enough disk space. | |
It tries to keep the most recently used weights files in the cache, so | |
ensure you call ensure() on the weights each time you use them. | |
It will not re-download weights files that are already in the cache. | |
:param min_disk_free: Minimum disk space required to start download, in bytes. | |
:param base_dir: The base directory to store weights files. | |
""" | |
self.min_disk_free = min_disk_free | |
self.base_dir = base_dir | |
self._hits = 0 | |
self._misses = 0 | |
# Least Recently Used (LRU) cache for paths | |
self.lru_paths = deque() | |
if not os.path.exists(base_dir): | |
os.makedirs(base_dir) | |
def _remove_least_recent(self) -> None: | |
""" | |
Remove the least recently used weights file from the cache and disk. | |
""" | |
oldest = self.lru_paths.popleft() | |
self._rm_disk(oldest) | |
def cache_info(self) -> str: | |
""" | |
Get cache information. | |
:return: Cache information. | |
""" | |
return f"CacheInfo(hits={self._hits}, misses={self._misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})" | |
def _rm_disk(self, path: str) -> None: | |
""" | |
Remove a weights file or directory from disk. | |
:param path: Path to remove. | |
""" | |
if os.path.isfile(path): | |
os.remove(path) | |
elif os.path.isdir(path): | |
shutil.rmtree(path) | |
def _has_enough_space(self) -> bool: | |
""" | |
Check if there's enough disk space. | |
:return: True if there's more than min_disk_free free, False otherwise. | |
""" | |
disk_usage = shutil.disk_usage(self.base_dir) | |
print(f"Free disk space: {disk_usage.free}") | |
return disk_usage.free >= self.min_disk_free | |
def ensure(self, url: str) -> str: | |
""" | |
Ensure weights file is in the cache and return its path. | |
This also updates the LRU cache to mark the weights as recently used. | |
:param url: URL to download weights file from, if not in cache. | |
:return: Path to weights. | |
""" | |
path = self.weights_path(url) | |
if path in self.lru_paths: | |
# here we remove to re-add to the end of the LRU (marking it as recently used) | |
self._hits += 1 | |
self.lru_paths.remove(path) | |
else: | |
self._misses += 1 | |
self.download_weights(url, path) | |
self.lru_paths.append(path) # Add file to end of cache | |
return path | |
def weights_path(self, url: str) -> str: | |
""" | |
Generate path to store a weights file based hash of the URL. | |
:param url: URL to download weights file from. | |
:return: Path to store weights file. | |
""" | |
hashed_url = hashlib.sha256(url.encode()).hexdigest() | |
short_hash = hashed_url[:16] # Use the first 16 characters of the hash | |
return os.path.join(self.base_dir, short_hash) | |
def download_weights(self, url: str, dest: str) -> None: | |
""" | |
Download weights file from a URL, ensuring there's enough disk space. | |
:param url: URL to download weights file from. | |
:param dest: Path to store weights file. | |
""" | |
print("Ensuring enough disk space...") | |
while not self._has_enough_space() and len(self.lru_paths) > 0: | |
self._remove_least_recent() | |
print(f"Downloading weights: {url}") | |
st = time.time() | |
# maybe retry with the real url if this doesn't work | |
try: | |
output = subprocess.check_output(["pget", "-x", url, dest], close_fds=True) | |
print(output) | |
except subprocess.CalledProcessError as e: | |
# If download fails, clean up and re-raise exception | |
print(e.output) | |
self._rm_disk(dest) | |
raise e | |
print(f"Downloaded weights in {time.time() - st} seconds") | |