from optimum.exporters.tasks import TasksManager from optimum.exporters.onnx import OnnxConfigWithPast, export, validate_model_outputs from tempfile import TemporaryDirectory from transformers import AutoConfig, AutoTokenizer, is_torch_available from pathlib import Path import os import shutil import argparse from typing import Optional, Tuple, List from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions from huggingface_hub.file_download import repo_folder_name def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]: try: discussions = api.get_repo_discussions(repo_id=model_id) except Exception: return None for discussion in discussions: if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title: return discussion def convert_onnx(model_id: str, task: str, folder: str) -> List: # Allocate the model model = TasksManager.get_model_from_task(task, model_id, framework="pt") model_type = model.config.model_type.replace("_", "-") model_name = getattr(model, "name", None) onnx_config_constructor = TasksManager.get_exporter_config_constructor( model_type, "onnx", task=task, model_name=model_name ) onnx_config = onnx_config_constructor(model.config) needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None and task in ["sequence_classification"] ) if needs_pad_token_id: #if args.pad_token_id is not None: # model.config.pad_token_id = args.pad_token_id try: tok = AutoTokenizer.from_pretrained(model_id) model.config.pad_token_id = tok.pad_token_id except Exception: raise ValueError( "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" ) # Ensure the requested opset is sufficient opset = onnx_config.DEFAULT_ONNX_OPSET output = Path(folder).joinpath("model.onnx") onnx_inputs, onnx_outputs = export( model, onnx_config, opset, output, ) atol = onnx_config.ATOL_FOR_VALIDATION if isinstance(atol, dict): atol = atol[task.replace("-with-past", "")] try: validate_model_outputs(onnx_config, model, output, onnx_outputs, atol) print(f"All good, model saved at: {output}") except ValueError: print(f"An error occured, but the model was saved at: {args.output.as_posix()}") n_files = len([name for name in os.listdir(folder) if os.path.isfile(os.path.join(folder, name)) and not name.startswith(".")]) if n_files == 1: operations = [CommitOperationAdd(path_in_repo=file_name, path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)] else: operations = [CommitOperationAdd(path_in_repo=os.path.join("onnx", file_name), path_or_fileobj=os.path.join(folder, file_name)) for file_name in os.listdir(folder)] return operations def convert(api: "HfApi", model_id: str, task: str, force: bool = False) -> Tuple[int, "CommitInfo"]: pr_title = "Adding ONNX file of this model" info = api.model_info(model_id) filenames = set(s.rfilename for s in info.siblings) if task == "auto": try: task = TasksManager.infer_task_from_model(model_id) except Exception as e: return f"### Error: {e}. Please pass explicitely the task as it could not be infered.", None with TemporaryDirectory() as d: folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) os.makedirs(folder) new_pr = None try: pr = previous_pr(api, model_id, pr_title) if "model.onnx" in filenames and not force: raise Exception(f"Model {model_id} is already converted, skipping..") elif pr is not None and not force: url = f"https://huggingface.co/{model_id}/discussions/{pr.num}" new_pr = pr raise Exception(f"Model {model_id} already has an open PR check out {url}") else: operations = convert_onnx(model_id, task, folder) new_pr = api.create_commit( repo_id=model_id, operations=operations, commit_message=pr_title, create_pr=True, ) finally: shutil.rmtree(folder) return "0", new_pr if __name__ == "__main__": DESCRIPTION = """ Simple utility tool to convert automatically a model on the hub to onnx format. It is PyTorch exclusive for now. It works by downloading the weights (PT), converting them locally, and uploading them back as a PR on the hub. """ parser = argparse.ArgumentParser(description=DESCRIPTION) parser.add_argument( "--model_id", type=str, help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", ) parser.add_argument( "--task", type=str, help="The task the model is performing", ) parser.add_argument( "--force", action="store_true", help="Create the PR even if it already exists of if the model was already converted.", ) args = parser.parse_args() api = HfApi() convert(api, args.model_id, task=args.task, force=args.force)