menimeni123
commited on
Commit
·
c3085a4
1
Parent(s):
8e04930
latest
Browse files- .DS_Store +0 -0
- endpoint.py +0 -28
- handler.py +22 -25
- requirements.txt +3 -3
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
endpoint.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
from huggingface_hub import InferenceClient, create_inference_endpoint
|
2 |
-
|
3 |
-
# Create the inference endpoint
|
4 |
-
endpoint = create_inference_endpoint(
|
5 |
-
name="my-custom-endpoint",
|
6 |
-
repository="path/to/your/model/repository",
|
7 |
-
framework="custom",
|
8 |
-
task="text-classification",
|
9 |
-
accelerator="cpu", # or "gpu" if needed
|
10 |
-
instance_size="medium",
|
11 |
-
instance_type="c6i",
|
12 |
-
region="us-east-1",
|
13 |
-
custom_image={
|
14 |
-
"health_route": "/healthz",
|
15 |
-
"port": 8080,
|
16 |
-
"url": "your-docker-image-url:latest"
|
17 |
-
}
|
18 |
-
)
|
19 |
-
|
20 |
-
# Wait for the endpoint to be ready
|
21 |
-
endpoint.wait()
|
22 |
-
|
23 |
-
# Create a client to interact with the endpoint
|
24 |
-
client = InferenceClient(endpoint.url)
|
25 |
-
|
26 |
-
# Test the endpoint
|
27 |
-
result = client.text_classification("This is a test input")
|
28 |
-
print(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
handler.py
CHANGED
@@ -1,38 +1,35 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
from joblib import load
|
4 |
from transformers import BertTokenizer
|
|
|
|
|
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
-
|
|
|
|
|
|
|
9 |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
10 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
-
self.model.to(self.device)
|
12 |
|
13 |
-
def __call__(self, data):
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
# Ensure inputs is a list
|
17 |
-
if isinstance(inputs, str):
|
18 |
-
inputs = [inputs]
|
19 |
-
|
20 |
-
# Tokenize inputs
|
21 |
-
encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
|
22 |
-
|
23 |
-
# Move inputs to the correct device
|
24 |
-
input_ids = encoded_inputs['input_ids'].to(self.device)
|
25 |
-
attention_mask = encoded_inputs['attention_mask'].to(self.device)
|
26 |
-
|
27 |
# Perform inference
|
28 |
with torch.no_grad():
|
29 |
outputs = self.model(input_ids, attention_mask=attention_mask)
|
30 |
logits = outputs.logits
|
31 |
-
probabilities =
|
32 |
-
|
33 |
-
|
34 |
-
#
|
35 |
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
1 |
from joblib import load
|
2 |
from transformers import BertTokenizer
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import Dict, Any
|
6 |
|
7 |
class EndpointHandler:
|
8 |
def __init__(self, path=""):
|
9 |
+
# Load the model
|
10 |
+
self.model = load(f"{path}/model.joblib")
|
11 |
+
self.model.eval()
|
12 |
+
# Load the tokenizer
|
13 |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
|
14 |
|
15 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
16 |
+
# Extract input text
|
17 |
+
text = data.get("inputs", "")
|
18 |
+
# Tokenize the input text
|
19 |
+
encoding = self.tokenizer(text, truncation=True, padding=True, max_length=128, return_tensors='pt')
|
20 |
+
input_ids = encoding['input_ids']
|
21 |
+
attention_mask = encoding['attention_mask']
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Perform inference
|
24 |
with torch.no_grad():
|
25 |
outputs = self.model(input_ids, attention_mask=attention_mask)
|
26 |
logits = outputs.logits
|
27 |
+
probabilities = F.softmax(logits, dim=-1)
|
28 |
+
confidence, predicted_class = torch.max(probabilities, dim=-1)
|
29 |
+
|
30 |
+
# Map predicted class to label
|
31 |
class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
|
32 |
+
predicted_label = class_names[predicted_class.item()]
|
33 |
+
confidence_score = confidence.item()
|
34 |
+
|
35 |
+
return {"label": predicted_label, "confidence": confidence_score}
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
torch
|
2 |
-
transformers
|
3 |
-
joblib
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
joblib
|