menimeni123 commited on
Commit
c3085a4
·
1 Parent(s): 8e04930
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. endpoint.py +0 -28
  3. handler.py +22 -25
  4. requirements.txt +3 -3
.DS_Store ADDED
Binary file (6.15 kB). View file
 
endpoint.py DELETED
@@ -1,28 +0,0 @@
1
- from huggingface_hub import InferenceClient, create_inference_endpoint
2
-
3
- # Create the inference endpoint
4
- endpoint = create_inference_endpoint(
5
- name="my-custom-endpoint",
6
- repository="path/to/your/model/repository",
7
- framework="custom",
8
- task="text-classification",
9
- accelerator="cpu", # or "gpu" if needed
10
- instance_size="medium",
11
- instance_type="c6i",
12
- region="us-east-1",
13
- custom_image={
14
- "health_route": "/healthz",
15
- "port": 8080,
16
- "url": "your-docker-image-url:latest"
17
- }
18
- )
19
-
20
- # Wait for the endpoint to be ready
21
- endpoint.wait()
22
-
23
- # Create a client to interact with the endpoint
24
- client = InferenceClient(endpoint.url)
25
-
26
- # Test the endpoint
27
- result = client.text_classification("This is a test input")
28
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
handler.py CHANGED
@@ -1,38 +1,35 @@
1
- import os
2
- import torch
3
  from joblib import load
4
  from transformers import BertTokenizer
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- self.model = load(os.path.join(path, "model.joblib"))
 
 
 
9
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
10
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- self.model.to(self.device)
12
 
13
- def __call__(self, data):
14
- inputs = data.pop("inputs", data)
 
 
 
 
 
15
 
16
- # Ensure inputs is a list
17
- if isinstance(inputs, str):
18
- inputs = [inputs]
19
-
20
- # Tokenize inputs
21
- encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, max_length=128, return_tensors="pt")
22
-
23
- # Move inputs to the correct device
24
- input_ids = encoded_inputs['input_ids'].to(self.device)
25
- attention_mask = encoded_inputs['attention_mask'].to(self.device)
26
-
27
  # Perform inference
28
  with torch.no_grad():
29
  outputs = self.model(input_ids, attention_mask=attention_mask)
30
  logits = outputs.logits
31
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
32
- predictions = torch.argmax(probabilities, dim=-1)
33
-
34
- # Convert predictions to human-readable labels
35
  class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
36
- results = [{"label": class_names[pred], "score": prob[pred].item()} for pred, prob in zip(predictions, probabilities)]
37
-
38
- return {"predictions": results}
 
 
 
 
1
  from joblib import load
2
  from transformers import BertTokenizer
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Dict, Any
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
+ # Load the model
10
+ self.model = load(f"{path}/model.joblib")
11
+ self.model.eval()
12
+ # Load the tokenizer
13
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
 
14
 
15
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
+ # Extract input text
17
+ text = data.get("inputs", "")
18
+ # Tokenize the input text
19
+ encoding = self.tokenizer(text, truncation=True, padding=True, max_length=128, return_tensors='pt')
20
+ input_ids = encoding['input_ids']
21
+ attention_mask = encoding['attention_mask']
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Perform inference
24
  with torch.no_grad():
25
  outputs = self.model(input_ids, attention_mask=attention_mask)
26
  logits = outputs.logits
27
+ probabilities = F.softmax(logits, dim=-1)
28
+ confidence, predicted_class = torch.max(probabilities, dim=-1)
29
+
30
+ # Map predicted class to label
31
  class_names = ["JAILBREAK", "INJECTION", "PHISHING", "SAFE"]
32
+ predicted_label = class_names[predicted_class.item()]
33
+ confidence_score = confidence.item()
34
+
35
+ return {"label": predicted_label, "confidence": confidence_score}
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch==1.9.0
2
- transformers==4.44.2
3
- joblib==1.1.0
 
1
+ torch
2
+ transformers
3
+ joblib