Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
hf_token = os.getenv('HF_TOKEN') | |
# Initialize Hugging Face pipeline for text classification or question-answering task | |
model_name = "meta-llama/Meta-Llama-Guard-3-8B" # Use a safety guard model or choose a suitable Hugging Face model | |
# Load the Hugging Face model and tokenizer for safety checking (using Hugging Face's `pipeline` API) | |
safety_model = pipeline('text-generation', model=model_name, use_auth_token=hf_token) | |
# Set the content policy for guardrails | |
safe_content_policy = { | |
"policy": """ | |
O1: Topic Restriction. | |
Should not: | |
- Respond to questions outside the domain of finance. | |
- Discuss non-financial topics such as health, technology, or politics. | |
Can: | |
- Provide responses strictly within the financial domain, including investment, savings, taxes, and financial management. | |
O2: Sensitive Information. | |
Should not: | |
- Request or handle sensitive personal data like passwords, account numbers, or social security numbers. | |
Can: | |
- Discuss financial concepts and advice without requiring personal details. | |
O3: Language. | |
Should not: | |
- Use profane, offensive, or inappropriate language. | |
Can: | |
- Use clear, professional, and respectful language appropriate for financial discussions. | |
""" | |
} | |
def is_safe(message: str) -> bool: | |
""" | |
Validates if the given message complies with the safety policy using the Hugging Face model. | |
Args: | |
message (str): User message or chatbot response. | |
Returns: | |
bool: True if the message is safe, False otherwise. | |
""" | |
prompt = f"""[INST] Task: Validate if the given 'user' message adheres to the safety policy below: | |
<BEGIN SAFETY POLICY> | |
{safe_content_policy['policy']} | |
<END SAFETY POLICY> | |
<BEGIN MESSAGE> | |
{message} | |
<END MESSAGE> | |
Provide your safety assessment for the above message: | |
- First line must read 'safe' or 'unsafe'. | |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" | |
# Get prediction from the Hugging Face model (classification model) | |
response = safety_model(prompt) | |
# Extract the classification result (safe or unsafe) | |
result = response[0]['label'].strip().lower() | |
return result == 'safe' | |