|
import os |
|
import sys |
|
|
|
try: |
|
from urllib import urlretrieve |
|
except ImportError: |
|
from urllib.request import urlretrieve |
|
|
|
import torch |
|
|
|
|
|
def load_url(url, model_dir="./pretrained", map_location=torch.device("cpu")): |
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir) |
|
filename = url.split("/")[-1] |
|
cached_file = os.path.join(model_dir, filename) |
|
if not os.path.exists(cached_file): |
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
urlretrieve(url, cached_file) |
|
return torch.load(cached_file, map_location=map_location) |
|
|