Spaces:
Build error
Build error
#!/usr/bin/env python | |
from transformers import AutoModel, AutoTokenizer | |
import torch | |
import argparse | |
import os | |
def export_model(model_id, output_dir): | |
if not os.path.exists(output_dir): | |
print(f"Output directory '{output_dir}' does not exist") | |
return | |
embedder = AutoModel.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
input_names = ["input_ids", "attention_mask", "token_type_ids"] | |
output_names = ["last_hidden_state"] | |
input_ids = torch.ones(1, 32, dtype=torch.int64) | |
attention_mask = torch.ones(1, 32, dtype=torch.int64) | |
token_type_ids = torch.zeros(1, 32, dtype=torch.int64) | |
args = (input_ids, attention_mask, token_type_ids) | |
f = os.path.join(output_dir, "model.onnx") | |
print(f"Exporting onnx model to {f}") | |
torch.onnx.export( | |
embedder, | |
args=args, | |
f=f, | |
do_constant_folding=True, | |
input_names=input_names, | |
output_names=output_names, | |
dynamic_axes={ | |
"input_ids": {0: "batch_size", 1: "dyn"}, | |
"attention_mask": {0: "batch_size", 1: "dyn"}, | |
"token_type_ids": {0: "batch_size", 1: "dyn"}, | |
"last_hidden_state": {0: "batch_size", 1: "dyn"}, | |
}, | |
opset_version=14, | |
) | |
tokenizer.save_pretrained(output_dir) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--hf_model", type=str, required=True) | |
parser.add_argument("--output_dir", type=str, required=True) | |
args = parser.parse_args() | |
export_model(args.hf_model, args.output_dir) | |
if __name__ == "__main__": | |
main() | |