Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) 2017-present, Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
"""IO utilities.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from __future__ import unicode_literals | |
from six.moves import cPickle as pickle | |
import hashlib | |
import logging | |
import os | |
import re | |
import sys | |
try: | |
from urllib.request import urlopen | |
except ImportError: #python2 | |
from urllib2 import urlopen | |
logger = logging.getLogger(__name__) | |
_DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron' | |
def save_object(obj, file_name): | |
"""Save a Python object by pickling it.""" | |
file_name = os.path.abspath(file_name) | |
with open(file_name, 'wb') as f: | |
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) | |
def cache_url(url_or_file, cache_dir): | |
"""Download the file specified by the URL to the cache_dir and return the | |
path to the cached file. If the argument is not a URL, simply return it as | |
is. | |
""" | |
is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None | |
if not is_url: | |
return url_or_file | |
url = url_or_file | |
# assert url.startswith(_DETECTRON_S3_BASE_URL), \ | |
# ('Detectron only automatically caches URLs in the Detectron S3 ' | |
# 'bucket: {}').format(_DETECTRON_S3_BASE_URL) | |
# | |
# cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir) | |
Len_filename = len(url.split('/')[-1]) | |
BASE_URL = url[0:-Len_filename - 1] | |
# | |
cache_file_path = url.replace(BASE_URL, cache_dir) | |
if os.path.exists(cache_file_path): | |
# assert_cache_file_is_ok(url, cache_file_path) | |
return cache_file_path | |
cache_file_dir = os.path.dirname(cache_file_path) | |
if not os.path.exists(cache_file_dir): | |
os.makedirs(cache_file_dir) | |
logger.info('Downloading remote file {} to {}'.format(url, cache_file_path)) | |
download_url(url, cache_file_path) | |
# assert_cache_file_is_ok(url, cache_file_path) | |
return cache_file_path | |
def assert_cache_file_is_ok(url, file_path): | |
"""Check that cache file has the correct hash.""" | |
# File is already in the cache, verify that the md5sum matches and | |
# return local path | |
cache_file_md5sum = _get_file_md5sum(file_path) | |
ref_md5sum = _get_reference_md5sum(url) | |
assert cache_file_md5sum == ref_md5sum, \ | |
('Target URL {} appears to be downloaded to the local cache file ' | |
'{}, but the md5 hash of the local file does not match the ' | |
'reference (actual: {} vs. expected: {}). You may wish to delete ' | |
'the cached file and try again to trigger automatic ' | |
'download.').format(url, file_path, cache_file_md5sum, ref_md5sum) | |
def _progress_bar(count, total): | |
"""Report download progress. | |
Credit: | |
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 | |
""" | |
bar_len = 60 | |
filled_len = int(round(bar_len * count / float(total))) | |
percents = round(100.0 * count / float(total), 1) | |
bar = '=' * filled_len + '-' * (bar_len - filled_len) | |
sys.stdout.write(' [{}] {}% of {:.1f}MB file \r'.format(bar, percents, total / 1024 / 1024)) | |
sys.stdout.flush() | |
if count >= total: | |
sys.stdout.write('\n') | |
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): | |
"""Download url and write it to dst_file_path. | |
Credit: | |
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook | |
""" | |
response = urlopen(url) | |
total_size = response.info().getheader('Content-Length').strip() | |
total_size = int(total_size) | |
bytes_so_far = 0 | |
with open(dst_file_path, 'wb') as f: | |
while 1: | |
chunk = response.read(chunk_size) | |
bytes_so_far += len(chunk) | |
if not chunk: | |
break | |
if progress_hook: | |
progress_hook(bytes_so_far, total_size) | |
f.write(chunk) | |
return bytes_so_far | |
def _get_file_md5sum(file_name): | |
"""Compute the md5 hash of a file.""" | |
hash_obj = hashlib.md5() | |
with open(file_name, 'r') as f: | |
hash_obj.update(f.read()) | |
return hash_obj.hexdigest() | |
def _get_reference_md5sum(url): | |
"""By convention the md5 hash for url is stored in url + '.md5sum'.""" | |
url_md5sum = url + '.md5sum' | |
md5sum = urlopen(url_md5sum).read().strip() | |
return md5sum | |