smsner_model3 / sms_classifier.py
AbidHasan95's picture
First commit
1455e81 verified
raw
history blame contribute delete
No virus
2.16 kB
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig,PreTrainedModel, Pipeline
class SMSClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
# if "second_text" in kwargs:
# preprocess_kwargs["second_text"] = kwargs["second_text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text):
return self.tokenizer(text, return_tensors=self.framework)
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs):
seq_labels = [
"Transaction",
"Courier",
"OTP",
"Expiry",
"Misc",
"Tele Marketing",
"Spam",
]
token_class_labels = [
'O',
'Courier Service',
'Credit',
'Date',
'Debit',
'Email',
'Expiry',
'Item',
'Order ID',
'Organization',
'OTP',
'Phone Number',
'Refund',
'Time',
'Tracking ID',
'URL',
]
# logits = model_outputs.logits[0].numpy()
# probabilities = softmax(logits)
# best_class = np.argmax(probabilities)
# label = self.model.config.id2label[best_class]
# score = probabilities[best_class].item()
# logits = logits.tolist()
# return {"label": label, "score": score, "logits": logits}
# out = self.tokenizer(model_outputs, return_tensors="pt")
token_classification_logits, sequence_logits = model_outputs
token_classification_logits = token_classification_logits.argmax(2)[0]
sequence_logits = sequence_logits.argmax(1)[0]
token_classification_out = [token_class_labels[i] for i in token_classification_logits.tolist()]
seq_classification_out = seq_labels[sequence_logits]
# return token_classification_out, seq_classification_out
return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}