samyak152002 commited on
Commit
02fd376
1 Parent(s): bdf2544

Create script

Browse files
Files changed (1) hide show
  1. script +26 -0
script ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizer, DistilBertModel
3
+
4
+ # Load the tokenizer and model
5
+ tokenizer = DistilBertTokenizer.from_pretrained("./directory")
6
+ model = DistilBertModel.from_pretrained("./directory")
7
+
8
+ # Define the inference function
9
+ def predict(text):
10
+ # Tokenize the input
11
+ inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
12
+
13
+ # Perform the inference
14
+ with torch.no_grad():
15
+ outputs = model(**inputs)
16
+ logits = outputs.logits
17
+
18
+ # Convert logits to probabilities
19
+ probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
20
+
21
+ return probabilities
22
+
23
+ # Example usage
24
+ text = "This is a sample input."
25
+ probabilities = predict(text)
26
+ print(probabilities)