Spaces:
Runtime error
Runtime error
############################################################################################################################# | |
# Filename : app.py | |
# Description: A Streamlit application to showcase the importance of Responsible AI in LLMs. | |
# Author : Georgios Ioannou | |
# | |
# TODO: Add code for Google Gemma 7b and 7b-it. | |
# TODO: Write code documentation. | |
# Copyright © 2024 by Georgios Ioannou | |
############################################################################################################################# | |
# Import libraries. | |
import os # Load environment variable(s). | |
import requests # Send HTTP GET request to Hugging Face models for inference. | |
import streamlit as st # Build the GUI of the application. | |
import streamlit.components.v1 as components | |
from dataclasses import dataclass | |
from dotenv import find_dotenv, load_dotenv # Read local .env file. | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains import ConversationChain | |
from langchain.llms import OpenAI | |
from policies import complex_policy, simple_policy | |
from transformers import pipeline # Access to Hugging Face models. | |
from typing import Literal | |
############################################################################################################################# | |
# Load environment variable(s). | |
# HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
############################################################################################################################# | |
class Message: | |
"""Class for keeping track of a chat message.""" | |
origin: Literal["human", "ai"] | |
message: str | |
############################################################################################################################# | |
# Remeber (Keep history) user's message in each rendering of the application. | |
def initialize_session_state(): | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "token_count" not in st.session_state: | |
st.session_state.token_count = 0 | |
if "policy" not in st.session_state: | |
st.session_state.policy = "No Policy" | |
if "is_gemma_model" not in st.session_state: | |
st.session_state.is_gemma_model = False | |
if "selected_ai_icon" not in st.session_state: | |
st.session_state.selected_ai_icon = "a1.png" | |
if "selected_user_icon" not in st.session_state: | |
st.session_state.selected_user_icon = "man.png" | |
############################################################################################################################# | |
def on_click_callback(): | |
with get_openai_callback() as cb: | |
human_prompt = st.session_state.human_prompt | |
# Policy + User Query Text. | |
role = "User" # This can be an Agent. | |
task = f""" | |
Task: Check if there is unsafe content in | |
'{role}' messages in conversations | |
according our safety policy with the below categories. | |
""" | |
output_format = f""" | |
Provide your safety assessment for '{role}' in the above conversation: | |
- First line must read 'safe' or 'unsafe'. | |
- If unsafe, a second line must include a comma-separated list of violated categories. | |
""" | |
query = human_prompt | |
conversation = f""" | |
<BEGIN CONVERSATION> | |
User: {query} | |
<END CONVERSATION> | |
""" | |
if st.session_state.policy == "Simple Policy": | |
prompt = f""" | |
{task} | |
{simple_policy} | |
{conversation} | |
{output_format} | |
""" | |
elif st.session_state.policy == "Complex Policy": | |
prompt = f""" | |
{task} | |
{complex_policy} | |
{conversation} | |
{output_format} | |
""" | |
elif st.session_state.policy == "No Policy": | |
prompt = human_prompt | |
# Getting the llm response for safety check 1. | |
# "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" | |
if st.session_state.is_gemma_model: | |
pass | |
else: | |
llm_response_safety_check_1 = st.session_state.conversation.run(prompt) | |
st.session_state.history.append(Message("human", human_prompt)) | |
st.session_state.token_count += cb.total_tokens | |
# Checking if response is safe. Safety Check 1. Checking what goes in (user input). | |
if ( | |
"unsafe" in llm_response_safety_check_1.lower() | |
): # If respone is unsafe return unsafe. | |
st.session_state.history.append(Message("ai", llm_response_safety_check_1)) | |
return | |
else: # If respone is safe answer the question. | |
if st.session_state.is_gemma_model: | |
pass | |
else: | |
conversation_chain = ConversationChain( | |
llm=OpenAI( | |
temperature=0.2, | |
openai_api_key=OPENAI_API_KEY, | |
model_name=st.session_state.model, | |
), | |
) | |
llm_response = conversation_chain.run(human_prompt) | |
# st.session_state.history.append(Message("ai", llm_response)) | |
st.session_state.token_count += cb.total_tokens | |
# Policy + LLM Response. | |
query = llm_response | |
conversation = f""" | |
<BEGIN CONVERSATION> | |
User: {query} | |
<END CONVERSATION> | |
""" | |
if st.session_state.policy == "Simple Policy": | |
prompt = f""" | |
{task} | |
{simple_policy} | |
{conversation} | |
{output_format} | |
""" | |
elif st.session_state.policy == "Complex Policy": | |
prompt = f""" | |
{task} | |
{complex_policy} | |
{conversation} | |
{output_format} | |
""" | |
elif st.session_state.policy == "No Policy": | |
prompt = llm_response | |
# Getting the llm response for safety check 2. | |
# "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" | |
if st.session_state.is_gemma_model: | |
pass | |
else: | |
llm_response_safety_check_2 = st.session_state.conversation.run(prompt) | |
st.session_state.token_count += cb.total_tokens | |
# Checking if response is safe. Safety Check 2. Checking what goes out (llm output). | |
if ( | |
"unsafe" in llm_response_safety_check_2.lower() | |
): # If respone is unsafe return. | |
st.session_state.history.append( | |
Message( | |
"ai", | |
"THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", | |
) | |
) | |
else: | |
st.session_state.history.append(Message("ai", llm_response)) | |
############################################################################################################################# | |
# Function to apply local CSS. | |
def local_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
############################################################################################################################# | |
# Main function to create the Streamlit web application. | |
def main(): | |
# try: | |
initialize_session_state() | |
# Page title and favicon. | |
st.set_page_config(page_title="Responsible AI", page_icon="⚖️") | |
# Load CSS. | |
local_css("./static/styles/styles.css") | |
# Title. | |
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
Responsible AI</h1>""" | |
st.markdown(title, unsafe_allow_html=True) | |
# Subtitle 1. | |
title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
Showcase the importance of Responsible AI in LLMs Using Policies</h3>""" | |
st.markdown(title, unsafe_allow_html=True) | |
# Subtitle 2. | |
title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem"> | |
CUNY Tech Prep Tutorial 6</h2>""" | |
st.markdown(title, unsafe_allow_html=True) | |
# Image. | |
image = "./static/ctp.png" | |
left_co, cent_co, last_co = st.columns(3) | |
with cent_co: | |
st.image(image=image) | |
# Sidebar dropdown menu for Models. | |
models = [ | |
"gpt-4-turbo", | |
"gpt-4", | |
"gpt-3.5-turbo", | |
"gpt-3.5-turbo-instruct", | |
"gemma-7b", | |
"gemma-7b-it", | |
] | |
selected_model = st.sidebar.selectbox("Select Model:", models) | |
st.sidebar.write(f"Current Model: {selected_model}") | |
if selected_model == "gpt-4-turbo": | |
st.session_state.model = "gpt-4-turbo" | |
elif selected_model == "gpt-4": | |
st.session_state.model = "gpt-4" | |
elif selected_model == "gpt-3.5-turbo": | |
st.session_state.model = "gpt-3.5-turbo" | |
elif selected_model == "gpt-3.5-turbo-instruct": | |
st.session_state.model = "gpt-3.5-turbo-instruct" | |
elif selected_model == "gemma-7b": | |
st.session_state.model = "gemma-7b" | |
elif selected_model == "gemma-7b-it": | |
st.session_state.model = "gemma-7b-it" | |
if "gpt" in st.session_state.model: | |
st.session_state.conversation = ConversationChain( | |
llm=OpenAI( | |
temperature=0.2, | |
openai_api_key=OPENAI_API_KEY, | |
model_name=st.session_state.model, | |
), | |
) | |
elif "gemma" in st.session_state.model: | |
# Load model from Hugging Face. | |
st.session_state.is_gemma_model = True | |
pass | |
# Sidebar dropdown menu for Policies. | |
policies = ["No Policy", "Complex Policy", "Simple Policy"] | |
selected_policy = st.sidebar.selectbox("Select Policy:", policies) | |
st.sidebar.write(f"Current Policy: {selected_policy}") | |
if selected_policy == "No Policy": | |
st.session_state.policy = "No Policy" | |
elif selected_policy == "Complex Policy": | |
st.session_state.policy = "Complex Policy" | |
elif selected_policy == "Simple Policy": | |
st.session_state.policy = "Simple Policy" | |
# Sidebar dropdown menu for AI Icons. | |
ai_icons = ["AI 1", "AI 2"] | |
selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) | |
st.sidebar.write(f"Current AI Icon: {selected_ai_icon}") | |
if selected_ai_icon == "AI 1": | |
st.session_state.selected_ai_icon = "ai1.png" | |
elif selected_ai_icon == "AI 2": | |
st.session_state.selected_ai_icon = "ai2.png" | |
# Sidebar dropdown menu for User Icons. | |
user_icons = ["Man", "Woman"] | |
selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) | |
st.sidebar.write(f"Current User Icon: {selected_user_icon}") | |
if selected_user_icon == "Man": | |
st.session_state.selected_user_icon = "man.png" | |
elif selected_user_icon == "Woman": | |
st.session_state.selected_user_icon = "woman.png" | |
# Placeholder for the chat messages. | |
chat_placeholder = st.container() | |
# Placeholder for the user input. | |
prompt_placeholder = st.form("chat-form") | |
token_placeholder = st.empty() | |
with chat_placeholder: | |
for chat in st.session_state.history: | |
div = f""" | |
<div class="chat-row | |
{'' if chat.origin == 'ai' else 'row-reverse'}"> | |
<img class="chat-icon" src="app/static/{ | |
st.session_state.selected_ai_icon if chat.origin == 'ai' | |
else st.session_state.selected_user_icon}" | |
width=32 height=32> | |
<div class="chat-bubble | |
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
​{chat.message} | |
</div> | |
</div> | |
""" | |
st.markdown(div, unsafe_allow_html=True) | |
for _ in range(3): | |
st.markdown("") | |
# User prompt. | |
with prompt_placeholder: | |
st.markdown("**Chat**") | |
cols = st.columns((6, 1)) | |
# Large text input in the left column. | |
cols[0].text_input( | |
"Chat", | |
placeholder="What is your question?", | |
label_visibility="collapsed", | |
key="human_prompt", | |
) | |
# Red button in the right column. | |
cols[1].form_submit_button( | |
"Submit", | |
type="primary", | |
on_click=on_click_callback, | |
) | |
token_placeholder.caption( | |
f""" | |
Used {st.session_state.token_count} tokens \n | |
""" | |
) | |
# GitHub repository of author. | |
st.markdown( | |
f""" | |
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Use the Enter key in the keyborad to click on the Submit button. | |
components.html( | |
""" | |
<script> | |
const streamlitDoc = window.parent.document; | |
const buttons = Array.from( | |
streamlitDoc.querySelectorAll('.stButton > button') | |
); | |
const submitButton = buttons.find( | |
el => el.innerText === 'Submit' | |
); | |
streamlitDoc.addEventListener('keydown', function(e) { | |
switch (e.key) { | |
case 'Enter': | |
submitButton.click(); | |
break; | |
} | |
}); | |
</script> | |
""", | |
height=0, | |
width=0, | |
) | |
############################################################################################################################# | |
if __name__ == "__main__": | |
main() | |