import os import torch from joblib import load from transformers import BertTokenizer class EndpointHandler: def __init__(self, path=""): self.model = load(os.path.join(path, "model.joblib")) self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def __call__(self, data): inputs = data.pop("inputs", data) # Ensure inputs is a list if isinstance(inputs, str): inputs = [inputs] # Tokenize inputs encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt") # Move inputs to the correct device input_ids = encoded_inputs['input_ids'].to(self.device) attention_mask = encoded_inputs['attention_mask'].to(self.device) # Perform inference with torch.no_grad(): outputs = self.model(input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) predictions = torch.argmax(probabilities, dim=-1) # Convert predictions to human-readable labels class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"] results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)] return {"predictions": results}