import joblib import numpy as np import pickle import boto3 from rlkit.launchers.conf import LOCAL_LOG_DIR, AWS_S3_PATH import os PICKLE = 'pickle' NUMPY = 'numpy' JOBLIB = 'joblib' def local_path_from_s3_or_local_path(filename): relative_filename = os.path.join(LOCAL_LOG_DIR, filename) if os.path.isfile(filename): return filename elif os.path.isfile(relative_filename): return relative_filename else: return sync_down(filename) def sync_down(path, check_exists=True): is_docker = os.path.isfile("/.dockerenv") if is_docker: local_path = "/tmp/%s" % (path) else: local_path = "%s/%s" % (LOCAL_LOG_DIR, path) if check_exists and os.path.isfile(local_path): return local_path local_dir = os.path.dirname(local_path) os.makedirs(local_dir, exist_ok=True) if is_docker: from doodad.ec2.autoconfig import AUTOCONFIG os.environ["AWS_ACCESS_KEY_ID"] = AUTOCONFIG.aws_access_key() os.environ["AWS_SECRET_ACCESS_KEY"] = AUTOCONFIG.aws_access_secret() full_s3_path = os.path.join(AWS_S3_PATH, path) bucket_name, bucket_relative_path = split_s3_full_path(full_s3_path) try: bucket = boto3.resource('s3').Bucket(bucket_name) bucket.download_file(bucket_relative_path, local_path) except Exception as e: local_path = None print("Failed to sync! path: ", path) print("Exception: ", e) return local_path def split_s3_full_path(s3_path): """ Split "s3://foo/bar/baz" into "foo" and "bar/baz" """ bucket_name_and_directories = s3_path.split('//')[1] bucket_name, *directories = bucket_name_and_directories.split('/') directory_path = '/'.join(directories) return bucket_name, directory_path def load_local_or_remote_file(filepath, file_type=None): local_path = local_path_from_s3_or_local_path(filepath) if file_type is None: extension = local_path.split('.')[-1] if extension == 'npy': file_type = NUMPY else: file_type = PICKLE else: file_type = PICKLE if file_type == NUMPY: object = np.load(open(local_path, "rb")) elif file_type == JOBLIB: object = joblib.load(local_path) else: object = pickle.load(open(local_path, "rb")) print("loaded", local_path) return object if __name__ == "__main__": p = sync_down("ashvin/vae/new-point2d/run0/id1/params.pkl") print("got", p)