Spaces:
Sleeping
Sleeping
File size: 2,499 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import joblib
import numpy as np
import pickle
import boto3
from rlkit.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH
import os
PICKLE = 'pickle'
NUMPY = 'numpy'
JOBLIB = 'joblib'
def local_path_from_s3_or_local_path(filename):
relative_filename = os.path.join(LOCAL_LOG_DIR, filename)
if os.path.isfile(filename):
return filename
elif os.path.isfile(relative_filename):
return relative_filename
else:
return sync_down(filename)
def sync_down(path, check_exists=True):
is_docker = os.path.isfile("/.dockerenv")
if is_docker:
local_path = "/tmp/%s" % (path)
else:
local_path = "%s/%s" % (LOCAL_LOG_DIR, path)
if check_exists and os.path.isfile(local_path):
return local_path
local_dir = os.path.dirname(local_path)
os.makedirs(local_dir, exist_ok=True)
if is_docker:
from doodad.ec2.autoconfig import AUTOCONFIG
os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key()
os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret()
full_s3_path = os.path.join(AWS_S3_PATH, path)
bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path)
try:
bucket = boto3.resource('s3').Bucket(bucket_name)
bucket.download_file(bucket_relative_path, local_path)
except Exception as e:
local_path = None
print("Failed to sync! path: ", path)
print("Exception: ", e)
return local_path
def split_s3_full_path(s3_path):
"""
Split "s3://foo/bar/baz" into "foo" and "bar/baz"
"""
bucket_name_and_directories = s3_path.split('//')[1]
bucket_name, *directories = bucket_name_and_directories.split('/')
directory_path = '/'.join(directories)
return bucket_name, directory_path
def load_local_or_remote_file(filepath, file_type=None):
local_path = local_path_from_s3_or_local_path(filepath)
if file_type is None:
extension = local_path.split('.')[-1]
if extension == 'npy':
file_type = NUMPY
else:
file_type = PICKLE
else:
file_type = PICKLE
if file_type == NUMPY:
object = np.load(open(local_path, "rb"))
elif file_type == JOBLIB:
object = joblib.load(local_path)
else:
object = pickle.load(open(local_path, "rb"))
print("loaded", local_path)
return object
if __name__ == "__main__":
p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl")
print("got", p) |