############################################################################################################################# # Filename : app.py # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs. # Author : Georgios Ioannou # # TODO: Add code for Google Gemma 7b and 7b-it. # TODO: Write code documentation. # Copyright © 2024 by Georgios Ioannou ############################################################################################################################# # Import libraries. import os # Load environment variable(s). import requests # Send HTTP GET request to Hugging Face models for inference. import streamlit as st # Build the GUI of the application. import streamlit.components.v1 as components from dataclasses import dataclass from dotenv import find_dotenv, load_dotenv # Read local .env file. from langchain.callbacks import get_openai_callback from langchain.chains import ConversationChain from langchain.llms import OpenAI from policies import complex_policy, simple_policy from transformers import pipeline # Access to Hugging Face models. from typing import Literal ############################################################################################################################# # Load environment variable(s). # HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") ############################################################################################################################# @dataclass class Message: """Class for keeping track of a chat message.""" origin: Literal["human", "ai"] message: str ############################################################################################################################# # Remeber (Keep history) user's message in each rendering of the application. def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "policy" not in st.session_state: st.session_state.policy = "No Policy" if "is_gemma_model" not in st.session_state: st.session_state.is_gemma_model = False if "selected_ai_icon" not in st.session_state: st.session_state.selected_ai_icon = "a1.png" if "selected_user_icon" not in st.session_state: st.session_state.selected_user_icon = "man.png" ############################################################################################################################# def on_click_callback(): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt # Policy + User Query Text. role = "User" # This can be an Agent. task = f""" Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. """ output_format = f""" Provide your safety assessment for '{role}' in the above conversation: - First line must read 'safe' or 'unsafe'. - If unsafe, a second line must include a comma-separated list of violated categories. """ query = human_prompt conversation = f""" User: {query} """ if st.session_state.policy == "Simple Policy": prompt = f""" {task} {simple_policy} {conversation} {output_format} """ elif st.session_state.policy == "Complex Policy": prompt = f""" {task} {complex_policy} {conversation} {output_format} """ elif st.session_state.policy == "No Policy": prompt = human_prompt # Getting the llm response for safety check 1. # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" if st.session_state.is_gemma_model: pass else: llm_response_safety_check_1 = st.session_state.conversation.run(prompt) st.session_state.history.append(Message("human", human_prompt)) st.session_state.token_count += cb.total_tokens # Checking if response is safe. Safety Check 1. Checking what goes in (user input). if ( "unsafe" in llm_response_safety_check_1.lower() ): # If respone is unsafe return unsafe. st.session_state.history.append(Message("ai", llm_response_safety_check_1)) return else: # If respone is safe answer the question. if st.session_state.is_gemma_model: pass else: conversation_chain = ConversationChain( llm=OpenAI( temperature=0.2, openai_api_key=OPENAI_API_KEY, model_name=st.session_state.model, ), ) llm_response = conversation_chain.run(human_prompt) # st.session_state.history.append(Message("ai", llm_response)) st.session_state.token_count += cb.total_tokens # Policy + LLM Response. query = llm_response conversation = f""" User: {query} """ if st.session_state.policy == "Simple Policy": prompt = f""" {task} {simple_policy} {conversation} {output_format} """ elif st.session_state.policy == "Complex Policy": prompt = f""" {task} {complex_policy} {conversation} {output_format} """ elif st.session_state.policy == "No Policy": prompt = llm_response # Getting the llm response for safety check 2. # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b" if st.session_state.is_gemma_model: pass else: llm_response_safety_check_2 = st.session_state.conversation.run(prompt) st.session_state.token_count += cb.total_tokens # Checking if response is safe. Safety Check 2. Checking what goes out (llm output). if ( "unsafe" in llm_response_safety_check_2.lower() ): # If respone is unsafe return. st.session_state.history.append( Message( "ai", "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", ) ) else: st.session_state.history.append(Message("ai", llm_response)) ############################################################################################################################# # Function to apply local CSS. def local_css(file_name): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) ############################################################################################################################# # Main function to create the Streamlit web application. def main(): # try: initialize_session_state() # Page title and favicon. st.set_page_config(page_title="Responsible AI", page_icon="⚖️") # Load CSS. local_css("./static/styles/styles.css") # Title. title = f"""

