|
import importlib.util |
|
import json |
|
import logging |
|
import os |
|
import shutil |
|
import tarfile |
|
import tempfile |
|
from functools import partial |
|
from hashlib import sha256 |
|
from pathlib import Path |
|
from typing import Any, BinaryIO, Dict, List, Optional, Union |
|
from urllib.parse import urlparse |
|
from zipfile import ZipFile, is_zipfile |
|
|
|
import huggingface_hub |
|
import requests |
|
import tqdm |
|
from filelock import FileLock |
|
from transformers.utils.hub import cached_file as hf_cached_file |
|
|
|
from relik.common.log import get_logger |
|
|
|
|
|
WEIGHTS_NAME = "weights.pt" |
|
ONNX_WEIGHTS_NAME = "weights.onnx" |
|
CONFIG_NAME = "config.yaml" |
|
LABELS_NAME = "labels.json" |
|
|
|
|
|
SAPIENZANLP_USER_NAME = "riccorl" |
|
SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}" |
|
SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = ( |
|
f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip" |
|
) |
|
|
|
SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp") |
|
SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S" |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def sapienzanlp_model_urls(model_id: str) -> str: |
|
""" |
|
Returns the URL for a possible SapienzaNLP valid model. |
|
|
|
Args: |
|
model_id (:obj:`str`): |
|
A SapienzaNLP model id. |
|
|
|
Returns: |
|
:obj:`str`: The url for the model id. |
|
""" |
|
|
|
if "/" in model_id: |
|
return model_id |
|
return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id) |
|
|
|
|
|
def is_package_available(package_name: str) -> bool: |
|
""" |
|
Check if a package is available. |
|
|
|
Args: |
|
package_name (`str`): The name of the package to check. |
|
""" |
|
return importlib.util.find_spec(package_name) is not None |
|
|
|
|
|
def load_json(path: Union[str, Path]) -> Any: |
|
""" |
|
Load a json file provided in input. |
|
|
|
Args: |
|
path (`Union[str, Path]`): The path to the json file to load. |
|
|
|
Returns: |
|
`Any`: The loaded json file. |
|
""" |
|
with open(path, encoding="utf8") as f: |
|
return json.load(f) |
|
|
|
|
|
def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None): |
|
""" |
|
Dump input to json file. |
|
|
|
Args: |
|
document (`Any`): The document to dump. |
|
path (`Union[str, Path]`): The path to dump the document to. |
|
indent (`Optional[int]`): The indent to use for the json file. |
|
|
|
""" |
|
with open(path, "w", encoding="utf8") as outfile: |
|
json.dump(document, outfile, indent=indent) |
|
|
|
|
|
def get_md5(path: Path): |
|
""" |
|
Get the MD5 value of a path. |
|
""" |
|
import hashlib |
|
|
|
with path.open("rb") as fin: |
|
data = fin.read() |
|
return hashlib.md5(data).hexdigest() |
|
|
|
|
|
def file_exists(path: Union[str, os.PathLike]) -> bool: |
|
""" |
|
Check if the file at :obj:`path` exists. |
|
|
|
Args: |
|
path (:obj:`str`, :obj:`os.PathLike`): |
|
Path to check. |
|
|
|
Returns: |
|
:obj:`bool`: :obj:`True` if the file exists. |
|
""" |
|
return Path(path).exists() |
|
|
|
|
|
def dir_exists(path: Union[str, os.PathLike]) -> bool: |
|
""" |
|
Check if the directory at :obj:`path` exists. |
|
|
|
Args: |
|
path (:obj:`str`, :obj:`os.PathLike`): |
|
Path to check. |
|
|
|
Returns: |
|
:obj:`bool`: :obj:`True` if the directory exists. |
|
""" |
|
return Path(path).is_dir() |
|
|
|
|
|
def is_remote_url(url_or_filename: Union[str, Path]): |
|
""" |
|
Returns :obj:`True` if the input path is an url. |
|
|
|
Args: |
|
url_or_filename (:obj:`str`, :obj:`Path`): |
|
path to check. |
|
|
|
Returns: |
|
:obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise. |
|
|
|
""" |
|
if isinstance(url_or_filename, Path): |
|
url_or_filename = str(url_or_filename) |
|
parsed = urlparse(url_or_filename) |
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
def url_to_filename(resource: str, etag: str = None) -> str: |
|
""" |
|
Convert a `resource` into a hashed filename in a repeatable way. |
|
If `etag` is specified, append its hash to the resources's, delimited |
|
by a period. |
|
""" |
|
resource_bytes = resource.encode("utf-8") |
|
resource_hash = sha256(resource_bytes) |
|
filename = resource_hash.hexdigest() |
|
|
|
if etag: |
|
etag_bytes = etag.encode("utf-8") |
|
etag_hash = sha256(etag_bytes) |
|
filename += "." + etag_hash.hexdigest() |
|
|
|
return filename |
|
|
|
|
|
def download_resource( |
|
url: str, |
|
temp_file: BinaryIO, |
|
headers=None, |
|
): |
|
""" |
|
Download remote file. |
|
""" |
|
|
|
if headers is None: |
|
headers = {} |
|
|
|
r = requests.get(url, stream=True, headers=headers) |
|
r.raise_for_status() |
|
content_length = r.headers.get("Content-Length") |
|
total = int(content_length) if content_length is not None else None |
|
progress = tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
total=total, |
|
desc="Downloading", |
|
disable=logger.level in [logging.NOTSET], |
|
) |
|
for chunk in r.iter_content(chunk_size=1024): |
|
if chunk: |
|
progress.update(len(chunk)) |
|
temp_file.write(chunk) |
|
progress.close() |
|
|
|
|
|
def download_and_cache( |
|
url: Union[str, Path], |
|
cache_dir: Union[str, Path] = None, |
|
force_download: bool = False, |
|
): |
|
if cache_dir is None: |
|
cache_dir = SAPIENZANLP_CACHE_DIR |
|
if isinstance(url, Path): |
|
url = str(url) |
|
|
|
|
|
Path(cache_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
headers = {} |
|
try: |
|
r = requests.head(url, allow_redirects=False, timeout=10) |
|
r.raise_for_status() |
|
except requests.exceptions.HTTPError: |
|
if r.status_code == 401: |
|
hf_token = huggingface_hub.HfFolder.get_token() |
|
if hf_token is None: |
|
raise ValueError( |
|
"You need to login to HuggingFace to download this model " |
|
"(use the `huggingface-cli login` command)" |
|
) |
|
headers["Authorization"] = f"Bearer {hf_token}" |
|
|
|
etag = None |
|
try: |
|
r = requests.head(url, allow_redirects=True, timeout=10, headers=headers) |
|
r.raise_for_status() |
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") |
|
|
|
|
|
|
|
if etag is None: |
|
raise OSError( |
|
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." |
|
) |
|
|
|
|
|
|
|
|
|
if 300 <= r.status_code <= 399: |
|
url = r.headers["Location"] |
|
except (requests.exceptions.SSLError, requests.exceptions.ProxyError): |
|
|
|
raise |
|
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): |
|
|
|
|
|
pass |
|
|
|
|
|
filename = url_to_filename(url, etag) |
|
|
|
cache_path = cache_dir / filename |
|
|
|
|
|
if file_exists(cache_path) and not force_download: |
|
logger.info( |
|
f"{url} found in cache, set `force_download=True` to force the download" |
|
) |
|
return cache_path |
|
|
|
cache_path = str(cache_path) |
|
|
|
lock_path = cache_path + ".lock" |
|
with FileLock(lock_path): |
|
|
|
if file_exists(cache_path) and not force_download: |
|
|
|
return cache_path |
|
|
|
temp_file_manager = partial( |
|
tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False |
|
) |
|
|
|
|
|
|
|
with temp_file_manager() as temp_file: |
|
logger.info( |
|
f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}" |
|
) |
|
download_resource(url, temp_file, headers) |
|
|
|
logger.info(f"storing {url} in cache at {cache_path}") |
|
os.replace(temp_file.name, cache_path) |
|
|
|
|
|
umask = os.umask(0o666) |
|
os.umask(umask) |
|
os.chmod(cache_path, 0o666 & ~umask) |
|
|
|
logger.info(f"creating metadata file for {cache_path}") |
|
meta = {"url": url} |
|
meta_path = cache_path + ".json" |
|
with open(meta_path, "w") as meta_file: |
|
json.dump(meta, meta_file) |
|
|
|
return cache_path |
|
|
|
|
|
def download_from_hf( |
|
path_or_repo_id: Union[str, Path], |
|
filenames: Optional[List[str]], |
|
cache_dir: Union[str, Path] = None, |
|
force_download: bool = False, |
|
resume_download: bool = False, |
|
proxies: Optional[Dict[str, str]] = None, |
|
use_auth_token: Optional[Union[bool, str]] = None, |
|
revision: Optional[str] = None, |
|
local_files_only: bool = False, |
|
subfolder: str = "", |
|
): |
|
if isinstance(path_or_repo_id, Path): |
|
path_or_repo_id = str(path_or_repo_id) |
|
|
|
downloaded_paths = [] |
|
for filename in filenames: |
|
downloaded_path = hf_cached_file( |
|
path_or_repo_id, |
|
filename, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
use_auth_token=use_auth_token, |
|
revision=revision, |
|
local_files_only=local_files_only, |
|
subfolder=subfolder, |
|
) |
|
downloaded_paths.append(downloaded_path) |
|
|
|
|
|
|
|
probably_the_folder = Path(downloaded_paths[0]).parent |
|
return probably_the_folder |
|
|
|
|
|
def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str: |
|
""" |
|
Resolve a model name or directory to a model archive name or directory. |
|
|
|
Args: |
|
model_name_or_dir (:obj:`str` or :obj:`os.PathLike`): |
|
A model name or directory. |
|
|
|
Returns: |
|
:obj:`str`: The model archive name or directory. |
|
""" |
|
if is_remote_url(model_name_or_dir): |
|
|
|
|
|
model_archive = model_name_or_dir |
|
elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file(): |
|
|
|
|
|
model_archive = model_name_or_dir |
|
else: |
|
|
|
|
|
model_name_or_dir_ = model_name_or_dir |
|
|
|
model_archive = sapienzanlp_model_urls(model_name_or_dir_) |
|
|
|
return model_archive |
|
|
|
|
|
def from_cache( |
|
url_or_filename: Union[str, Path], |
|
cache_dir: Union[str, Path] = None, |
|
force_download: bool = False, |
|
resume_download: bool = False, |
|
proxies: Optional[Dict[str, str]] = None, |
|
use_auth_token: Optional[Union[bool, str]] = None, |
|
revision: Optional[str] = None, |
|
local_files_only: bool = False, |
|
subfolder: str = "", |
|
filenames: Optional[List[str]] = None, |
|
) -> Path: |
|
""" |
|
Given something that could be either a local path or a URL (or a SapienzaNLP model id), |
|
determine which one and return a path to the corresponding file. |
|
|
|
Args: |
|
url_or_filename (:obj:`str` or :obj:`Path`): |
|
A path to a local file or a URL (or a SapienzaNLP model id). |
|
cache_dir (:obj:`str` or :obj:`Path`, `optional`): |
|
Path to a directory in which a downloaded file will be cached. |
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): |
|
Whether or not to re-download the file even if it already exists. |
|
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): |
|
Whether or not to delete incompletely received files. Attempts to resume the download if such a file |
|
exists. |
|
proxies (:obj:`Dict[str, str]`, `optional`): |
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', |
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. |
|
use_auth_token (:obj:`Union[bool, str]`, `optional`): |
|
Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from |
|
:obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token. |
|
revision (:obj:`str`, `optional`): |
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any |
|
identifier allowed by git. |
|
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): |
|
Whether or not to raise an error if the file to be downloaded is local. |
|
subfolder (:obj:`str`, `optional`): |
|
In case the relevant file is in a subfolder of the URL, specify it here. |
|
filenames (:obj:`List[str]`, `optional`): |
|
List of filenames to look for in the directory structure. |
|
|
|
Returns: |
|
:obj:`Path`: Path to the cached file. |
|
""" |
|
|
|
url_or_filename = model_name_or_path_resolver(url_or_filename) |
|
|
|
if cache_dir is None: |
|
cache_dir = SAPIENZANLP_CACHE_DIR |
|
|
|
if file_exists(url_or_filename): |
|
logger.info(f"{url_or_filename} is a local path or file") |
|
output_path = url_or_filename |
|
elif is_remote_url(url_or_filename): |
|
|
|
output_path = download_and_cache( |
|
url_or_filename, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
) |
|
else: |
|
if filenames is None: |
|
filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME] |
|
output_path = download_from_hf( |
|
url_or_filename, |
|
filenames, |
|
cache_dir, |
|
force_download, |
|
resume_download, |
|
proxies, |
|
use_auth_token, |
|
revision, |
|
local_files_only, |
|
subfolder, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dir_exists(output_path) or ( |
|
not is_zipfile(output_path) and not tarfile.is_tarfile(output_path) |
|
): |
|
return Path(output_path) |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Extracting compressed archive") |
|
output_dir, output_file = os.path.split(output_path) |
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
|
|
|
|
|
if ( |
|
os.path.isdir(output_path_extracted) |
|
and os.listdir(output_path_extracted) |
|
and not force_download |
|
): |
|
return Path(output_path_extracted) |
|
|
|
|
|
lock_path = output_path + ".lock" |
|
with FileLock(lock_path): |
|
shutil.rmtree(output_path_extracted, ignore_errors=True) |
|
os.makedirs(output_path_extracted) |
|
if is_zipfile(output_path): |
|
with ZipFile(output_path, "r") as zip_file: |
|
zip_file.extractall(output_path_extracted) |
|
zip_file.close() |
|
elif tarfile.is_tarfile(output_path): |
|
tar_file = tarfile.open(output_path) |
|
tar_file.extractall(output_path_extracted) |
|
tar_file.close() |
|
else: |
|
raise EnvironmentError( |
|
f"Archive format of {output_path} could not be identified" |
|
) |
|
|
|
|
|
os.remove(lock_path) |
|
|
|
return Path(output_path_extracted) |
|
|
|
|
|
def is_str_a_path(maybe_path: str) -> bool: |
|
""" |
|
Check if a string is a path. |
|
|
|
Args: |
|
maybe_path (`str`): The string to check. |
|
|
|
Returns: |
|
`bool`: `True` if the string is a path, `False` otherwise. |
|
""" |
|
|
|
if Path(maybe_path).exists(): |
|
return True |
|
|
|
if Path(os.path.join(os.getcwd(), maybe_path)).exists(): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def relative_to_absolute_path(path: str) -> os.PathLike: |
|
""" |
|
Convert a relative path to an absolute path. |
|
|
|
Args: |
|
path (`str`): The relative path to convert. |
|
|
|
Returns: |
|
`os.PathLike`: The absolute path. |
|
""" |
|
if not is_str_a_path(path): |
|
raise ValueError(f"{path} is not a path") |
|
if Path(path).exists(): |
|
return Path(path).absolute() |
|
if Path(os.path.join(os.getcwd(), path)).exists(): |
|
return Path(os.path.join(os.getcwd(), path)).absolute() |
|
raise ValueError(f"{path} is not a path") |
|
|
|
|
|
def to_config(object_to_save: Any) -> Dict[str, Any]: |
|
""" |
|
Convert an object to a dictionary. |
|
|
|
Returns: |
|
`Dict[str, Any]`: The dictionary representation of the object. |
|
""" |
|
|
|
def obj_to_dict(obj): |
|
match obj: |
|
case dict(): |
|
data = {} |
|
for k, v in obj.items(): |
|
data[k] = obj_to_dict(v) |
|
return data |
|
|
|
case list() | tuple(): |
|
return [obj_to_dict(x) for x in obj] |
|
|
|
case object(__dict__=_): |
|
data = { |
|
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", |
|
} |
|
for k, v in obj.__dict__.items(): |
|
if not k.startswith("_"): |
|
data[k] = obj_to_dict(v) |
|
return data |
|
|
|
case _: |
|
return obj |
|
|
|
return obj_to_dict(object_to_save) |
|
|
|
|
|
def get_callable_from_string(callable_fn: str) -> Any: |
|
""" |
|
Get a callable from a string. |
|
|
|
Args: |
|
callable_fn (`str`): |
|
The string representation of the callable. |
|
|
|
Returns: |
|
`Any`: The callable. |
|
""" |
|
|
|
module_name, function_name = callable_fn.rsplit(".", 1) |
|
|
|
module = importlib.import_module(module_name) |
|
|
|
return getattr(module, function_name) |
|
|