pratikshahp commited on
Commit
bb2efc2
·
verified ·
1 Parent(s): b31696f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -2,12 +2,13 @@ import os
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  from langchain_huggingface import HuggingFaceEndpoint
 
5
 
6
  # Load environment variables
7
  load_dotenv()
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
10
- # Initialize the Hugging Face endpoint for inference (Primary model)
11
  llm = HuggingFaceEndpoint(
12
  repo_id="mistralai/Mistral-7B-Instruct-v0.3", # Replace with your model repo
13
  huggingfacehub_api_token=HF_TOKEN.strip(),
@@ -15,37 +16,30 @@ llm = HuggingFaceEndpoint(
15
  max_new_tokens=100
16
  )
17
 
18
- # Initialize the Meta-Llama Guard model for content safety
19
- guard_llm = HuggingFaceEndpoint(
20
- repo_id="meta-llama/Meta-Llama-Guard-2-8B", # Replace with your guard model repo
21
- huggingfacehub_api_token=HF_TOKEN.strip(),
22
- temperature=0.7,
23
- max_new_tokens=100
24
- )
25
 
26
- # Function to handle chatbot response with guardrails
27
- def guarded_chatbot_response(message):
28
  try:
29
- # Get raw response from the primary model
30
  raw_response = llm(message)
31
 
32
- # Guardrail logic: Ensure the response is safe and appropriate
33
- guard_prompt = f"""
34
- You are a content safety guard for kids under 16. Analyze the following text and ensure it is friendly and appropriate for children.
35
- If appropriate, return the text as is. If not, rephrase or respond with "Content not suitable."
36
- Text: {raw_response}
37
- """
38
 
39
- # Get the guarded response from the Meta-Llama Guard model
40
- guard_response = guard_llm(guard_prompt)
41
- return guard_response
 
 
42
  except Exception as e:
43
  return f"Error: {e}"
44
 
45
  # Gradio Interface for Chatbot with Guardrails
46
  with gr.Blocks() as app_with_guardrails:
47
  gr.Markdown("## Chatbot With Guardrails")
48
- gr.Markdown("This chatbot ensures all responses are appropriate for kids under 16.")
49
 
50
  # Input and output
51
  with gr.Row():
@@ -55,7 +49,7 @@ with gr.Blocks() as app_with_guardrails:
55
 
56
  # Button click event
57
  submit_button.click(
58
- guarded_chatbot_response,
59
  inputs=[user_input],
60
  outputs=[response_output]
61
  )
 
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  from langchain_huggingface import HuggingFaceEndpoint
5
+ from transformers import pipeline
6
 
7
  # Load environment variables
8
  load_dotenv()
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
11
+ # Initialize the Hugging Face endpoint for text generation (Mistral model)
12
  llm = HuggingFaceEndpoint(
13
  repo_id="mistralai/Mistral-7B-Instruct-v0.3", # Replace with your model repo
14
  huggingfacehub_api_token=HF_TOKEN.strip(),
 
16
  max_new_tokens=100
17
  )
18
 
19
+ # Initialize content moderation model (e.g., JinaAI ContentFilter or similar)
20
+ content_filter = pipeline("text-classification", model="JinaAI/ContentFilter", tokenizer="JinaAI/ContentFilter")
 
 
 
 
 
21
 
22
+ # Function to handle chatbot response and guardrails
23
+ def chatbot_response_with_guardrails(message):
24
  try:
25
+ # Generate raw response from the primary model (Mistral)
26
  raw_response = llm(message)
27
 
28
+ # Check if the response contains inappropriate content using the content filter
29
+ result = content_filter(raw_response)
 
 
 
 
30
 
31
+ # If the response is deemed harmful, modify it or reject it
32
+ if result[0]['label'] == 'toxic': # Adjust based on your model's output
33
+ return "Content not suitable."
34
+ else:
35
+ return raw_response
36
  except Exception as e:
37
  return f"Error: {e}"
38
 
39
  # Gradio Interface for Chatbot with Guardrails
40
  with gr.Blocks() as app_with_guardrails:
41
  gr.Markdown("## Chatbot With Guardrails")
42
+ gr.Markdown("This chatbot ensures all responses are appropriate.")
43
 
44
  # Input and output
45
  with gr.Row():
 
49
 
50
  # Button click event
51
  submit_button.click(
52
+ chatbot_response_with_guardrails,
53
  inputs=[user_input],
54
  outputs=[response_output]
55
  )