|
import os |
|
import torch |
|
from joblib import load |
|
from transformers import BertTokenizer |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.model = load(os.path.join(path, "model.joblib")) |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def __call__(self, data): |
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
|
|
|
|
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt") |
|
|
|
|
|
input_ids = encoded_inputs['input_ids'].to(self.device) |
|
attention_mask = encoded_inputs['attention_mask'].to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
predictions = torch.argmax(probabilities, dim=-1) |
|
|
|
|
|
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"] |
|
results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)] |
|
|
|
return {"predictions": results} |