|
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig,PreTrainedModel, Pipeline |
|
|
|
|
|
class SMSClassificationPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
|
|
|
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, text): |
|
return self.tokenizer(text, return_tensors=self.framework) |
|
|
|
def _forward(self, model_inputs): |
|
return self.model(**model_inputs) |
|
|
|
def postprocess(self, model_outputs): |
|
seq_labels = [ |
|
"Transaction", |
|
"Courier", |
|
"OTP", |
|
"Expiry", |
|
"Misc", |
|
"Tele Marketing", |
|
"Spam", |
|
] |
|
|
|
token_class_labels = [ |
|
'O', |
|
'Courier Service', |
|
'Credit', |
|
'Date', |
|
'Debit', |
|
'Email', |
|
'Expiry', |
|
'Item', |
|
'Order ID', |
|
'Organization', |
|
'OTP', |
|
'Phone Number', |
|
'Refund', |
|
'Time', |
|
'Tracking ID', |
|
'URL', |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_classification_logits, sequence_logits = model_outputs |
|
token_classification_logits = token_classification_logits.argmax(2)[0] |
|
sequence_logits = sequence_logits.argmax(1)[0] |
|
token_classification_out = [token_class_labels[i] for i in token_classification_logits.tolist()] |
|
seq_classification_out = seq_labels[sequence_logits] |
|
|
|
return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out} |