helem-llm / handler.py
menimeni123's picture
latest
91cf739
raw
history blame
1.51 kB
import os
import torch
from joblib import load
from transformers import BertTokenizer
class EndpointHandler:
def __init__(self, path=""):
self.model = load(os.path.join(path, "model.joblib"))
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def __call__(self, data):
inputs = data.pop("inputs", data)
# Ensure inputs is a list
if isinstance(inputs, str):
inputs = [inputs]
# Tokenize inputs
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
# Move inputs to the correct device
input_ids = encoded_inputs['input_ids'].to(self.device)
attention_mask = encoded_inputs['attention_mask'].to(self.device)
# Perform inference
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predictions = torch.argmax(probabilities, dim=-1)
# Convert predictions to human-readable labels
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)]
return {"predictions": results}