fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
4.98 kB
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""IO utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from six.moves import cPickle as pickle
import hashlib
import logging
import os
import re
import sys
try:
from urllib.request import urlopen
except ImportError: #python2
from urllib2 import urlopen
logger = logging.getLogger(__name__)
_DETECTRON_S3_BASE_URL = 'https://s3-us-west-2.amazonaws.com/detectron'
def save_object(obj, file_name):
"""Save a Python object by pickling it."""
file_name = os.path.abspath(file_name)
with open(file_name, 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def cache_url(url_or_file, cache_dir):
"""Download the file specified by the URL to the cache_dir and return the
path to the cached file. If the argument is not a URL, simply return it as
is.
"""
is_url = re.match(r'^(?:http)s?://', url_or_file, re.IGNORECASE) is not None
if not is_url:
return url_or_file
url = url_or_file
# assert url.startswith(_DETECTRON_S3_BASE_URL), \
# ('Detectron only automatically caches URLs in the Detectron S3 '
# 'bucket: {}').format(_DETECTRON_S3_BASE_URL)
#
# cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir)
Len_filename = len(url.split('/')[-1])
BASE_URL = url[0:-Len_filename - 1]
#
cache_file_path = url.replace(BASE_URL, cache_dir)
if os.path.exists(cache_file_path):
# assert_cache_file_is_ok(url, cache_file_path)
return cache_file_path
cache_file_dir = os.path.dirname(cache_file_path)
if not os.path.exists(cache_file_dir):
os.makedirs(cache_file_dir)
logger.info('Downloading remote file {} to {}'.format(url, cache_file_path))
download_url(url, cache_file_path)
# assert_cache_file_is_ok(url, cache_file_path)
return cache_file_path
def assert_cache_file_is_ok(url, file_path):
"""Check that cache file has the correct hash."""
# File is already in the cache, verify that the md5sum matches and
# return local path
cache_file_md5sum = _get_file_md5sum(file_path)
ref_md5sum = _get_reference_md5sum(url)
assert cache_file_md5sum == ref_md5sum, \
('Target URL {} appears to be downloaded to the local cache file '
'{}, but the md5 hash of the local file does not match the '
'reference (actual: {} vs. expected: {}). You may wish to delete '
'the cached file and try again to trigger automatic '
'download.').format(url, file_path, cache_file_md5sum, ref_md5sum)
def _progress_bar(count, total):
"""Report download progress.
Credit:
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)
sys.stdout.write(' [{}] {}% of {:.1f}MB file \r'.format(bar, percents, total / 1024 / 1024))
sys.stdout.flush()
if count >= total:
sys.stdout.write('\n')
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
"""Download url and write it to dst_file_path.
Credit:
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
"""
response = urlopen(url)
total_size = response.info().getheader('Content-Length').strip()
total_size = int(total_size)
bytes_so_far = 0
with open(dst_file_path, 'wb') as f:
while 1:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
if progress_hook:
progress_hook(bytes_so_far, total_size)
f.write(chunk)
return bytes_so_far
def _get_file_md5sum(file_name):
"""Compute the md5 hash of a file."""
hash_obj = hashlib.md5()
with open(file_name, 'r') as f:
hash_obj.update(f.read())
return hash_obj.hexdigest()
def _get_reference_md5sum(url):
"""By convention the md5 hash for url is stored in url + '.md5sum'."""
url_md5sum = url + '.md5sum'
md5sum = urlopen(url_md5sum).read().strip()
return md5sum