Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
"""ONNX Model Hub | |
This implements the python client for the ONNX model hub. | |
""" | |
import hashlib | |
import json | |
import os | |
import sys | |
import tarfile | |
from io import BytesIO | |
from os.path import join | |
from typing import IO, Any, Dict, List, Optional, Set, Tuple, cast | |
from urllib.error import HTTPError | |
from urllib.request import urlopen | |
import onnx | |
if "ONNX_HOME" in os.environ: | |
_ONNX_HUB_DIR = join(os.environ["ONNX_HOME"], "hub") | |
elif "XDG_CACHE_HOME" in os.environ: | |
_ONNX_HUB_DIR = join(os.environ["XDG_CACHE_HOME"], "onnx", "hub") | |
else: | |
_ONNX_HUB_DIR = join(os.path.expanduser("~"), ".cache", "onnx", "hub") | |
class ModelInfo: | |
"""A class to represent a model's property and metadata in the ONNX Hub. | |
It extracts model name, path, sha, tags, etc. from the passed in raw_model_info dict. | |
Attributes: | |
model: The name of the model. | |
model_path: The path to the model, relative to the model zoo (https://github.com/onnx/models/) repo root. | |
metadata: Additional metadata of the model, such as the size of the model, IO ports, etc. | |
model_sha: The SHA256 digest of the model file. | |
tags: A set of tags associated with the model. | |
opset: The opset version of the model. | |
""" | |
def __init__(self, raw_model_info: Dict[str, Any]) -> None: | |
"""Initializer. | |
Args: | |
raw_model_info: A JSON dict containing the model info. | |
""" | |
self.model = cast(str, raw_model_info["model"]) | |
self.model_path = cast(str, raw_model_info["model_path"]) | |
self.metadata: Dict[str, Any] = cast(Dict[str, Any], raw_model_info["metadata"]) | |
self.model_sha: Optional[str] = None | |
if "model_sha" in self.metadata: | |
self.model_sha = cast(str, self.metadata["model_sha"]) | |
self.tags: Set[str] = set() | |
if "tags" in self.metadata: | |
self.tags = set(cast(List[str], self.metadata["tags"])) | |
self.opset = cast(int, raw_model_info["opset_version"]) | |
self.raw_model_info: Dict[str, Any] = raw_model_info | |
def __str__(self) -> str: | |
return f"ModelInfo(model={self.model}, opset={self.opset}, path={self.model_path}, metadata={self.metadata})" | |
def __repr__(self) -> str: | |
return self.__str__() | |
def set_dir(new_dir: str) -> None: | |
"""Sets the current ONNX hub cache location. | |
Args: | |
new_dir: Location of new model hub cache. | |
""" | |
global _ONNX_HUB_DIR # noqa: PLW0603 | |
_ONNX_HUB_DIR = new_dir | |
def get_dir() -> str: | |
"""Gets the current ONNX hub cache location. | |
Returns: | |
The location of the ONNX hub model cache. | |
""" | |
return _ONNX_HUB_DIR | |
def _parse_repo_info(repo: str) -> Tuple[str, str, str]: | |
"""Gets the repo owner, name and ref from a repo specification string.""" | |
repo_owner = repo.split(":")[0].split("/")[0] | |
repo_name = repo.split(":")[0].split("/")[1] | |
if ":" in repo: | |
repo_ref = repo.split(":")[1] | |
else: | |
repo_ref = "main" | |
return repo_owner, repo_name, repo_ref | |
def _verify_repo_ref(repo: str) -> bool: | |
"""Verifies whether the given model repo can be trusted. | |
A model repo can be trusted if it matches onnx/models:main. | |
""" | |
repo_owner, repo_name, repo_ref = _parse_repo_info(repo) | |
return (repo_owner == "onnx") and (repo_name == "models") and (repo_ref == "main") | |
def _get_base_url(repo: str, lfs: bool = False) -> str: | |
"""Gets the base github url from a repo specification string. | |
Args: | |
repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main". | |
lfs: Whether the url is for downloading lfs models. | |
Returns: | |
The base github url for downloading. | |
""" | |
repo_owner, repo_name, repo_ref = _parse_repo_info(repo) | |
if lfs: | |
return f"https://media.githubusercontent.com/media/{repo_owner}/{repo_name}/{repo_ref}/" | |
return f"https://raw.githubusercontent.com/{repo_owner}/{repo_name}/{repo_ref}/" | |
def _download_file(url: str, file_name: str) -> None: | |
"""Downloads the file with specified file_name from the url. | |
Args: | |
url: A url of download link. | |
file_name: A specified file name for the downloaded file. | |
""" | |
chunk_size = 16384 # 1024 * 16 | |
with urlopen(url) as response, open(file_name, "wb") as f: | |
# Loads processively with chuck_size for huge models | |
while True: | |
chunk = response.read(chunk_size) | |
if not chunk: | |
break | |
f.write(chunk) | |
def list_models( | |
repo: str = "onnx/models:main", | |
model: Optional[str] = None, | |
tags: Optional[List[str]] = None, | |
) -> List[ModelInfo]: | |
"""Gets the list of model info consistent with a given name and tags | |
Args: | |
repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
model: The name of the model to search for. If `None`, will | |
return all models with matching tags. | |
tags: A list of tags to filter models by. If `None`, will return | |
all models with matching name. | |
Returns: | |
``ModelInfo``s. | |
""" | |
base_url = _get_base_url(repo) | |
manifest_url = base_url + "ONNX_HUB_MANIFEST.json" | |
try: | |
with urlopen(manifest_url) as response: | |
manifest: List[ModelInfo] = [ | |
ModelInfo(info) for info in json.load(cast(IO[str], response)) | |
] | |
except HTTPError as e: | |
raise AssertionError(f"Could not find manifest at {manifest_url}") from e | |
# Filter by model name first. | |
matching_models = ( | |
manifest | |
if model is None | |
else [m for m in manifest if m.model.lower() == model.lower()] | |
) | |
# Filter by tags | |
if tags is None: | |
return matching_models | |
canonical_tags = {t.lower() for t in tags} | |
matching_info_list: List[ModelInfo] = [] | |
for m in matching_models: | |
model_tags = {t.lower() for t in m.tags} | |
if len(canonical_tags.intersection(model_tags)) > 0: | |
matching_info_list.append(m) | |
return matching_info_list | |
def get_model_info( | |
model: str, repo: str = "onnx/models:main", opset: Optional[int] = None | |
) -> ModelInfo: | |
"""Gets the model info matching the given name and opset. | |
Args: | |
model: The name of the onnx model in the manifest. This field is | |
case-sensitive | |
repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
opset: The opset of the model to get. The default of `None` will | |
return the model with largest opset. | |
Returns: | |
``ModelInfo``. | |
""" | |
matching_models = list_models(repo, model) | |
if not matching_models: | |
raise AssertionError(f"No models found with name {model}") | |
if opset is None: | |
selected_models = sorted(matching_models, key=lambda m: -m.opset) | |
else: | |
selected_models = [m for m in matching_models if m.opset == opset] | |
if not selected_models: | |
valid_opsets = [m.opset for m in matching_models] | |
raise AssertionError( | |
f"{model} has no version with opset {opset}. Valid opsets: {valid_opsets}" | |
) | |
return selected_models[0] | |
def load( | |
model: str, | |
repo: str = "onnx/models:main", | |
opset: Optional[int] = None, | |
force_reload: bool = False, | |
silent: bool = False, | |
) -> Optional[onnx.ModelProto]: | |
"""Downloads a model by name from the onnx model hub. | |
Args: | |
model: The name of the onnx model in the manifest. This field is | |
case-sensitive | |
repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
opset: The opset of the model to download. The default of `None` | |
automatically chooses the largest opset | |
force_reload: Whether to force the model to re-download even if | |
its already found in the cache | |
silent: Whether to suppress the warning message if the repo is | |
not trusted. | |
Returns: | |
ModelProto or None | |
""" | |
selected_model = get_model_info(model, repo, opset) | |
local_model_path_arr = selected_model.model_path.split("/") | |
if selected_model.model_sha is not None: | |
local_model_path_arr[ | |
-1 | |
] = f"{selected_model.model_sha}_{local_model_path_arr[-1]}" | |
local_model_path = join(_ONNX_HUB_DIR, os.sep.join(local_model_path_arr)) | |
if force_reload or not os.path.exists(local_model_path): | |
if not _verify_repo_ref(repo) and not silent: | |
msg = f"The model repo specification {repo} is not trusted and may contain security vulnerabilities. Only continue if you trust this repo." | |
print(msg, file=sys.stderr) | |
print("Continue?[y/n]") | |
if input().lower() != "y": | |
return None | |
os.makedirs(os.path.dirname(local_model_path), exist_ok=True) | |
lfs_url = _get_base_url(repo, True) | |
print(f"Downloading {model} to local path {local_model_path}") | |
_download_file(lfs_url + selected_model.model_path, local_model_path) | |
else: | |
print(f"Using cached {model} model from {local_model_path}") | |
with open(local_model_path, "rb") as f: | |
model_bytes = f.read() | |
if selected_model.model_sha is not None: | |
downloaded_sha = hashlib.sha256(model_bytes).hexdigest() | |
if not downloaded_sha == selected_model.model_sha: | |
raise AssertionError( | |
f"The cached model {selected_model.model} has SHA256 {downloaded_sha} " | |
f"while checksum should be {selected_model.model_sha}. " | |
"The model in the hub may have been updated. Use force_reload to " | |
"download the model from the model hub." | |
) | |
return onnx.load(cast(IO[bytes], BytesIO(model_bytes))) | |
def download_model_with_test_data( | |
model: str, | |
repo: str = "onnx/models:main", | |
opset: Optional[int] = None, | |
force_reload: bool = False, | |
silent: bool = False, | |
) -> Optional[str]: | |
"""Downloads a model along with test data by name from the onnx model hub and returns the directory to which the files have been extracted. | |
Args: | |
model: The name of the onnx model in the manifest. This field is | |
case-sensitive | |
repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
opset: The opset of the model to download. The default of `None` | |
automatically chooses the largest opset | |
force_reload: Whether to force the model to re-download even if | |
its already found in the cache | |
silent: Whether to suppress the warning message if the repo is | |
not trusted. | |
Returns: | |
str or None | |
""" | |
selected_model = get_model_info(model, repo, opset) | |
local_model_with_data_path_arr = selected_model.metadata[ | |
"model_with_data_path" | |
].split("/") | |
model_with_data_sha = selected_model.metadata["model_with_data_sha"] | |
if model_with_data_sha is not None: | |
local_model_with_data_path_arr[ | |
-1 | |
] = f"{model_with_data_sha}_{local_model_with_data_path_arr[-1]}" | |
local_model_with_data_path = join( | |
_ONNX_HUB_DIR, os.sep.join(local_model_with_data_path_arr) | |
) | |
if force_reload or not os.path.exists(local_model_with_data_path): | |
if not _verify_repo_ref(repo) and not silent: | |
msg = f"The model repo specification {repo} is not trusted and may contain security vulnerabilities. Only continue if you trust this repo." | |
print(msg, file=sys.stderr) | |
print("Continue?[y/n]") | |
if input().lower() != "y": | |
return None | |
os.makedirs(os.path.dirname(local_model_with_data_path), exist_ok=True) | |
lfs_url = _get_base_url(repo, True) | |
print(f"Downloading {model} to local path {local_model_with_data_path}") | |
_download_file( | |
lfs_url + selected_model.metadata["model_with_data_path"], | |
local_model_with_data_path, | |
) | |
else: | |
print(f"Using cached {model} model from {local_model_with_data_path}") | |
with open(local_model_with_data_path, "rb") as f: | |
model_with_data_bytes = f.read() | |
if model_with_data_sha is not None: | |
downloaded_sha = hashlib.sha256(model_with_data_bytes).hexdigest() | |
if not downloaded_sha == model_with_data_sha: | |
raise AssertionError( | |
f"The cached model {selected_model.model} has SHA256 {downloaded_sha} " | |
f"while checksum should be {model_with_data_sha}. " | |
"The model in the hub may have been updated. Use force_reload to " | |
"download the model from the model hub." | |
) | |
with tarfile.open(local_model_with_data_path) as model_with_data_zipped: | |
# FIXME: Avoid index manipulation with magic numbers | |
local_model_with_data_dir_path = local_model_with_data_path[ | |
0 : len(local_model_with_data_path) - 7 | |
] | |
model_with_data_zipped.extractall(local_model_with_data_dir_path) | |
model_with_data_path = ( | |
local_model_with_data_dir_path | |
+ "/" | |
+ os.listdir(local_model_with_data_dir_path)[0] | |
) | |
return model_with_data_path | |
def load_composite_model( | |
network_model: str, | |
preprocessing_model: str, | |
network_repo: str = "onnx/models:main", | |
preprocessing_repo: str = "onnx/models:main", | |
opset: Optional[int] = None, | |
force_reload: bool = False, | |
silent: bool = False, | |
) -> Optional[onnx.ModelProto]: | |
"""Builds a composite model including data preprocessing by downloading a network and a preprocessing model | |
and combine it into a single model | |
Args: | |
network_model: The name of the onnx model in the manifest. | |
preprocessing_model: The name of the preprocessing model. | |
network_repo: The location of the model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
preprocessing_repo: The location of the proprocessing model repo in format | |
"user/repo[:branch]". If no branch is found will default to | |
"main" | |
opset: The opset of the model to download. The default of `None` | |
automatically chooses the largest opset | |
force_reload: Whether to force the model to re-download even if | |
its already found in the cache | |
silent: Whether to suppress the warning message if the repo is | |
not trusted. | |
Returns: | |
ModelProto or None | |
""" | |
preprocessing = load( | |
preprocessing_model, preprocessing_repo, opset, force_reload, silent | |
) | |
if preprocessing is None: | |
raise RuntimeError( | |
f"Could not load the preprocessing model: {preprocessing_model}" | |
) | |
network = load(network_model, network_repo, opset, force_reload, silent) | |
if network is None: | |
raise RuntimeError(f"Could not load the network model: {network_model}") | |
all_domains: Set[str] = set() | |
domains_to_version_network: Dict[str, int] = {} | |
domains_to_version_preprocessing: Dict[str, int] = {} | |
for opset_import_entry in network.opset_import: | |
domain = ( | |
"ai.onnx" if opset_import_entry.domain == "" else opset_import_entry.domain | |
) | |
all_domains.add(domain) | |
domains_to_version_network[domain] = opset_import_entry.version | |
for opset_import_entry in preprocessing.opset_import: | |
domain = ( | |
"ai.onnx" if opset_import_entry.domain == "" else opset_import_entry.domain | |
) | |
all_domains.add(domain) | |
domains_to_version_preprocessing[domain] = opset_import_entry.version | |
preprocessing_opset_version = -1 | |
network_opset_version = -1 | |
for domain in all_domains: | |
if domain == "ai.onnx": | |
preprocessing_opset_version = domains_to_version_preprocessing[domain] | |
network_opset_version = domains_to_version_network[domain] | |
elif ( | |
domain in domains_to_version_preprocessing | |
and domain in domains_to_version_network | |
and domains_to_version_preprocessing[domain] | |
!= domains_to_version_preprocessing[domain] | |
): | |
raise ValueError( | |
f"Can not merge {preprocessing_model} and {network_model} because they contain " | |
f"different opset versions for domain {domain} ({domains_to_version_preprocessing[domain]}) " | |
f"and {domains_to_version_network[domain]}). Only the default domain can be " | |
"automatically converted to the highest version of the two." | |
) | |
if preprocessing_opset_version > network_opset_version: | |
network = onnx.version_converter.convert_version( | |
network, preprocessing_opset_version | |
) | |
network.ir_version = preprocessing.ir_version | |
onnx.checker.check_model(network) | |
elif network_opset_version > preprocessing_opset_version: | |
preprocessing = onnx.version_converter.convert_version( | |
preprocessing, network_opset_version | |
) | |
preprocessing.ir_version = network.ir_version | |
onnx.checker.check_model(preprocessing) | |
io_map = [ | |
(out_entry.name, in_entry.name) | |
for out_entry, in_entry in zip(preprocessing.graph.output, network.graph.input) | |
] | |
model_with_preprocessing = onnx.compose.merge_models( | |
preprocessing, network, io_map=io_map | |
) | |
return model_with_preprocessing | |