File size: 762 Bytes
02fd376
 
 
 
bc9b616
 
02fd376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from transformers import DistilBertTokenizer, DistilBertModel

# Load the tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("tokenizer_config.json")
model = DistilBertModel.from_pretrained("pytorch_model.bin")

# Define the inference function
def predict(text):
    # Tokenize the input
    inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")

    # Perform the inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Convert logits to probabilities
    probabilities = torch.softmax(logits, dim=1).squeeze().tolist()

    return probabilities

# Example usage
text = "This is a sample input."
probabilities = predict(text)
print(probabilities)