|
|
|
import os |
|
import joblib |
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
from torch.nn.functional import softmax |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = joblib.load('model.joblib') |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"] |
|
|
|
def preprocess(text): |
|
|
|
encoding = tokenizer( |
|
text, |
|
truncation=True, |
|
padding=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
return encoding |
|
|
|
def inference(model_inputs): |
|
""" |
|
This function will be called for every inference request. |
|
""" |
|
try: |
|
|
|
text = model_inputs.get('text', None) |
|
if text is None: |
|
return {'message': 'No text provided for inference.'} |
|
|
|
|
|
encoding = preprocess(text) |
|
input_ids = encoding['input_ids'].to(device) |
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
probabilities = softmax(logits, dim=-1) |
|
confidence, predicted_class = torch.max(probabilities, dim=-1) |
|
|
|
|
|
predicted_label = class_names[predicted_class.item()] |
|
confidence_score = confidence.item() |
|
|
|
return { |
|
'classification': predicted_label, |
|
'confidence': confidence_score |
|
} |
|
|
|
except Exception as e: |
|
return {'error': str(e)} |
|
|