Responsible AI

""" st.markdown(title, unsafe_allow_html=True) # Subtitle 1. title = f"""

Showcase the importance of Responsible AI in LLMs

""" st.markdown(title, unsafe_allow_html=True) # Subtitle 2. title = f"""

CUNY Tech Prep Tutorial 6

""" st.markdown(title, unsafe_allow_html=True) # Image. image = "./static/ctp.png" left_co, cent_co, last_co = st.columns(3) with cent_co: st.image(image=image) # Sidebar dropdown menu for Models. models = [ "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gemma-7b", "gemma-7b-it", ] selected_model = st.sidebar.selectbox("Select Model:", models) st.sidebar.write(f"Current Model: {selected_model}") if selected_model == "gpt-4-turbo": st.session_state.model = "gpt-4-turbo" elif selected_model == "gpt-4": st.session_state.model = "gpt-4" elif selected_model == "gpt-3.5-turbo": st.session_state.model = "gpt-3.5-turbo" elif selected_model == "gpt-3.5-turbo-instruct": st.session_state.model = "gpt-3.5-turbo-instruct" elif selected_model == "gemma-7b": st.session_state.model = "gemma-7b" elif selected_model == "gemma-7b-it": st.session_state.model = "gemma-7b-it" if "gpt" in st.session_state.model: st.session_state.conversation = ConversationChain( llm=OpenAI( temperature=0.2, openai_api_key=OPENAI_API_KEY, model_name=st.session_state.model, ), ) elif "gemma" in st.session_state.model: # Load model from Hugging Face. st.session_state.is_gemma_model = True pass # Sidebar dropdown menu for Policies. policies = ["No Policy", "Complex Policy", "Simple Policy"] selected_policy = st.sidebar.selectbox("Select Policy:", policies) st.sidebar.write(f"Current Policy: {selected_policy}") if selected_policy == "No Policy": st.session_state.policy = "No Policy" elif selected_policy == "Complex Policy": st.session_state.policy = "Complex Policy" elif selected_policy == "Simple Policy": st.session_state.policy = "Simple Policy" # Sidebar dropdown menu for AI Icons. ai_icons = ["AI 1", "AI 2"] selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) st.sidebar.write(f"Current AI Icon: {selected_ai_icon}") if selected_ai_icon == "AI 1": st.session_state.selected_ai_icon = "ai1.png" elif selected_ai_icon == "AI 2": st.session_state.selected_ai_icon = "ai2.png" # Sidebar dropdown menu for User Icons. user_icons = ["Man", "Woman"] selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) st.sidebar.write(f"Current User Icon: {selected_user_icon}") if selected_user_icon == "Man": st.session_state.selected_user_icon = "man.png" elif selected_user_icon == "Woman": st.session_state.selected_user_icon = "woman.png" # Placeholder for the chat messages. chat_placeholder = st.container() # Placeholder for the user input. prompt_placeholder = st.form("chat-form") token_placeholder = st.empty() with chat_placeholder: for chat in st.session_state.history: div = f"""
​{chat.message}
""" st.markdown(div, unsafe_allow_html=True) for _ in range(3): st.markdown("") # User prompt. with prompt_placeholder: st.markdown("**Chat**") cols = st.columns((6, 1)) # Large text input in the left column. cols[0].text_input( "Chat", placeholder="What is your question?", label_visibility="collapsed", key="human_prompt", ) # Red button in the right column. cols[1].form_submit_button( "Submit", type="primary", on_click=on_click_callback, ) token_placeholder.caption( f""" Used {st.session_state.token_count} tokens \n """ ) # GitHub repository of author. st.markdown( f"""

Check out our GitHub repository

""", unsafe_allow_html=True, ) # Use the Enter key in the keyborad to click on the Submit button. components.html( """ """, height=0, width=0, ) ############################################################################################################################# if __name__ == "__main__": main()