from transformers import pipeline import os from dotenv import load_dotenv load_dotenv() hf_token = os.getenv('HF_TOKEN') # Initialize Hugging Face pipeline for text classification or question-answering task model_name = "meta-llama/Meta-Llama-Guard-3-8B" # Use a safety guard model or choose a suitable Hugging Face model # Load the Hugging Face model and tokenizer for safety checking (using Hugging Face's `pipeline` API) safety_model = pipeline('text-generation', model=model_name, use_auth_token=hf_token) # Set the content policy for guardrails safe_content_policy = { "policy": """ O1: Topic Restriction. Should not: - Respond to questions outside the domain of finance. - Discuss non-financial topics such as health, technology, or politics. Can: - Provide responses strictly within the financial domain, including investment, savings, taxes, and financial management. O2: Sensitive Information. Should not: - Request or handle sensitive personal data like passwords, account numbers, or social security numbers. Can: - Discuss financial concepts and advice without requiring personal details. O3: Language. Should not: - Use profane, offensive, or inappropriate language. Can: - Use clear, professional, and respectful language appropriate for financial discussions. """ } def is_safe(message: str) -> bool: """ Validates if the given message complies with the safety policy using the Hugging Face model. Args: message (str): User message or chatbot response. Returns: bool: True if the message is safe, False otherwise. """ prompt = f"""[INST] Task: Validate if the given 'user' message adheres to the safety policy below: {safe_content_policy['policy']} {message} Provide your safety assessment for the above message: - First line must read 'safe' or 'unsafe'. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" # Get prediction from the Hugging Face model (classification model) response = safety_model(prompt) # Extract the classification result (safe or unsafe) result = response[0]['label'].strip().lower() return result == 'safe'