FawadHaider2 commited on
Commit
0965dd4
Β·
verified Β·
1 Parent(s): d9ac46c

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +4 -4
  2. app.py +71 -0
  3. guardrail.py +54 -0
  4. helper.py +14 -0
  5. hf_app.py +69 -0
  6. hf_guardrail.py +60 -0
  7. requirements.txt +3 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Finance Guide
3
- emoji: πŸ“š
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.8.0
8
  app_file: app.py
 
1
  ---
2
+ title: Finance Chatbot With Guardrails
3
+ emoji: πŸ“‰
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.8.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Finance Chatbot)
2
+
3
+ import gradio as gr
4
+ from helper import get_together_api_key
5
+ from guardrail import is_safe
6
+ from together import Together
7
+
8
+ # Initialize Together client
9
+ client = Together(api_key=get_together_api_key())
10
+
11
+ # Function to handle the chatbot's response to user queries
12
+ # You can only answer finance-related queries.
13
+ # - Do not answer non-finance questions.
14
+ def run_action(message, history):
15
+ system_prompt = """You are a financial assistant.
16
+ - Answer in 50 words.
17
+ - Ensure responses adhere to the safety policy."""
18
+
19
+ messages = [{"role": "system", "content": system_prompt}]
20
+
21
+ # Convert history into the appropriate format
22
+ for entry in history:
23
+ if entry["role"] == "user":
24
+ messages.append({"role": "user", "content": entry["content"]})
25
+ elif entry["role"] == "assistant":
26
+ messages.append({"role": "assistant", "content": entry["content"]})
27
+
28
+ # Add the user's current action
29
+ messages.append({"role": "user", "content": message})
30
+
31
+ # Get the model's response
32
+ model_output = client.chat.completions.create(
33
+ model="meta-llama/Llama-3-70b-chat-hf",
34
+ messages=messages,
35
+ )
36
+
37
+ return model_output.choices[0].message.content
38
+
39
+ # Main loop for the chatbot to handle user input
40
+ def main_loop(message, history):
41
+ """
42
+ Main loop for the chatbot to handle user input.
43
+ """
44
+ # Validate the user's input for safety
45
+ if not is_safe(message):
46
+ return "Your input violates our safety policy. Please try again with a finance-related query."
47
+
48
+ # Generate and validate the response
49
+ return run_action(message, history)
50
+
51
+ # Gradio Chat Interface
52
+ demo = gr.ChatInterface(
53
+ main_loop,
54
+ chatbot=gr.Chatbot(
55
+ height=450,
56
+ placeholder="Ask a finance-related question. Type 'exit' to quit.",
57
+ type="messages", # Proper rendering of chat format
58
+ ),
59
+ textbox=gr.Textbox(
60
+ placeholder="What do you want to ask about finance?",
61
+ container=False,
62
+ scale=7,
63
+ ),
64
+ title="Finance Chatbot",
65
+ theme="Monochrome",
66
+ examples=["What is compound interest?", "How to save for retirement?", "What are tax-saving options?"],
67
+ cache_examples=False,
68
+ )
69
+
70
+ # Launch the Gradio app
71
+ demo.launch(share=True, server_name="0.0.0.0")
guardrail.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from together import Together
2
+ from helper import get_together_api_key
3
+
4
+ # Initialize Together client
5
+ client = Together(api_key=get_together_api_key())
6
+
7
+ # Set the content policy for guardrails
8
+ safe_content_policy = {
9
+ "policy": """
10
+ O1: Topic Restriction.
11
+ Should not:
12
+ - Respond to questions outside the domain of finance.
13
+ - Discuss non-financial topics such as health, technology, or politics.
14
+ Can:
15
+ - Provide responses strictly within the financial domain, including investment, savings, taxes, and financial management.
16
+ O2: Sensitive Information.
17
+ Should not:
18
+ - Request or handle sensitive personal data like passwords, account numbers, or social security numbers.
19
+ Can:
20
+ - Discuss financial concepts and advice without requiring personal details.
21
+ O3: Language.
22
+ Should not:
23
+ - Use profane, offensive, or inappropriate language.
24
+ Can:
25
+ - Use clear, professional, and respectful language appropriate for financial discussions.
26
+ """
27
+ }
28
+
29
+ def is_safe(message: str) -> bool:
30
+ """
31
+ Validates if the given message complies with the safety policy using the LlamaGuard model.
32
+ Args:
33
+ message (str): User message or chatbot response.
34
+ Returns:
35
+ bool: True if the message is safe, False otherwise.
36
+ """
37
+ prompt = f"""[INST] Task: Validate if the given 'user' message adheres to the safety policy below:
38
+ <BEGIN SAFETY POLICY>
39
+ {safe_content_policy['policy']}
40
+ <END SAFETY POLICY>
41
+ <BEGIN MESSAGE>
42
+ {message}
43
+ <END MESSAGE>
44
+ Provide your safety assessment for the above message:
45
+ - First line must read 'safe' or 'unsafe'.
46
+ - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
47
+
48
+ response = client.completions.create(
49
+ model="meta-llama/Meta-Llama-Guard-3-8B",
50
+ prompt=prompt
51
+ )
52
+
53
+ result = response.choices[0].text.strip().lower()
54
+ return result == 'safe'
helper.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv, find_dotenv
3
+
4
+ def load_env():
5
+ """Load environment variables from the .env file."""
6
+ _ = load_dotenv(find_dotenv())
7
+
8
+ def get_together_api_key() -> str:
9
+ """Retrieve the Together API key from the environment variables."""
10
+ load_env()
11
+ api_key = os.getenv("TOGETHER_API_KEY")
12
+ if not api_key:
13
+ raise ValueError("TOGETHER_API_KEY is not set in the environment variables.")
14
+ return api_key
hf_app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import gradio as gr
3
+ from transformers import pipeline
4
+ import os
5
+
6
+ load_dotenv()
7
+ # Load the Hugging Face model and tokenizer for text generation
8
+ hf_token = os.getenv('HF_TOKEN') # Hugging Face Token for authentication
9
+ model_name = "meta-llama/Llama-3-70b-chat-hf" # Hugging Face model
10
+ chat_pipeline = pipeline("text-generation", model=model_name, use_auth_token=hf_token)
11
+
12
+ # Function to handle the chatbot's response to user queries
13
+ # You can only answer finance-related queries.
14
+ # - Do not answer non-finance questions.
15
+ def run_action(message, history):
16
+ system_prompt = """You are a financial assistant.
17
+ - Answer in 50 words.
18
+ - Ensure responses adhere to the safety policy."""
19
+
20
+ messages = [{"role": "system", "content": system_prompt}]
21
+
22
+ # Convert history into the appropriate format
23
+ for entry in history:
24
+ if entry["role"] == "user":
25
+ messages.append({"role": "user", "content": entry["content"]})
26
+ elif entry["role"] == "assistant":
27
+ messages.append({"role": "assistant", "content": entry["content"]})
28
+
29
+ # Add the user's current action
30
+ messages.append({"role": "user", "content": message})
31
+
32
+ # Generate the model output using Hugging Face's pipeline
33
+ response = chat_pipeline(messages)
34
+
35
+ return response[0]['generated_text']
36
+
37
+ # Main loop for the chatbot to handle user input
38
+ def main_loop(message, history):
39
+ """
40
+ Main loop for the chatbot to handle user input.
41
+ """
42
+ # Validate the user's input for safety
43
+ if not is_safe(message):
44
+ return "Your input violates our safety policy. Please try again with a finance-related query."
45
+
46
+ # Generate and validate the response
47
+ return run_action(message, history)
48
+
49
+ # Gradio Chat Interface
50
+ demo = gr.ChatInterface(
51
+ main_loop,
52
+ chatbot=gr.Chatbot(
53
+ height=450,
54
+ placeholder="Ask a finance-related question. Type 'exit' to quit.",
55
+ type="messages", # Proper rendering of chat format
56
+ ),
57
+ textbox=gr.Textbox(
58
+ placeholder="What do you want to ask about finance?",
59
+ container=False,
60
+ scale=7,
61
+ ),
62
+ title="Finance Chatbot",
63
+ theme="Monochrome",
64
+ examples=["What is compound interest?", "How to save for retirement?", "What are tax-saving options?"],
65
+ cache_examples=False,
66
+ )
67
+
68
+ # Launch the Gradio app
69
+ demo.launch(share=True, server_name="0.0.0.0")
hf_guardrail.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+ hf_token = os.getenv('HF_TOKEN')
7
+
8
+ # Initialize Hugging Face pipeline for text classification or question-answering task
9
+ model_name = "meta-llama/Meta-Llama-Guard-3-8B" # Use a safety guard model or choose a suitable Hugging Face model
10
+
11
+ # Load the Hugging Face model and tokenizer for safety checking (using Hugging Face's `pipeline` API)
12
+ safety_model = pipeline('text-generation', model=model_name, use_auth_token=hf_token)
13
+
14
+ # Set the content policy for guardrails
15
+ safe_content_policy = {
16
+ "policy": """
17
+ O1: Topic Restriction.
18
+ Should not:
19
+ - Respond to questions outside the domain of finance.
20
+ - Discuss non-financial topics such as health, technology, or politics.
21
+ Can:
22
+ - Provide responses strictly within the financial domain, including investment, savings, taxes, and financial management.
23
+ O2: Sensitive Information.
24
+ Should not:
25
+ - Request or handle sensitive personal data like passwords, account numbers, or social security numbers.
26
+ Can:
27
+ - Discuss financial concepts and advice without requiring personal details.
28
+ O3: Language.
29
+ Should not:
30
+ - Use profane, offensive, or inappropriate language.
31
+ Can:
32
+ - Use clear, professional, and respectful language appropriate for financial discussions.
33
+ """
34
+ }
35
+
36
+ def is_safe(message: str) -> bool:
37
+ """
38
+ Validates if the given message complies with the safety policy using the Hugging Face model.
39
+ Args:
40
+ message (str): User message or chatbot response.
41
+ Returns:
42
+ bool: True if the message is safe, False otherwise.
43
+ """
44
+ prompt = f"""[INST] Task: Validate if the given 'user' message adheres to the safety policy below:
45
+ <BEGIN SAFETY POLICY>
46
+ {safe_content_policy['policy']}
47
+ <END SAFETY POLICY>
48
+ <BEGIN MESSAGE>
49
+ {message}
50
+ <END MESSAGE>
51
+ Provide your safety assessment for the above message:
52
+ - First line must read 'safe' or 'unsafe'.
53
+ - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
54
+
55
+ # Get prediction from the Hugging Face model (classification model)
56
+ response = safety_model(prompt)
57
+
58
+ # Extract the classification result (safe or unsafe)
59
+ result = response[0]['label'].strip().lower()
60
+ return result == 'safe'
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ together==1.2.0
2
+ python-dotenv~=1.0.1
3
+ gradio==4.44.1