not-lain commited on
Commit
17ad6cf
1 Parent(s): b335c9a

Create tunBertClassificationPipeline.py

Browse files
Files changed (1) hide show
  1. tunBertClassificationPipeline.py +25 -0
tunBertClassificationPipeline.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch
3
+
4
+ class TBCP(Pipeline):
5
+ def _sanitize_parameters(self, **kwargs):
6
+ preprocess_kwargs = {}
7
+ if "text_pair" in kwargs:
8
+ preprocess_kwargs["text_pair"] = kwargs["text_pair"]
9
+ return preprocess_kwargs, {}, {}
10
+
11
+ def preprocess(self, text, text_pair=None):
12
+ return self.tokenizer(text, text_pair=text_pair, return_tensors="pt")
13
+
14
+ def _forward(self, model_inputs):
15
+ return self.model(**model_inputs)
16
+
17
+ def postprocess(self, model_outputs):
18
+ logits = model_outputs.logits
19
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
20
+
21
+ best_class = probabilities.argmax().item()
22
+ label = self.model.config.id2label[best_class]
23
+ score = probabilities.squeeze()[best_class].item()
24
+ logits = logits.squeeze().tolist()
25
+ return {"label": label, "score": score, "logits": logits}