|
from dotenv import load_dotenv, find_dotenv |
|
import os |
|
import requests |
|
|
|
|
|
load_dotenv(find_dotenv()) |
|
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] |
|
|
|
with open("data/processed/llama_guard-unsafe_files.txt", "r") as file: |
|
unsafe_categories = file.read() |
|
|
|
|
|
def moderation_prompt_for_chat(chat): |
|
|
|
|
|
conversation = [turn["content"] for turn in chat] |
|
is_agent = len(conversation) % 2 == 0 |
|
role = "Agent" if len(conversation) % 2 == 0 else "User" |
|
prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. |
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES> |
|
{unsafe_categories} |
|
<END UNSAFE CONTENT CATEGORIES> |
|
|
|
<BEGIN CONVERSATION> |
|
|
|
""" |
|
|
|
|
|
for i, m in enumerate(conversation): |
|
role = "User" if i % 2 == 0 else "Agent" |
|
prompt += f"{role}: {m}\n\n" |
|
prompt += f"""<END CONVERSATION> |
|
|
|
Provide your safety assessment for {role} in the above conversation: |
|
- First line must read 'safe' or 'unsafe'. |
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" |
|
return prompt |
|
|
|
|
|
def query(payload): |
|
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud" |
|
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}" |
|
headers = { |
|
"Accept": "application/json", |
|
"Authorization": bearer_txt, |
|
"Content-Type": "application/json", |
|
} |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
|
|
return response.json() |
|
|
|
|
|
def moderate_chat(chat): |
|
prompt = moderation_prompt_for_chat(chat) |
|
|
|
output = query( |
|
{ |
|
"inputs": prompt, |
|
"parameters": { |
|
"top_k": 1, |
|
"top_p": 0.2, |
|
"temperature": 0.1, |
|
"max_new_tokens": 512, |
|
}, |
|
} |
|
) |
|
|
|
return output |
|
|