TunBERT / tunBertClassificationPipeline.py
not-lain's picture
commit files to HF hub
e4ca3a6
raw
history blame
1.43 kB
from transformers import Pipeline, AutoModelForSequenceClassification,AutoTokenizer
import torch
from transformers.pipelines import PIPELINE_REGISTRY
class TBCP(Pipeline):
def __init__(self,**kwargs):
Pipeline.__init__(self,**kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(kwargs["tokenizer"])
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "text_pair" in kwargs:
postprocess_kwargs["top_k"] = kwargs["top_k"]
return {}, {}, postprocess_kwargs
def preprocess(self, text):
return self.tokenizer(text, return_tensors="pt")
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs,top_k = None):
logits = model_outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
best_class = probabilities.argmax().item()
label = f"Label_{best_class}"
# score = probabilities.squeeze()[best_class].item()
logits = logits.squeeze().tolist()
return {"label": label,
# "score": score,
"logits": logits}
PIPELINE_REGISTRY.register_pipeline(
"TunBERT-classifier",
pipeline_class=TBCP,
pt_model=AutoModelForSequenceClassification,
default={"pt": ("not-lain/TunBERT", "main")},
type="text", # current support type: text, audio, image, multimodal
)