File size: 4,511 Bytes
00e4075
 
 
 
bd9870c
 
 
00e4075
 
 
bd9870c
a6a602a
bd9870c
00e4075
bd9870c
00e4075
 
2edb6cf
00e4075
 
 
 
 
 
 
 
 
 
 
 
bd9870c
00e4075
 
bd9870c
00e4075
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9870c
 
 
 
00e4075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
00e4075
bd9870c
 
 
 
 
 
 
 
 
 
 
00e4075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
import os
import requests

model_id = "meta-llama/LlamaGuard-7b"


# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACE_LLAMAGAURD_API_TOKEN")

# updated on March 24th

# reading unsafe categories file
try:
    with open("data/llama_guard-unsafe_categories.txt", "r") as file:
        unsafe_categories = file.read()
except FileNotFoundError:
    print("File not found. Please check the file path.")
    unsafe_categories = (
        ""  # Or handle the missing file as appropriate for your application
    )
except PermissionError:
    print("Permission denied. You don't have the rights to read the file.")
    unsafe_categories = ""  # Or handle the lack of permissions as appropriate
except Exception as e:  # Catch any other exceptions
    print(f"An error occurred while reading the file: {e}")
    unsafe_categories = ""  # Or handle unexpected errors as appropriate


# added on March 24th
def moderation_prompt_for_chat(chat):
    prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
  <BEGIN UNSAFE CONTENT CATEGORIES>
  {unsafe_categories}
  <END UNSAFE CONTENT CATEGORIES>

  <BEGIN CONVERSATION>

  User : {chat}

  <END CONVERSATION>

  Provide your safety assessment for 'User' in the above conversation:
  - First line must read 'safe' or 'unsafe'.
  - If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
    return prompt


def query(payload):
    API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
    bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
    headers = {
        "Accept": "application/json",
        "Authorization": bearer_txt,
        "Content-Type": "application/json",
    }
    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()  # This will raise an exception for HTTP error responses
        return response.json(), None
    except requests.exceptions.HTTPError as http_err:
        error_message = f"HTTP error occurred: {http_err}"
        print(error_message)
    except requests.exceptions.ConnectionError:
        error_message = "Could not connect to the API endpoint."
        print(error_message)
    except Exception as err:
        error_message = f"An error occurred: {err}"
        print(error_message)

    return None, error_message


def query1(payload):
    API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
    bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
    headers = {
        "Accept": "application/json",
        "Authorization": bearer_txt,
        "Content-Type": "application/json",
    }
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()


def moderate_chat(chat):
    prompt = moderation_prompt_for_chat(chat)

    output, error_msg = query(
        {
            "inputs": prompt,
            "parameters": {
                "top_k": 1,
                "top_p": 0.2,
                "temperature": 0.1,
                "max_new_tokens": 512,
            },
        }
    )

    print("Llamaguard prompt****", prompt)
    print("Llamaguard output****", output)

    return output, error_msg


# added on March 24th
def load_category_names_from_string(file_content):
    """Load category codes and names from a string into a dictionary."""
    category_names = {}
    lines = file_content.split("\n")
    for line in lines:
        if line.startswith("O"):
            parts = line.split(":")
            if len(parts) == 2:
                code = parts[0].strip()
                name = parts[1].strip()
                category_names[code] = name
    return category_names


def get_category_name(input_str):
    """Return the category name given a category code from an input string."""
    # Load the category names from the file content
    category_names = load_category_names_from_string(unsafe_categories)

    # Extract the category code from the input string
    category_code = input_str.split("\n")[1].strip()

    # Find the full category name using the code
    category_name = category_names.get(category_code, "Unknown Category")

    # return f"{category_code} : {category_name}"
    return f"{category_name}"