""" inference_safetensors.py Defines the architecture of the fine-tuned embedding model used for Off-Topic classification. """ import json import torch import sys import torch.nn as nn from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import AutoTokenizer, AutoModel class CrossEncoderWithMLP(nn.Module): def __init__(self, base_model, num_labels=2): super(CrossEncoderWithMLP, self).__init__() # Existing cross-encoder model self.base_model = base_model # Hidden size of the base model hidden_size = base_model.config.hidden_size # MLP layers after combining the cross-encoders self.mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), # Input: a single sentence nn.ReLU(), nn.Linear(hidden_size // 2, hidden_size // 4), # Reduce the size of the layer nn.ReLU() ) # Classifier head self.classifier = nn.Linear(hidden_size // 4, num_labels) def forward(self, input_ids, attention_mask): # Encode the pair of sentences in one pass outputs = self.base_model(input_ids, attention_mask) pooled_output = outputs.pooler_output # Pass the pooled output through mlp layers mlp_output = self.mlp(pooled_output) # Pass the final MLP output through the classifier logits = self.classifier(mlp_output) return logits # Load configuration file repo_path = "govtech/jina-embeddings-v2-small-en-off-topic" #config_path = hf_hub_download(repo_id=repo_path, filename="config.json") config_path = "config.json" with open(config_path, 'r') as f: config = json.load(f) def predict(sentence1, sentence2): """ Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights. Args: - sentence1 (str): The first input sentence. - sentence2 (str): The second input sentence. Returns: tuple: - predicted_label (int): The predicted label (e.g., 0 or 1). - probabilities (numpy.ndarray): The probabilities for each class. """ # Load model configuration model_name = config['classifier']['embedding']['model_name'] max_length = config['classifier']['embedding']['max_length'] model_weights_fp = config['classifier']['embedding']['model_weights_fp'] # Load tokenizer and base model device = torch.device("cuda") if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) base_model = AutoModel.from_pretrained(model_name) model = CrossEncoderWithMLP(base_model, num_labels=2) # Load weights into the model weights = load_file(model_weights_fp) model.load_state_dict(weights) model.to(device) model.eval() # Get inputs encoding = tokenizer( sentence1, sentence2, # Takes in a two sentences as a pair return_tensors="pt", truncation=True, padding="max_length", max_length=max_length, return_token_type_ids=False ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) # Get outputs with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.softmax(outputs, dim=1) predicted_label = torch.argmax(probabilities, dim=1).item() return predicted_label, probabilities.cpu().numpy() if __name__ == "__main__": # Load data input_data = sys.argv[1] sentence_pairs = json.loads(input_data) # Validate input data format if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs): raise ValueError("Each pair must contain two strings.") for idx, (sentence1, sentence2) in enumerate(sentence_pairs): # Generate prediction and scores predicted_label, probabilities = predict(sentence1, sentence2) # Print the results print(f"Pair {idx + 1}:") print(f" Sentence 1: {sentence1}") print(f" Sentence 2: {sentence2}") print(f" Predicted Label: {predicted_label}") print(f" Probabilities: {probabilities}") print('-' * 50)