ONNX
English
stsb-roberta-base-off-topic / inference_safetensors.py
Shing Yee
feat: add files
1d52dd2 unverified
"""
inference_safetensors.py
Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
"""
import json
import torch
import sys
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel
class CrossEncoderWithMLP(nn.Module):
def __init__(self, base_model, num_labels=2):
super(CrossEncoderWithMLP, self).__init__()
# Existing cross-encoder model
self.base_model = base_model
# Hidden size of the base model
hidden_size = base_model.config.hidden_size
# MLP layers after combining the cross-encoders
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence
nn.ReLU(),
nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer
nn.ReLU()
)
# Classifier head
self.classifier = nn.Linear(hidden_size // 4, num_labels)
def forward(self, input_ids, attention_mask):
# Encode the pair of sentences in one pass
outputs = self.base_model(input_ids, attention_mask)
pooled_output = outputs.pooler_output
# Pass the pooled output through mlp layers
mlp_output = self.mlp(pooled_output)
# Pass the final MLP output through the classifier
logits = self.classifier(mlp_output)
return logits
# Load configuration file
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
#config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"
with open(config_path, 'r') as f:
config = json.load(f)
def predict(sentence1, sentence2):
"""
Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.
Args:
- sentence1 (str): The first input sentence.
- sentence2 (str): The second input sentence.
Returns:
tuple:
- predicted_label (int): The predicted label (e.g., 0 or 1).
- probabilities (numpy.ndarray): The probabilities for each class.
"""
# Load model configuration
model_name = config['classifier']['embedding']['model_name']
max_length = config['classifier']['embedding']['max_length']
model_weights_fp = config['classifier']['embedding']['model_weights_fp']
# Load tokenizer and base model
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name)
model = CrossEncoderWithMLP(base_model, num_labels=2)
# Load weights into the model
weights = load_file(model_weights_fp)
model.load_state_dict(weights)
model.to(device)
model.eval()
# Get inputs
encoding = tokenizer(
sentence1, sentence2, # Takes in a two sentences as a pair
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length,
return_token_type_ids=False
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
# Get outputs
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probabilities = torch.softmax(outputs, dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
return predicted_label, probabilities.cpu().numpy()
if __name__ == "__main__":
# Load data
input_data = sys.argv[1]
sentence_pairs = json.loads(input_data)
# Validate input data format
if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
raise ValueError("Each pair must contain two strings.")
for idx, (sentence1, sentence2) in enumerate(sentence_pairs):
# Generate prediction and scores
predicted_label, probabilities = predict(sentence1, sentence2)
# Print the results
print(f"Pair {idx + 1}:")
print(f" Sentence 1: {sentence1}")
print(f" Sentence 2: {sentence2}")
print(f" Predicted Label: {predicted_label}")
print(f" Probabilities: {probabilities}")
print('-' * 50)