|
|
|
import os |
|
import sys |
|
import time |
|
import yaml |
|
import hashlib |
|
import argparse |
|
|
|
from six.moves import urllib |
|
|
|
required_keys = ['caffemodel', 'caffemodel_url', 'sha1'] |
|
|
|
|
|
def reporthook(count, block_size, total_size): |
|
""" |
|
From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/ |
|
""" |
|
global start_time |
|
if count == 0: |
|
start_time = time.time() |
|
return |
|
duration = (time.time() - start_time) or 0.01 |
|
progress_size = int(count * block_size) |
|
speed = int(progress_size / (1024 * duration)) |
|
percent = int(count * block_size * 100 / total_size) |
|
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % |
|
(percent, progress_size / (1024 * 1024), speed, duration)) |
|
sys.stdout.flush() |
|
|
|
|
|
def parse_readme_frontmatter(dirname): |
|
readme_filename = os.path.join(dirname, 'readme.md') |
|
with open(readme_filename) as f: |
|
lines = [line.strip() for line in f.readlines()] |
|
top = lines.index('---') |
|
bottom = lines.index('---', top + 1) |
|
frontmatter = yaml.load('\n'.join(lines[top + 1:bottom])) |
|
assert all(key in frontmatter for key in required_keys) |
|
return dirname, frontmatter |
|
|
|
|
|
def valid_dirname(dirname): |
|
try: |
|
return parse_readme_frontmatter(dirname) |
|
except Exception as e: |
|
print('ERROR: {}'.format(e)) |
|
raise argparse.ArgumentTypeError( |
|
'Must be valid Caffe model directory with a correct readme.md') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description='Download trained model binary.') |
|
parser.add_argument('dirname', type=valid_dirname) |
|
args = parser.parse_args() |
|
|
|
|
|
dirname = args.dirname[0] |
|
frontmatter = args.dirname[1] |
|
model_filename = os.path.join(dirname, frontmatter['caffemodel']) |
|
|
|
|
|
def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']): |
|
with open(filename, 'rb') as f: |
|
return hashlib.sha1(f.read()).hexdigest() == sha1 |
|
|
|
|
|
if os.path.exists(model_filename) and model_checks_out(): |
|
print("Model already exists.") |
|
sys.exit(0) |
|
|
|
|
|
urllib.request.urlretrieve( |
|
frontmatter['caffemodel_url'], model_filename, reporthook) |
|
if not model_checks_out(): |
|
print('ERROR: model did not download correctly! Run this again.') |
|
sys.exit(1) |
|
|