ONNX
English
stsb-roberta-base-off-topic / inference_onnx.py
Shing Yee
feat: add files
1d52dd2 unverified
raw
history blame
2.66 kB
"""
inference_onnx.py
This script leverages ONNX runtime to perform inference with a pre-trained model.
"""
import json
import torch
import sys
import numpy as np
import onnxruntime as rt
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
repo_path = "govtech/stsb-roberta-base-off-topic"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"
with open(config_path, 'r') as f:
config = json.load(f)
def predict(sentence1, sentence2):
# Configuration
model_name = config['classifier']['embedding']['model_name']
max_length = config['classifier']['embedding']['max_length']
model_fp = config['classifier']['embedding']['model_fp']
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Get inputs
encoding = tokenizer(
sentence1, sentence2, # Takes in a two sentences as a pair
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length,
return_token_type_ids=False
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Download the classifier from HuggingFace hub
local_model_fp = model_fp
local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp)
# Run inference
session = rt.InferenceSession(local_model_fp) # Load the ONNX model
onnx_inputs = {
session.get_inputs()[0].name: input_ids.cpu().numpy(),
session.get_inputs()[1].name: attention_mask.cpu().numpy()
}
outputs = session.run(None, onnx_inputs)
probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
if __name__ == "__main__":
# Load data
input_data = sys.argv[1]
sentence_pairs = json.loads(input_data)
# Validate input data format
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
raise ValueError("Each pair must contain two strings.")
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
# Generate prediction and scores
predicted_label, probabilities = predict(sentence1, sentence2)
# Print the results
print(f"Pair {idx + 1}:")
print(f" Sentence 1: {sentence1}")
print(f" Sentence 2: {sentence2}")
print(f" Predicted Label: {predicted_label}")
print(f" Probabilities: {probabilities}")
print('-' * 50)