|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import requests |
|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
|
|
from dataclasses import dataclass |
|
from dotenv import find_dotenv, load_dotenv |
|
from langchain.callbacks import get_openai_callback |
|
from langchain.chains import ConversationChain |
|
from langchain.llms import OpenAI |
|
from policies import complex_policy, simple_policy |
|
from transformers import pipeline |
|
from typing import Literal |
|
|
|
|
|
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class Message: |
|
"""Class for keeping track of a chat message.""" |
|
|
|
origin: Literal["human", "ai"] |
|
message: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_session_state(): |
|
if "history" not in st.session_state: |
|
st.session_state.history = [] |
|
if "token_count" not in st.session_state: |
|
st.session_state.token_count = 0 |
|
if "policy" not in st.session_state: |
|
st.session_state.policy = "No Policy" |
|
if "is_gemma_model" not in st.session_state: |
|
st.session_state.is_gemma_model = False |
|
if "selected_ai_icon" not in st.session_state: |
|
st.session_state.selected_ai_icon = "a1.png" |
|
if "selected_user_icon" not in st.session_state: |
|
st.session_state.selected_user_icon = "man.png" |
|
|
|
|
|
|
|
|
|
|
|
def on_click_callback(): |
|
with get_openai_callback() as cb: |
|
human_prompt = st.session_state.human_prompt |
|
|
|
|
|
|
|
role = "User" |
|
|
|
task = f""" |
|
Task: Check if there is unsafe content in |
|
'{role}' messages in conversations |
|
according our safety policy with the below categories. |
|
""" |
|
|
|
output_format = f""" |
|
Provide your safety assessment for '{role}' in the above conversation: |
|
- First line must read 'safe' or 'unsafe'. |
|
- If unsafe, a second line must include a comma-separated list of violated categories. |
|
""" |
|
|
|
query = human_prompt |
|
|
|
conversation = f""" |
|
<BEGIN CONVERSATION> |
|
User: {query} |
|
<END CONVERSATION> |
|
""" |
|
|
|
if st.session_state.policy == "Simple Policy": |
|
prompt = f""" |
|
{task} |
|
{simple_policy} |
|
{conversation} |
|
{output_format} |
|
""" |
|
elif st.session_state.policy == "Complex Policy": |
|
prompt = f""" |
|
{task} |
|
{complex_policy} |
|
{conversation} |
|
{output_format} |
|
""" |
|
elif st.session_state.policy == "No Policy": |
|
prompt = human_prompt |
|
|
|
|
|
|
|
if st.session_state.is_gemma_model: |
|
pass |
|
else: |
|
llm_response_safety_check_1 = st.session_state.conversation.run(prompt) |
|
st.session_state.history.append(Message("human", human_prompt)) |
|
st.session_state.token_count += cb.total_tokens |
|
|
|
|
|
if ( |
|
"unsafe" in llm_response_safety_check_1.lower() |
|
): |
|
st.session_state.history.append(Message("ai", llm_response_safety_check_1)) |
|
return |
|
else: |
|
if st.session_state.is_gemma_model: |
|
pass |
|
else: |
|
conversation_chain = ConversationChain( |
|
llm=OpenAI( |
|
temperature=0.2, |
|
openai_api_key=OPENAI_API_KEY, |
|
model_name=st.session_state.model, |
|
), |
|
) |
|
llm_response = conversation_chain.run(human_prompt) |
|
|
|
st.session_state.token_count += cb.total_tokens |
|
|
|
|
|
query = llm_response |
|
|
|
conversation = f""" |
|
<BEGIN CONVERSATION> |
|
User: {query} |
|
<END CONVERSATION> |
|
""" |
|
|
|
if st.session_state.policy == "Simple Policy": |
|
prompt = f""" |
|
{task} |
|
{simple_policy} |
|
{conversation} |
|
{output_format} |
|
""" |
|
elif st.session_state.policy == "Complex Policy": |
|
prompt = f""" |
|
{task} |
|
{complex_policy} |
|
{conversation} |
|
{output_format} |
|
""" |
|
elif st.session_state.policy == "No Policy": |
|
prompt = llm_response |
|
|
|
|
|
|
|
if st.session_state.is_gemma_model: |
|
pass |
|
else: |
|
llm_response_safety_check_2 = st.session_state.conversation.run(prompt) |
|
st.session_state.token_count += cb.total_tokens |
|
|
|
|
|
if ( |
|
"unsafe" in llm_response_safety_check_2.lower() |
|
): |
|
st.session_state.history.append( |
|
Message( |
|
"ai", |
|
"THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", |
|
) |
|
) |
|
else: |
|
st.session_state.history.append(Message("ai", llm_response)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def local_css(file_name): |
|
with open(file_name) as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
initialize_session_state() |
|
|
|
|
|
st.set_page_config(page_title="Responsible AI", page_icon="⚖️") |
|
|
|
|
|
local_css("./static/styles/styles.css") |
|
|
|
|
|
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> |
|
Responsible AI</h1>""" |
|
st.markdown(title, unsafe_allow_html=True) |
|
|
|
|
|
title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> |
|
Showcase the importance of Responsible AI in LLMs</h3>""" |
|
st.markdown(title, unsafe_allow_html=True) |
|
|
|
|
|
title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem"> |
|
CUNY Tech Prep Tutorial 6</h2>""" |
|
st.markdown(title, unsafe_allow_html=True) |
|
|
|
|
|
image = "./static/ctp.png" |
|
left_co, cent_co, last_co = st.columns(3) |
|
with cent_co: |
|
st.image(image=image) |
|
|
|
|
|
models = [ |
|
"gpt-4-turbo", |
|
"gpt-4", |
|
"gpt-3.5-turbo", |
|
"gpt-3.5-turbo-instruct", |
|
"gemma-7b", |
|
"gemma-7b-it", |
|
] |
|
selected_model = st.sidebar.selectbox("Select Model:", models) |
|
st.sidebar.write(f"Current Model: {selected_model}") |
|
|
|
if selected_model == "gpt-4-turbo": |
|
st.session_state.model = "gpt-4-turbo" |
|
elif selected_model == "gpt-4": |
|
st.session_state.model = "gpt-4" |
|
elif selected_model == "gpt-3.5-turbo": |
|
st.session_state.model = "gpt-3.5-turbo" |
|
elif selected_model == "gpt-3.5-turbo-instruct": |
|
st.session_state.model = "gpt-3.5-turbo-instruct" |
|
elif selected_model == "gemma-7b": |
|
st.session_state.model = "gemma-7b" |
|
elif selected_model == "gemma-7b-it": |
|
st.session_state.model = "gemma-7b-it" |
|
|
|
if "gpt" in st.session_state.model: |
|
st.session_state.conversation = ConversationChain( |
|
llm=OpenAI( |
|
temperature=0.2, |
|
openai_api_key=OPENAI_API_KEY, |
|
model_name=st.session_state.model, |
|
), |
|
) |
|
elif "gemma" in st.session_state.model: |
|
|
|
st.session_state.is_gemma_model = True |
|
pass |
|
|
|
|
|
policies = ["No Policy", "Complex Policy", "Simple Policy"] |
|
selected_policy = st.sidebar.selectbox("Select Policy:", policies) |
|
st.sidebar.write(f"Current Policy: {selected_policy}") |
|
|
|
if selected_policy == "No Policy": |
|
st.session_state.policy = "No Policy" |
|
elif selected_policy == "Complex Policy": |
|
st.session_state.policy = "Complex Policy" |
|
elif selected_policy == "Simple Policy": |
|
st.session_state.policy = "Simple Policy" |
|
|
|
|
|
ai_icons = ["AI 1", "AI 2"] |
|
selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) |
|
st.sidebar.write(f"Current AI Icon: {selected_ai_icon}") |
|
|
|
if selected_ai_icon == "AI 1": |
|
st.session_state.selected_ai_icon = "ai1.png" |
|
elif selected_ai_icon == "AI 2": |
|
st.session_state.selected_ai_icon = "ai2.png" |
|
|
|
|
|
user_icons = ["Man", "Woman"] |
|
selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) |
|
st.sidebar.write(f"Current User Icon: {selected_user_icon}") |
|
|
|
if selected_user_icon == "Man": |
|
st.session_state.selected_user_icon = "man.png" |
|
elif selected_user_icon == "Woman": |
|
st.session_state.selected_user_icon = "woman.png" |
|
|
|
|
|
chat_placeholder = st.container() |
|
|
|
prompt_placeholder = st.form("chat-form") |
|
token_placeholder = st.empty() |
|
|
|
with chat_placeholder: |
|
for chat in st.session_state.history: |
|
div = f""" |
|
<div class="chat-row |
|
{'' if chat.origin == 'ai' else 'row-reverse'}"> |
|
<img class="chat-icon" src="app/static/{ |
|
st.session_state.selected_ai_icon if chat.origin == 'ai' |
|
else st.session_state.selected_user_icon}" |
|
width=32 height=32> |
|
<div class="chat-bubble |
|
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> |
|
​{chat.message} |
|
</div> |
|
</div> |
|
""" |
|
st.markdown(div, unsafe_allow_html=True) |
|
|
|
for _ in range(3): |
|
st.markdown("") |
|
|
|
|
|
with prompt_placeholder: |
|
st.markdown("**Chat**") |
|
cols = st.columns((6, 1)) |
|
|
|
|
|
cols[0].text_input( |
|
"Chat", |
|
placeholder="What is your question?", |
|
label_visibility="collapsed", |
|
key="human_prompt", |
|
) |
|
|
|
cols[1].form_submit_button( |
|
"Submit", |
|
type="primary", |
|
on_click=on_click_callback, |
|
) |
|
|
|
token_placeholder.caption( |
|
f""" |
|
Used {st.session_state.token_count} tokens \n |
|
""" |
|
) |
|
|
|
|
|
|
|
st.markdown( |
|
f""" |
|
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our |
|
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
components.html( |
|
""" |
|
<script> |
|
const streamlitDoc = window.parent.document; |
|
|
|
const buttons = Array.from( |
|
streamlitDoc.querySelectorAll('.stButton > button') |
|
); |
|
const submitButton = buttons.find( |
|
el => el.innerText === 'Submit' |
|
); |
|
|
|
streamlitDoc.addEventListener('keydown', function(e) { |
|
switch (e.key) { |
|
case 'Enter': |
|
submitButton.click(); |
|
break; |
|
} |
|
}); |
|
</script> |
|
""", |
|
height=0, |
|
width=0, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|