pratikshahp commited on
Commit
a177a19
·
verified ·
1 Parent(s): 991a228

Create hf_guardrail.py

Browse files
Files changed (1) hide show
  1. hf_guardrail.py +57 -0
hf_guardrail.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import os
3
+
4
+ # Initialize Hugging Face pipeline for text classification or question-answering task
5
+ model_name = "meta-llama/Meta-Llama-Guard-3-8B" # Use a safety guard model or choose a suitable Hugging Face model
6
+ hf_token = os.getenv('HF_TOKEN')
7
+
8
+ # Load the Hugging Face model and tokenizer for safety checking (using Hugging Face's `pipeline` API)
9
+ safety_model = pipeline('text-generation', model=model_name, use_auth_token=hf_token)
10
+
11
+ # Set the content policy for guardrails
12
+ safe_content_policy = {
13
+ "policy": """
14
+ O1: Topic Restriction.
15
+ Should not:
16
+ - Respond to questions outside the domain of finance.
17
+ - Discuss non-financial topics such as health, technology, or politics.
18
+ Can:
19
+ - Provide responses strictly within the financial domain, including investment, savings, taxes, and financial management.
20
+ O2: Sensitive Information.
21
+ Should not:
22
+ - Request or handle sensitive personal data like passwords, account numbers, or social security numbers.
23
+ Can:
24
+ - Discuss financial concepts and advice without requiring personal details.
25
+ O3: Language.
26
+ Should not:
27
+ - Use profane, offensive, or inappropriate language.
28
+ Can:
29
+ - Use clear, professional, and respectful language appropriate for financial discussions.
30
+ """
31
+ }
32
+
33
+ def is_safe(message: str) -> bool:
34
+ """
35
+ Validates if the given message complies with the safety policy using the Hugging Face model.
36
+ Args:
37
+ message (str): User message or chatbot response.
38
+ Returns:
39
+ bool: True if the message is safe, False otherwise.
40
+ """
41
+ prompt = f"""[INST] Task: Validate if the given 'user' message adheres to the safety policy below:
42
+ <BEGIN SAFETY POLICY>
43
+ {safe_content_policy['policy']}
44
+ <END SAFETY POLICY>
45
+ <BEGIN MESSAGE>
46
+ {message}
47
+ <END MESSAGE>
48
+ Provide your safety assessment for the above message:
49
+ - First line must read 'safe' or 'unsafe'.
50
+ - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
51
+
52
+ # Get prediction from the Hugging Face model (classification model)
53
+ response = safety_model(prompt)
54
+
55
+ # Extract the classification result (safe or unsafe)
56
+ result = response[0]['label'].strip().lower()
57
+ return result == 'safe'