Spaces:
Runtime error
Runtime error
import pickle | |
import torch | |
import lzma | |
from pathlib import Path | |
from tops import logger | |
class BaseDetector: | |
def __init__(self, cache_directory: str) -> None: | |
if cache_directory is not None: | |
self.cache_directory = Path(cache_directory, str(self.__class__.__name__)) | |
self.cache_directory.mkdir(exist_ok=True, parents=True) | |
def save_to_cache(self, detection, cache_path: Path, after_preprocess=True): | |
logger.log(f"Caching detection to: {cache_path}") | |
with lzma.open(cache_path, "wb") as fp: | |
torch.save( | |
[det.state_dict(after_preprocess=after_preprocess) for det in detection], fp, | |
pickle_protocol=pickle.HIGHEST_PROTOCOL) | |
def load_from_cache(self, cache_path: Path): | |
logger.log(f"Loading detection from cache path: {cache_path}") | |
with lzma.open(cache_path, "rb") as fp: | |
state_dict = torch.load(fp) | |
return [ | |
state["cls"].from_state_dict(state_dict=state) for state in state_dict | |
] | |
def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool): | |
if cache_id is None: | |
return self.forward(im) | |
cache_path = self.cache_directory.joinpath(cache_id + ".torch") | |
if cache_path.is_file() and load_cache: | |
try: | |
return self.load_from_cache(cache_path) | |
except Exception as e: | |
logger.warn(f"The cache file was corrupted: {cache_path}") | |
exit() | |
detections = self.forward(im) | |
self.save_to_cache(detections, cache_path) | |
return detections | |