not-lain commited on
Commit
297b879
1 Parent(s): 17ad6cf

Update tunBertClassificationPipeline.py

Browse files
Files changed (1) hide show
  1. tunBertClassificationPipeline.py +6 -6
tunBertClassificationPipeline.py CHANGED
@@ -3,18 +3,18 @@ import torch
3
 
4
  class TBCP(Pipeline):
5
  def _sanitize_parameters(self, **kwargs):
6
- preprocess_kwargs = {}
7
  if "text_pair" in kwargs:
8
- preprocess_kwargs["text_pair"] = kwargs["text_pair"]
9
- return preprocess_kwargs, {}, {}
10
 
11
- def preprocess(self, text, text_pair=None):
12
- return self.tokenizer(text, text_pair=text_pair, return_tensors="pt")
13
 
14
  def _forward(self, model_inputs):
15
  return self.model(**model_inputs)
16
 
17
- def postprocess(self, model_outputs):
18
  logits = model_outputs.logits
19
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
20
 
 
3
 
4
  class TBCP(Pipeline):
5
  def _sanitize_parameters(self, **kwargs):
6
+ postprocess_kwargs = {}
7
  if "text_pair" in kwargs:
8
+ postprocess_kwargs["top_k"] = kwargs["top_k"]
9
+ return {}, {}, postprocess_kwargs
10
 
11
+ def preprocess(self, text):
12
+ return self.tokenizer(text, return_tensors="pt")
13
 
14
  def _forward(self, model_inputs):
15
  return self.model(**model_inputs)
16
 
17
+ def postprocess(self, model_outputs,top_k = None):
18
  logits = model_outputs.logits
19
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
20