PerryCheng614 commited on
Commit
8cb1299
1 Parent(s): 5ef88f3

Upload inference script

Browse files
Files changed (1) hide show
  1. bert_inference.py +63 -0
bert_inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+ class BertInference:
5
+ def __init__(self, model_path):
6
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
8
+ self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
9
+ # self.tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
10
+ self.label_map = {
11
+ 0: "query_with_pdf",
12
+ 1: "summarize_pdf",
13
+ 2: "query_metadata"
14
+ }
15
+
16
+ def predict(self, text):
17
+ # Tokenize
18
+ inputs = self.tokenizer(
19
+ text,
20
+ return_tensors="pt"
21
+ ).to(self.device)
22
+
23
+ # Get prediction
24
+ with torch.no_grad():
25
+ outputs = self.model(**inputs)
26
+ predictions = torch.softmax(outputs.logits, dim=1)
27
+ predicted_class = torch.argmax(predictions, dim=1).item()
28
+ confidence = predictions[0][predicted_class].item()
29
+
30
+ return {
31
+ "predicted_class": self.label_map[predicted_class],
32
+ "confidence": confidence,
33
+ "all_probabilities": {
34
+ self.label_map[i]: prob.item()
35
+ for i, prob in enumerate(predictions[0])
36
+ }
37
+ }
38
+
39
+ def main():
40
+ # Initialize the model
41
+ model_path = "output_dir_decision" # Path to your saved model
42
+ # model_path = "output_xlm_roberta_bert"
43
+ inferencer = BertInference(model_path)
44
+
45
+ # Example usage
46
+ test_questions = [
47
+ "What is television",
48
+ "What is the summary",
49
+ "What is GPU",
50
+ "What is the title of this pdf?"
51
+ ]
52
+
53
+ for question in test_questions:
54
+ result = inferencer.predict(question)
55
+ print(f"\nQuestion: {question}")
56
+ print(f"Predicted Class: {result['predicted_class']}")
57
+ print(f"Confidence: {result['confidence']:.4f}")
58
+ print("All Probabilities:")
59
+ for class_name, prob in result['all_probabilities'].items():
60
+ print(f" {class_name}: {prob:.4f}")
61
+
62
+ if __name__ == "__main__":
63
+ main()