from transformers import Pipeline, AutoModelForSequenceClassification,AutoTokenizer import torch from transformers.pipelines import PIPELINE_REGISTRY class TBCP(Pipeline): def __init__(self,**kwargs): Pipeline.__init__(self,**kwargs) self.tokenizer = AutoTokenizer.from_pretrained(kwargs["tokenizer"]) def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} if "text_pair" in kwargs: postprocess_kwargs["top_k"] = kwargs["top_k"] return {}, {}, postprocess_kwargs def preprocess(self, text): return self.tokenizer(text, return_tensors="pt") def _forward(self, model_inputs): return self.model(**model_inputs) def postprocess(self, model_outputs,top_k = None): logits = model_outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) best_class = probabilities.argmax().item() label = f"Label_{best_class}" # score = probabilities.squeeze()[best_class].item() logits = logits.squeeze().tolist() return {"label": label, # "score": score, "logits": logits} PIPELINE_REGISTRY.register_pipeline( "TunBERT-classifier", pipeline_class=TBCP, pt_model=AutoModelForSequenceClassification, default={"pt": ("not-lain/TunBERT", "main")}, type="text", # current support type: text, audio, image, multimodal )