stanslausmwongela commited on
Commit
424ae24
1 Parent(s): e4eab97

Added the prediction endpoint

Browse files
Files changed (1) hide show
  1. app/predict.py +154 -0
app/predict.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import heapq
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
+ from torch.utils.data import TensorDataset, DataLoader
11
+
12
+
13
+ class Preprocess:
14
+ def __init__(self, tokenizer_vocab_path, tokenizer_max_len):
15
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_vocab_path,
16
+ use_auth_token='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')
17
+ self.max_len = tokenizer_max_len
18
+
19
+ def clean_text(self, text):
20
+ text = text.lower()
21
+ stopwords = ["i", "was", "transferred",
22
+ "from", "to", "nilienda", "kituo",
23
+ "cha", "lakini", "saa", "hii", "niko",
24
+ "at", "nilienda", "nikahudumiwa", "pole",
25
+ "deliver", "na", "ni", "baada", "ya",
26
+ "kutumwa", "kutoka", "nilienda",
27
+ "ndipo", "nikapewa", "hiyo", "lindam ama", "nikawa",
28
+ "mgonjwa", "nikatibiwa", "in", "had", "a",
29
+ "visit", "gynaecologist", "ndio",
30
+ "karibu", "mimi", "niko", "sehemu", "hospitali",
31
+ "serikali", "delivered", "katika", "kaunti", "kujifungua",
32
+ "katika", "huko", "nilipoenda", "kwa", "bado", "naedelea",
33
+ "sija", "maliza", "mwisho",
34
+ "nilianza", "kliniki", "yangu",
35
+ "nilianzia", "nilijifungua"]
36
+ text_single = ' '.join(word for word in text.split() if word not in stopwords)
37
+ return text_single
38
+
39
+ def encode_fn(self, text_single):
40
+ """
41
+ Using tokenizer to preprocess the text
42
+ example of text_single:'Nairobi Hospital'
43
+ """
44
+ tokenizer = self.tokenizer(text_single,
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=self.max_len,
48
+ return_tensors='pt'
49
+ )
50
+ input_ids = tokenizer['input_ids']
51
+ attention_mask = tokenizer['attention_mask']
52
+ return input_ids, attention_mask
53
+
54
+ def process_tokenizer(self, text_single):
55
+ """
56
+ Preprocess text and prepare dataloader for a single new sentence
57
+ """
58
+ input_ids, attention_mask = self.encode_fn(text_single)
59
+ data = TensorDataset(input_ids, attention_mask)
60
+ return data
61
+
62
+
63
+ class Facility_Model:
64
+ def __init__(self, facility_model_path: any,
65
+ max_len: int):
66
+ self.max_len = max_len
67
+ self.softmax = torch.nn.Softmax(dim=1)
68
+ self.gpu = False
69
+ self.model = AutoModelForSequenceClassification.from_pretrained(facility_model_path,
70
+ use_auth_token='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')
71
+ self.model.eval() # set pytorch model for inference mode
72
+
73
+ if torch.cuda.device_count() > 1:
74
+ self.model = torch.nn.DataParallel(self.model)
75
+
76
+ if self.gpu:
77
+ seed = 42
78
+ random.seed(seed)
79
+ np.random.seed(seed)
80
+ torch.manual_seed(seed)
81
+ torch.cuda.manual_seed_all(seed)
82
+ torch.backends.cudnn.deterministic = True
83
+ self.device = torch.device('cuda')
84
+ else:
85
+ self.device = 'cpu'
86
+
87
+ self.model = self.model.to(self.device)
88
+
89
+ def predict_single(self, model, pred_data):
90
+ """
91
+ Model inference for new single sentence
92
+ """
93
+ pred_dataloader = DataLoader(pred_data, batch_size=10, shuffle=False)
94
+ for i, batch in enumerate(pred_dataloader):
95
+ with torch.no_grad():
96
+ outputs = model(input_ids=batch[0].to(self.device),
97
+ attention_mask=batch[1].to(self.device)
98
+ )
99
+ loss, logits = outputs.loss, outputs.logits
100
+ probability = self.softmax(logits)
101
+ probability_list = probability.detach().cpu().numpy()
102
+ return probability_list
103
+
104
+ def output_intent_probability(self, pred: any) -> dict:
105
+ """
106
+ convert the model output into a dictionary with all intents and its probability
107
+ """
108
+ output_dict = {}
109
+ # transform the relation table(between label and intent)
110
+ path_table = pd.read_csv('/content/drive/MyDrive/dhis14000/dhis_label_relation_14357.csv')
111
+
112
+ label_intent_dict = path_table[["label", "corresponding_label"]].set_index("corresponding_label").to_dict()[
113
+ 'label']
114
+
115
+ # transform the output into dictionary(between intent and probability)
116
+ for intent in range(pred.shape[1]):
117
+ output_dict[label_intent_dict[intent]] = pred[0][intent]
118
+
119
+ return output_dict
120
+
121
+ def inference(self, prepared_data):
122
+ """
123
+ Make predictions on one new sentence and output a JSON format variable
124
+ """
125
+ temp = []
126
+ prob_distribution = self.predict_single(self.model, prepared_data)
127
+ prediction_results = self.output_intent_probability(prob_distribution.astype(float))
128
+
129
+ # Filter out predictions containing "dental" or "optical" keywords
130
+ filtered_results = {intent: prob for intent, prob in prediction_results.items()
131
+ if
132
+ "dental" not in intent.lower() and "optical" not in intent.lower() and "eye" not in intent.lower()}
133
+
134
+ sorted_pred_intent_results = sorted(filtered_results.items(), key=lambda x: x[1], reverse=True)
135
+ sorted_pred_intent_results_dict = dict(sorted_pred_intent_results)
136
+ # Return the top result
137
+ top_results = dict(list(sorted_pred_intent_results)[:4])
138
+ temp.append(top_results)
139
+ final_preds = json.dumps(temp)
140
+ #final_preds = ', '.join(top_results.keys())
141
+ #final_preds = ', '.join(top_results)
142
+ # final_preds = final_preds.replace("'", "")
143
+ return final_preds
144
+
145
+
146
+ jacaranda_hugging_face_model = "Jacaranda/dhis_14000_600k_Test_Model"
147
+
148
+ obj_Facility_Model = Facility_Model(facility_model_path=jacaranda_hugging_face_model,
149
+ max_len=128
150
+ )
151
+
152
+ processor = Preprocess(tokenizer_vocab_path=jacaranda_hugging_face_model,
153
+ tokenizer_max_len=128
154
+ )