Spaces:
Build error
Build error
File size: 1,626 Bytes
5285b72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
#!/usr/bin/env python
from transformers import AutoModel, AutoTokenizer
import torch
import argparse
import os
def export_model(model_id, output_dir):
if not os.path.exists(output_dir):
print(f"Output directory '{output_dir}' does not exist")
return
embedder = AutoModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_names = ["input_ids", "attention_mask", "token_type_ids"]
output_names = ["last_hidden_state"]
input_ids = torch.ones(1, 32, dtype=torch.int64)
attention_mask = torch.ones(1, 32, dtype=torch.int64)
token_type_ids = torch.zeros(1, 32, dtype=torch.int64)
args = (input_ids, attention_mask, token_type_ids)
f = os.path.join(output_dir, "model.onnx")
print(f"Exporting onnx model to {f}")
torch.onnx.export(
embedder,
args=args,
f=f,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "dyn"},
"attention_mask": {0: "batch_size", 1: "dyn"},
"token_type_ids": {0: "batch_size", 1: "dyn"},
"last_hidden_state": {0: "batch_size", 1: "dyn"},
},
opset_version=14,
)
tokenizer.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--hf_model", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
export_model(args.hf_model, args.output_dir)
if __name__ == "__main__":
main()
|