menimeni123
commited on
Commit
•
60fbaa9
1
Parent(s):
f311c70
entrophy method
Browse files- handler.py +30 -9
handler.py
CHANGED
@@ -18,17 +18,38 @@ class EndpointHandler:
|
|
18 |
return self.predict(inputs)
|
19 |
|
20 |
def predict(self, text):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
with torch.no_grad():
|
26 |
outputs = self.model(**encoded_input)
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def get_pipeline():
|
34 |
return EndpointHandler
|
|
|
18 |
return self.predict(inputs)
|
19 |
|
20 |
def predict(self, text):
|
21 |
+
# Tokenize and encode the input
|
22 |
+
encoded_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
|
23 |
+
|
24 |
+
# Get model prediction
|
25 |
with torch.no_grad():
|
26 |
outputs = self.model(**encoded_input)
|
27 |
+
logits = outputs.logits
|
28 |
+
|
29 |
+
# Get probabilities
|
30 |
+
probabilities = F.softmax(logits, dim=-1).squeeze().numpy()
|
31 |
+
|
32 |
+
# Get predicted class and confidence
|
33 |
+
predicted_class_idx = np.argmax(probabilities)
|
34 |
+
predicted_label = self.labels[predicted_class_idx]
|
35 |
+
confidence = probabilities[predicted_class_idx]
|
36 |
+
|
37 |
+
# Additional analysis
|
38 |
+
entropy = -np.sum(probabilities * np.log(probabilities + 1e-9))
|
39 |
+
max_prob_ratio = np.max(probabilities) / np.sort(probabilities)[-2]
|
40 |
+
|
41 |
+
# Adjust confidence based on entropy and probability ratio
|
42 |
+
adjusted_confidence = confidence * (1 - entropy/np.log(len(probabilities))) * max_prob_ratio
|
43 |
+
|
44 |
+
# Lower the confidence for very short inputs
|
45 |
+
if len(text.split()) < 4:
|
46 |
+
adjusted_confidence *= 0.5
|
47 |
+
|
48 |
+
return {
|
49 |
+
"label": predicted_label,
|
50 |
+
"score": float(adjusted_confidence),
|
51 |
+
"raw_scores": {label: float(prob) for label, prob in zip(self.labels.values(), probabilities)}
|
52 |
+
}
|
53 |
|
54 |
def get_pipeline():
|
55 |
return EndpointHandler
|