ONNX
English
File size: 4,273 Bytes
1d52dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
inference_safetensors.py

Defines the architecture of the fine-tuned embedding model used for Off-Topic classification.
"""
import json
import torch
import sys
import torch.nn as nn

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModel

class CrossEncoderWithMLP(nn.Module):
    def __init__(self, base_model, num_labels=2):
        super(CrossEncoderWithMLP, self).__init__()

        # Existing cross-encoder model
        self.base_model = base_model
        # Hidden size of the base model
        hidden_size = base_model.config.hidden_size
        # MLP layers after combining the cross-encoders
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),  # Input: a single sentence
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),  # Reduce the size of the layer
            nn.ReLU()
        )
        # Classifier head
        self.classifier = nn.Linear(hidden_size // 4, num_labels)

    def forward(self, input_ids, attention_mask):
        # Encode the pair of sentences in one pass
        outputs = self.base_model(input_ids, attention_mask)
        pooled_output = outputs.pooler_output
        # Pass the pooled output through mlp layers
        mlp_output = self.mlp(pooled_output)
        # Pass the final MLP output through the classifier
        logits = self.classifier(mlp_output)
        return logits

# Load configuration file
repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
#config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
config_path = "config.json"

with open(config_path, 'r') as f:
    config = json.load(f)

def predict(sentence1, sentence2):
    """
    Predicts the label for a pair of sentences using a fine-tuned model with SafeTensors weights.

    Args:
    - sentence1 (str): The first input sentence.
    - sentence2 (str): The second input sentence.

    Returns:
    tuple:
        - predicted_label (int): The predicted label (e.g., 0 or 1).
        - probabilities (numpy.ndarray): The probabilities for each class.
    """
    # Load model configuration
    model_name = config['classifier']['embedding']['model_name']
    max_length = config['classifier']['embedding']['max_length']
    model_weights_fp = config['classifier']['embedding']['model_weights_fp']

    # Load tokenizer and base model
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = AutoModel.from_pretrained(model_name)
    model = CrossEncoderWithMLP(base_model, num_labels=2)

    # Load weights into the model
    weights = load_file(model_weights_fp)
    model.load_state_dict(weights)
    model.to(device)
    model.eval()

    # Get inputs
    encoding = tokenizer(
        sentence1, sentence2,  # Takes in a two sentences as a pair
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_token_type_ids=False
    )
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    # Get outputs
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()

    return predicted_label, probabilities.cpu().numpy()

if __name__ == "__main__":
    # Load data
    input_data = sys.argv[1]
    sentence_pairs = json.loads(input_data)

    # Validate input data format
    if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs):
        raise ValueError("Each pair must contain two strings.")

    for idx, (sentence1, sentence2) in enumerate(sentence_pairs):

        # Generate prediction and scores
        predicted_label, probabilities = predict(sentence1, sentence2)

        # Print the results
        print(f"Pair {idx + 1}:")
        print(f"  Sentence 1: {sentence1}")
        print(f"  Sentence 2: {sentence2}")
        print(f"  Predicted Label: {predicted_label}")
        print(f"  Probabilities: {probabilities}")
        print('-' * 50)