|
|
|
"""Misc utility functions.""" |
|
|
|
import os |
|
import hashlib |
|
|
|
from torch.hub import download_url_to_file |
|
|
|
__all__ = [ |
|
'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext', |
|
'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS', |
|
'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url' |
|
] |
|
|
|
REPO_NAME = 'Hammer' |
|
|
|
|
|
class Infix(object): |
|
"""Helper class to create custom infix operators. |
|
|
|
When using it, make sure to put the operator between `<<` and `>>`. |
|
`<< INFIX_OP_NAME >>` should be considered as a whole operator. |
|
|
|
Examples: |
|
|
|
# Use `Infix` to create infix operators directly. |
|
add = Infix(lambda a, b: a + b) |
|
1 << add >> 2 # gives 3 |
|
1 << add >> 2 << add >> 3 # gives 6 |
|
|
|
# Use `Infix` as a decorator. |
|
@Infix |
|
def mul(a, b): |
|
return a * b |
|
2 << mul >> 4 # gives 8 |
|
2 << mul >> 3 << mul >> 7 # gives 42 |
|
""" |
|
|
|
def __init__(self, function): |
|
self.function = function |
|
self.left_value = None |
|
|
|
def __rlshift__(self, left_value): |
|
assert self.left_value is None |
|
self.left_value = left_value |
|
return self |
|
|
|
def __rshift__(self, right_value): |
|
result = self.function(self.left_value, right_value) |
|
self.left_value = None |
|
return result |
|
|
|
|
|
def print_and_execute(cmd): |
|
"""Prints and executes a system command. |
|
|
|
Args: |
|
cmd: Command to be executed. |
|
""" |
|
print(cmd) |
|
os.system(cmd) |
|
|
|
|
|
def check_file_ext(filename, *ext_list): |
|
"""Checks whether the given filename is with target extension(s). |
|
|
|
NOTE: If `ext_list` is empty, this function will always return `False`. |
|
|
|
Args: |
|
filename: Filename to check. |
|
*ext_list: A list of extensions. |
|
|
|
Returns: |
|
`True` if the filename is with one of extensions in `ext_list`, |
|
otherwise `False`. |
|
""" |
|
if len(ext_list) == 0: |
|
return False |
|
ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] |
|
ext_list = [ext.lower() for ext in ext_list] |
|
basename = os.path.basename(filename) |
|
ext = os.path.splitext(basename)[1].lower() |
|
return ext in ext_list |
|
|
|
|
|
|
|
IMAGE_EXTENSIONS = ( |
|
'.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', |
|
'.tiff', '.tif' |
|
) |
|
|
|
VIDEO_EXTENSIONS = ( |
|
'.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm', |
|
'.3gp' |
|
) |
|
|
|
MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS) |
|
|
|
|
|
def parse_file_format(path): |
|
"""Parses the file format of a given path. |
|
|
|
This function basically parses the file format according to its extension. |
|
It will also return `dir` is the given path is a directory. |
|
|
|
Parable file formats: |
|
|
|
- zip: with `.zip` extension. |
|
- tar: with `.tar` / `.tgz` / `.tar.gz` extension. |
|
- lmdb: a folder ending with `lmdb`. |
|
- txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE). |
|
- json: with `.json` extension. |
|
- jpg: with `.jpeg` / `jpg` / `jpe` extension. |
|
- png: with `.png` extension. |
|
|
|
Args: |
|
path: The path to the file to parse format from. |
|
|
|
Returns: |
|
A lower-case string, indicating the file format, or `None` if the format |
|
cannot be successfully parsed. |
|
""" |
|
|
|
if os.path.isdir(path) or path.endswith('/'): |
|
if path.rstrip('/').lower().endswith('lmdb'): |
|
return 'lmdb' |
|
return 'dir' |
|
|
|
if os.path.isfile(path) and os.path.splitext(path)[1] == '': |
|
return 'txt' |
|
path = path.lower() |
|
if path.endswith('.tar.gz'): |
|
return 'tar' |
|
ext = os.path.splitext(path)[1] |
|
if ext == '.zip': |
|
return 'zip' |
|
if ext in ['.tar', '.tgz']: |
|
return 'tar' |
|
if ext in ['.txt', '.text']: |
|
return 'txt' |
|
if ext == '.json': |
|
return 'json' |
|
if ext in ['.jpeg', '.jpg', '.jpe']: |
|
return 'jpg' |
|
if ext == '.png': |
|
return 'png' |
|
|
|
return None |
|
|
|
|
|
_cache_dir = None |
|
|
|
|
|
def set_cache_dir(directory=None): |
|
"""Sets the global cache directory. |
|
|
|
The cache directory can be used to save some files that will be shared |
|
across jobs. The default cache directory is set as `~/.cache/`. This |
|
function can be used to redirect the cache directory. Or, users can use |
|
`None` to reset the cache directory back to default. |
|
|
|
Args: |
|
directory: The target directory used to cache files. If set as `None`, |
|
the cache directory will be reset back to default. (default: None) |
|
""" |
|
assert directory is None or isinstance(directory, str), 'Invalid directory!' |
|
global _cache_dir |
|
_cache_dir = directory |
|
|
|
|
|
def get_cache_dir(use_repo_name=True): |
|
"""Gets the global cache directory. |
|
|
|
The global cache directory is primarily set as `~/.cache/` by default, and |
|
can be redirected with `set_cache_dir()`. |
|
|
|
Args: |
|
use_repo_name: Whether to create a folder, named `REPO_NAME`, under |
|
`_cache_dir` as the actual cache directory. (default: True) |
|
|
|
Returns: |
|
A string, representing the global cache directory. |
|
""" |
|
if _cache_dir is None: |
|
cache_dir = os.path.join(os.path.expanduser('~'), '.cache') |
|
else: |
|
cache_dir = _cache_dir |
|
if use_repo_name: |
|
return os.path.join(cache_dir, REPO_NAME) |
|
return cache_dir |
|
|
|
|
|
def download_url(url, path=None, filename=None, sha256=None): |
|
"""Downloads file from URL. |
|
|
|
This function downloads a file from given URL, and executes Hash check if |
|
needed. |
|
|
|
Args: |
|
url: The URL to download file from. |
|
path: Path (directory) to save the downloaded file. If set as `None`, |
|
the cache directory will be used. Please see `get_cache_dir()` for |
|
more details. (default: None) |
|
filename: The name to save the file. If set as `None`, this name will be |
|
automatically parsed from the given URL. (default: None) |
|
sha256: The expected sha256 of the downloaded file. If set as `None`, |
|
the hash check will be skipped. Otherwise, this function will check |
|
whether the sha256 of the downloaded file matches this field. |
|
|
|
Returns: |
|
A two-element tuple, where the first term is the full path of the |
|
downloaded file, and the second term indicate the hash check result. |
|
`True` means hash check passes, `False` means hash check fails, |
|
while `None` means no hash check is executed. |
|
""" |
|
|
|
if path is None: |
|
path = get_cache_dir() |
|
if filename is None: |
|
filename = os.path.basename(url) |
|
save_path = os.path.join(path, filename) |
|
|
|
if not os.path.exists(save_path): |
|
print(f'Downloading URL `{url}` to path `{save_path}` ...') |
|
os.makedirs(path, exist_ok=True) |
|
download_url_to_file(url, save_path, hash_prefix=None, progress=True) |
|
|
|
check_result = None |
|
if sha256 is not None: |
|
with open(save_path, 'rb') as f: |
|
file_hash = hashlib.sha256(f.read()) |
|
check_result = (file_hash.hexdigest() == sha256) |
|
|
|
return save_path, check_result |
|
|