Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import os | |
import sys | |
try: | |
from torch.hub import _download_url_to_file | |
from torch.hub import urlparse | |
from torch.hub import HASH_REGEX | |
except ImportError: | |
from torch.utils.model_zoo import _download_url_to_file | |
from torch.utils.model_zoo import urlparse | |
from torch.utils.model_zoo import HASH_REGEX | |
from maskrcnn_benchmark.utils.comm import is_main_process | |
from maskrcnn_benchmark.utils.comm import synchronize | |
# very similar to https://github.com/pytorch/pytorch/blob/master/torch/utils/model_zoo.py | |
# but with a few improvements and modifications | |
def cache_url(url, model_dir='model', progress=True): | |
r"""Loads the Torch serialized object at the given URL. | |
If the object is already present in `model_dir`, it's deserialized and | |
returned. The filename part of the URL should follow the naming convention | |
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |
digits of the SHA256 hash of the contents of the file. The hash is used to | |
ensure unique names and to verify the contents of the file. | |
The default value of `model_dir` is ``$TORCH_HOME/models`` where | |
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be | |
overridden with the ``$TORCH_MODEL_ZOO`` environment variable. | |
Args: | |
url (string): URL of the object to download | |
model_dir (string, optional): directory in which to save the object | |
progress (bool, optional): whether or not to display a progress bar to stderr | |
Example: | |
>>> cached_file = maskrcnn_benchmark.utils.model_zoo.cache_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |
""" | |
if model_dir is None: | |
torch_home = os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch")) | |
model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir, exist_ok=True) | |
parts = urlparse(url) | |
filename = os.path.basename(parts.path) | |
if filename == "model_final.pkl": | |
# workaround as pre-trained Caffe2 models from Detectron have all the same filename | |
# so make the full path the filename by replacing / with _ | |
filename = parts.path.replace("/", "_") | |
cached_file = os.path.join(model_dir, filename) | |
if not os.path.exists(cached_file): | |
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
hash_prefix = HASH_REGEX.search(filename) | |
if hash_prefix is not None: | |
hash_prefix = hash_prefix.group(1) | |
# workaround: Caffe2 models don't have a hash, but follow the R-50 convention, | |
# which matches the hash PyTorch uses. So we skip the hash matching | |
# if the hash_prefix is less than 6 characters | |
if len(hash_prefix) < 6: | |
hash_prefix = None | |
_download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |
synchronize() | |
return cached_file | |