GeorgiosIoannouCoder commited on
Commit
7284b57
1 Parent(s): dcfc540

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +399 -400
app.py CHANGED
@@ -1,400 +1,399 @@
1
- #############################################################################################################################
2
- # Filename : app.py
3
- # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
4
- # Author : Georgios Ioannou
5
- #
6
- # TODO: Add code for Google Gemma 7b and 7b-it.
7
- # TODO: Write code documentation.
8
- # Copyright © 2024 by Georgios Ioannou
9
- #############################################################################################################################
10
- # Import libraries.
11
-
12
- import os # Load environment variable(s).
13
- import requests # Send HTTP GET request to Hugging Face models for inference.
14
- import streamlit as st # Build the GUI of the application.
15
- import streamlit.components.v1 as components
16
-
17
- from dataclasses import dataclass
18
- from dotenv import find_dotenv, load_dotenv # Read local .env file.
19
- from langchain.callbacks import get_openai_callback
20
- from langchain.chains import ConversationChain
21
- from langchain.llms import OpenAI
22
- from policies import complex_policy, simple_policy
23
- from transformers import pipeline # Access to Hugging Face models.
24
- from typing import Literal
25
-
26
-
27
- #############################################################################################################################
28
- # Load environment variable(s).
29
-
30
- load_dotenv(find_dotenv()) # Read local .env file.
31
- HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
32
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
33
-
34
-
35
- #############################################################################################################################
36
-
37
-
38
- @dataclass
39
- class Message:
40
- """Class for keeping track of a chat message."""
41
-
42
- origin: Literal["human", "ai"]
43
- message: str
44
-
45
-
46
- #############################################################################################################################
47
- # Remeber (Keep history) user's message in each rendering of the application.
48
-
49
-
50
- def initialize_session_state():
51
- if "history" not in st.session_state:
52
- st.session_state.history = []
53
- if "token_count" not in st.session_state:
54
- st.session_state.token_count = 0
55
- if "policy" not in st.session_state:
56
- st.session_state.policy = "No Policy"
57
- if "is_gemma_model" not in st.session_state:
58
- st.session_state.is_gemma_model = False
59
- if "selected_ai_icon" not in st.session_state:
60
- st.session_state.selected_ai_icon = "a1.png"
61
- if "selected_user_icon" not in st.session_state:
62
- st.session_state.selected_user_icon = "man.png"
63
-
64
-
65
- #############################################################################################################################
66
-
67
-
68
- def on_click_callback():
69
- with get_openai_callback() as cb:
70
- human_prompt = st.session_state.human_prompt
71
-
72
- # Policy + User Query Text.
73
-
74
- role = "User" # This can be an Agent.
75
-
76
- task = f"""
77
- Task: Check if there is unsafe content in
78
- '{role}' messages in conversations
79
- according our safety policy with the below categories.
80
- """
81
-
82
- output_format = f"""
83
- Provide your safety assessment for '{role}' in the above conversation:
84
- - First line must read 'safe' or 'unsafe'.
85
- - If unsafe, a second line must include a comma-separated list of violated categories.
86
- """
87
-
88
- query = human_prompt
89
-
90
- conversation = f"""
91
- <BEGIN CONVERSATION>
92
- User: {query}
93
- <END CONVERSATION>
94
- """
95
-
96
- if st.session_state.policy == "Simple Policy":
97
- prompt = f"""
98
- {task}
99
- {simple_policy}
100
- {conversation}
101
- {output_format}
102
- """
103
- elif st.session_state.policy == "Complex Policy":
104
- prompt = f"""
105
- {task}
106
- {complex_policy}
107
- {conversation}
108
- {output_format}
109
- """
110
- elif st.session_state.policy == "No Policy":
111
- prompt = human_prompt
112
-
113
- # Getting the llm response for safety check 1.
114
- # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
115
- if st.session_state.is_gemma_model:
116
- pass
117
- else:
118
- llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
119
- st.session_state.history.append(Message("human", human_prompt))
120
- st.session_state.token_count += cb.total_tokens
121
-
122
- # Checking if response is safe. Safety Check 1. Checking what goes in (user input).
123
- if (
124
- "unsafe" in llm_response_safety_check_1.lower()
125
- ): # If respone is unsafe return unsafe.
126
- st.session_state.history.append(Message("ai", llm_response_safety_check_1))
127
- return
128
- else: # If respone is safe answer the question.
129
- if st.session_state.is_gemma_model:
130
- pass
131
- else:
132
- conversation_chain = ConversationChain(
133
- llm=OpenAI(
134
- temperature=0.2,
135
- openai_api_key=OPENAI_API_KEY,
136
- model_name=st.session_state.model,
137
- ),
138
- )
139
- llm_response = conversation_chain.run(human_prompt)
140
- # st.session_state.history.append(Message("ai", llm_response))
141
- st.session_state.token_count += cb.total_tokens
142
-
143
- # Policy + LLM Response.
144
- query = llm_response
145
-
146
- conversation = f"""
147
- <BEGIN CONVERSATION>
148
- User: {query}
149
- <END CONVERSATION>
150
- """
151
-
152
- if st.session_state.policy == "Simple Policy":
153
- prompt = f"""
154
- {task}
155
- {simple_policy}
156
- {conversation}
157
- {output_format}
158
- """
159
- elif st.session_state.policy == "Complex Policy":
160
- prompt = f"""
161
- {task}
162
- {complex_policy}
163
- {conversation}
164
- {output_format}
165
- """
166
- elif st.session_state.policy == "No Policy":
167
- prompt = llm_response
168
-
169
- # Getting the llm response for safety check 2.
170
- # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
171
- if st.session_state.is_gemma_model:
172
- pass
173
- else:
174
- llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
175
- st.session_state.token_count += cb.total_tokens
176
-
177
- # Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
178
- if (
179
- "unsafe" in llm_response_safety_check_2.lower()
180
- ): # If respone is unsafe return.
181
- st.session_state.history.append(
182
- Message(
183
- "ai",
184
- "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!",
185
- )
186
- )
187
- else:
188
- st.session_state.history.append(Message("ai", llm_response))
189
-
190
-
191
- #############################################################################################################################
192
- # Function to apply local CSS.
193
-
194
-
195
- def local_css(file_name):
196
- with open(file_name) as f:
197
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
198
-
199
-
200
- #############################################################################################################################
201
-
202
-
203
- # Main function to create the Streamlit web application.
204
-
205
-
206
- def main():
207
- # try:
208
- initialize_session_state()
209
-
210
- # Page title and favicon.
211
- st.set_page_config(page_title="Responsible AI", page_icon="⚖️")
212
-
213
- # Load CSS.
214
- local_css("./static/styles/styles.css")
215
-
216
- # Title.
217
- title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
218
- Responsible AI</h1>"""
219
- st.markdown(title, unsafe_allow_html=True)
220
-
221
- # Subtitle 1.
222
- title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
223
- Showcase the importance of Responsible AI in LLMs</h3>"""
224
- st.markdown(title, unsafe_allow_html=True)
225
-
226
- # Subtitle 2.
227
- title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
228
- CUNY Tech Prep Tutorial 6</h2>"""
229
- st.markdown(title, unsafe_allow_html=True)
230
-
231
- # Image.
232
- image = "./static/ctp.png"
233
- left_co, cent_co, last_co = st.columns(3)
234
- with cent_co:
235
- st.image(image=image)
236
-
237
- # Sidebar dropdown menu for Models.
238
- models = [
239
- "gpt-4-turbo",
240
- "gpt-4",
241
- "gpt-3.5-turbo",
242
- "gpt-3.5-turbo-instruct",
243
- "gemma-7b",
244
- "gemma-7b-it",
245
- ]
246
- selected_model = st.sidebar.selectbox("Select Model:", models)
247
- st.sidebar.write(f"Current Model: {selected_model}")
248
-
249
- if selected_model == "gpt-4-turbo":
250
- st.session_state.model = "gpt-4-turbo"
251
- elif selected_model == "gpt-4":
252
- st.session_state.model = "gpt-4"
253
- elif selected_model == "gpt-3.5-turbo":
254
- st.session_state.model = "gpt-3.5-turbo"
255
- elif selected_model == "gpt-3.5-turbo-instruct":
256
- st.session_state.model = "gpt-3.5-turbo-instruct"
257
- elif selected_model == "gemma-7b":
258
- st.session_state.model = "gemma-7b"
259
- elif selected_model == "gemma-7b-it":
260
- st.session_state.model = "gemma-7b-it"
261
-
262
- if "gpt" in st.session_state.model:
263
- st.session_state.conversation = ConversationChain(
264
- llm=OpenAI(
265
- temperature=0.2,
266
- openai_api_key=OPENAI_API_KEY,
267
- model_name=st.session_state.model,
268
- ),
269
- )
270
- elif "gemma" in st.session_state.model:
271
- # Load model from Hugging Face.
272
- st.session_state.is_gemma_model = True
273
- pass
274
-
275
- # Sidebar dropdown menu for Policies.
276
- policies = ["No Policy", "Complex Policy", "Simple Policy"]
277
- selected_policy = st.sidebar.selectbox("Select Policy:", policies)
278
- st.sidebar.write(f"Current Policy: {selected_policy}")
279
-
280
- if selected_policy == "No Policy":
281
- st.session_state.policy = "No Policy"
282
- elif selected_policy == "Complex Policy":
283
- st.session_state.policy = "Complex Policy"
284
- elif selected_policy == "Simple Policy":
285
- st.session_state.policy = "Simple Policy"
286
-
287
- # Sidebar dropdown menu for AI Icons.
288
- ai_icons = ["AI 1", "AI 2"]
289
- selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
290
- st.sidebar.write(f"Current AI Icon: {selected_ai_icon}")
291
-
292
- if selected_ai_icon == "AI 1":
293
- st.session_state.selected_ai_icon = "ai1.png"
294
- elif selected_ai_icon == "AI 2":
295
- st.session_state.selected_ai_icon = "ai2.png"
296
-
297
- # Sidebar dropdown menu for User Icons.
298
- user_icons = ["Man", "Woman"]
299
- selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
300
- st.sidebar.write(f"Current User Icon: {selected_user_icon}")
301
-
302
- if selected_user_icon == "Man":
303
- st.session_state.selected_user_icon = "man.png"
304
- elif selected_user_icon == "Woman":
305
- st.session_state.selected_user_icon = "woman.png"
306
-
307
- # Placeholder for the chat messages.
308
- chat_placeholder = st.container()
309
- # Placeholder for the user input.
310
- prompt_placeholder = st.form("chat-form")
311
- token_placeholder = st.empty()
312
-
313
- with chat_placeholder:
314
- for chat in st.session_state.history:
315
- div = f"""
316
- <div class="chat-row
317
- {'' if chat.origin == 'ai' else 'row-reverse'}">
318
- <img class="chat-icon" src="app/static/{
319
- st.session_state.selected_ai_icon if chat.origin == 'ai'
320
- else st.session_state.selected_user_icon}"
321
- width=32 height=32>
322
- <div class="chat-bubble
323
- {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
324
- &#8203;{chat.message}
325
- </div>
326
- </div>
327
- """
328
- st.markdown(div, unsafe_allow_html=True)
329
-
330
- for _ in range(3):
331
- st.markdown("")
332
-
333
- # User prompt.
334
- with prompt_placeholder:
335
- st.markdown("**Chat**")
336
- cols = st.columns((6, 1))
337
-
338
- # Large text input in the left column.
339
- cols[0].text_input(
340
- "Chat",
341
- placeholder="What is your question?",
342
- label_visibility="collapsed",
343
- key="human_prompt",
344
- )
345
- # Red button in the right column.
346
- cols[1].form_submit_button(
347
- "Submit",
348
- type="primary",
349
- on_click=on_click_callback,
350
- )
351
-
352
- token_placeholder.caption(
353
- f"""
354
- Used {st.session_state.token_count} tokens \n
355
- """
356
- )
357
-
358
- # GitHub repository of author.
359
-
360
- st.markdown(
361
- f"""
362
- <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
363
- <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
364
- </p>
365
- """,
366
- unsafe_allow_html=True,
367
- )
368
-
369
- # Use the Enter key in the keyborad to click on the Submit button.
370
- components.html(
371
- """
372
- <script>
373
- const streamlitDoc = window.parent.document;
374
-
375
- const buttons = Array.from(
376
- streamlitDoc.querySelectorAll('.stButton > button')
377
- );
378
- const submitButton = buttons.find(
379
- el => el.innerText === 'Submit'
380
- );
381
-
382
- streamlitDoc.addEventListener('keydown', function(e) {
383
- switch (e.key) {
384
- case 'Enter':
385
- submitButton.click();
386
- break;
387
- }
388
- });
389
- </script>
390
- """,
391
- height=0,
392
- width=0,
393
- )
394
-
395
-
396
- #############################################################################################################################
397
-
398
-
399
- if __name__ == "__main__":
400
- main()
 
1
+ #############################################################################################################################
2
+ # Filename : app.py
3
+ # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
4
+ # Author : Georgios Ioannou
5
+ #
6
+ # TODO: Add code for Google Gemma 7b and 7b-it.
7
+ # TODO: Write code documentation.
8
+ # Copyright © 2024 by Georgios Ioannou
9
+ #############################################################################################################################
10
+ # Import libraries.
11
+
12
+ import os # Load environment variable(s).
13
+ import requests # Send HTTP GET request to Hugging Face models for inference.
14
+ import streamlit as st # Build the GUI of the application.
15
+ import streamlit.components.v1 as components
16
+
17
+ from dataclasses import dataclass
18
+ from dotenv import find_dotenv, load_dotenv # Read local .env file.
19
+ from langchain.callbacks import get_openai_callback
20
+ from langchain.chains import ConversationChain
21
+ from langchain.llms import OpenAI
22
+ from policies import complex_policy, simple_policy
23
+ from transformers import pipeline # Access to Hugging Face models.
24
+ from typing import Literal
25
+
26
+
27
+ #############################################################################################################################
28
+ # Load environment variable(s).
29
+
30
+ # HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
31
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
+
33
+
34
+ #############################################################################################################################
35
+
36
+
37
+ @dataclass
38
+ class Message:
39
+ """Class for keeping track of a chat message."""
40
+
41
+ origin: Literal["human", "ai"]
42
+ message: str
43
+
44
+
45
+ #############################################################################################################################
46
+ # Remeber (Keep history) user's message in each rendering of the application.
47
+
48
+
49
+ def initialize_session_state():
50
+ if "history" not in st.session_state:
51
+ st.session_state.history = []
52
+ if "token_count" not in st.session_state:
53
+ st.session_state.token_count = 0
54
+ if "policy" not in st.session_state:
55
+ st.session_state.policy = "No Policy"
56
+ if "is_gemma_model" not in st.session_state:
57
+ st.session_state.is_gemma_model = False
58
+ if "selected_ai_icon" not in st.session_state:
59
+ st.session_state.selected_ai_icon = "a1.png"
60
+ if "selected_user_icon" not in st.session_state:
61
+ st.session_state.selected_user_icon = "man.png"
62
+
63
+
64
+ #############################################################################################################################
65
+
66
+
67
+ def on_click_callback():
68
+ with get_openai_callback() as cb:
69
+ human_prompt = st.session_state.human_prompt
70
+
71
+ # Policy + User Query Text.
72
+
73
+ role = "User" # This can be an Agent.
74
+
75
+ task = f"""
76
+ Task: Check if there is unsafe content in
77
+ '{role}' messages in conversations
78
+ according our safety policy with the below categories.
79
+ """
80
+
81
+ output_format = f"""
82
+ Provide your safety assessment for '{role}' in the above conversation:
83
+ - First line must read 'safe' or 'unsafe'.
84
+ - If unsafe, a second line must include a comma-separated list of violated categories.
85
+ """
86
+
87
+ query = human_prompt
88
+
89
+ conversation = f"""
90
+ <BEGIN CONVERSATION>
91
+ User: {query}
92
+ <END CONVERSATION>
93
+ """
94
+
95
+ if st.session_state.policy == "Simple Policy":
96
+ prompt = f"""
97
+ {task}
98
+ {simple_policy}
99
+ {conversation}
100
+ {output_format}
101
+ """
102
+ elif st.session_state.policy == "Complex Policy":
103
+ prompt = f"""
104
+ {task}
105
+ {complex_policy}
106
+ {conversation}
107
+ {output_format}
108
+ """
109
+ elif st.session_state.policy == "No Policy":
110
+ prompt = human_prompt
111
+
112
+ # Getting the llm response for safety check 1.
113
+ # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
114
+ if st.session_state.is_gemma_model:
115
+ pass
116
+ else:
117
+ llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
118
+ st.session_state.history.append(Message("human", human_prompt))
119
+ st.session_state.token_count += cb.total_tokens
120
+
121
+ # Checking if response is safe. Safety Check 1. Checking what goes in (user input).
122
+ if (
123
+ "unsafe" in llm_response_safety_check_1.lower()
124
+ ): # If respone is unsafe return unsafe.
125
+ st.session_state.history.append(Message("ai", llm_response_safety_check_1))
126
+ return
127
+ else: # If respone is safe answer the question.
128
+ if st.session_state.is_gemma_model:
129
+ pass
130
+ else:
131
+ conversation_chain = ConversationChain(
132
+ llm=OpenAI(
133
+ temperature=0.2,
134
+ openai_api_key=OPENAI_API_KEY,
135
+ model_name=st.session_state.model,
136
+ ),
137
+ )
138
+ llm_response = conversation_chain.run(human_prompt)
139
+ # st.session_state.history.append(Message("ai", llm_response))
140
+ st.session_state.token_count += cb.total_tokens
141
+
142
+ # Policy + LLM Response.
143
+ query = llm_response
144
+
145
+ conversation = f"""
146
+ <BEGIN CONVERSATION>
147
+ User: {query}
148
+ <END CONVERSATION>
149
+ """
150
+
151
+ if st.session_state.policy == "Simple Policy":
152
+ prompt = f"""
153
+ {task}
154
+ {simple_policy}
155
+ {conversation}
156
+ {output_format}
157
+ """
158
+ elif st.session_state.policy == "Complex Policy":
159
+ prompt = f"""
160
+ {task}
161
+ {complex_policy}
162
+ {conversation}
163
+ {output_format}
164
+ """
165
+ elif st.session_state.policy == "No Policy":
166
+ prompt = llm_response
167
+
168
+ # Getting the llm response for safety check 2.
169
+ # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
170
+ if st.session_state.is_gemma_model:
171
+ pass
172
+ else:
173
+ llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
174
+ st.session_state.token_count += cb.total_tokens
175
+
176
+ # Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
177
+ if (
178
+ "unsafe" in llm_response_safety_check_2.lower()
179
+ ): # If respone is unsafe return.
180
+ st.session_state.history.append(
181
+ Message(
182
+ "ai",
183
+ "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!",
184
+ )
185
+ )
186
+ else:
187
+ st.session_state.history.append(Message("ai", llm_response))
188
+
189
+
190
+ #############################################################################################################################
191
+ # Function to apply local CSS.
192
+
193
+
194
+ def local_css(file_name):
195
+ with open(file_name) as f:
196
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
197
+
198
+
199
+ #############################################################################################################################
200
+
201
+
202
+ # Main function to create the Streamlit web application.
203
+
204
+
205
+ def main():
206
+ # try:
207
+ initialize_session_state()
208
+
209
+ # Page title and favicon.
210
+ st.set_page_config(page_title="Responsible AI", page_icon="⚖️")
211
+
212
+ # Load CSS.
213
+ local_css("./static/styles/styles.css")
214
+
215
+ # Title.
216
+ title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
217
+ Responsible AI</h1>"""
218
+ st.markdown(title, unsafe_allow_html=True)
219
+
220
+ # Subtitle 1.
221
+ title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
222
+ Showcase the importance of Responsible AI in LLMs</h3>"""
223
+ st.markdown(title, unsafe_allow_html=True)
224
+
225
+ # Subtitle 2.
226
+ title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
227
+ CUNY Tech Prep Tutorial 6</h2>"""
228
+ st.markdown(title, unsafe_allow_html=True)
229
+
230
+ # Image.
231
+ image = "./static/ctp.png"
232
+ left_co, cent_co, last_co = st.columns(3)
233
+ with cent_co:
234
+ st.image(image=image)
235
+
236
+ # Sidebar dropdown menu for Models.
237
+ models = [
238
+ "gpt-4-turbo",
239
+ "gpt-4",
240
+ "gpt-3.5-turbo",
241
+ "gpt-3.5-turbo-instruct",
242
+ "gemma-7b",
243
+ "gemma-7b-it",
244
+ ]
245
+ selected_model = st.sidebar.selectbox("Select Model:", models)
246
+ st.sidebar.write(f"Current Model: {selected_model}")
247
+
248
+ if selected_model == "gpt-4-turbo":
249
+ st.session_state.model = "gpt-4-turbo"
250
+ elif selected_model == "gpt-4":
251
+ st.session_state.model = "gpt-4"
252
+ elif selected_model == "gpt-3.5-turbo":
253
+ st.session_state.model = "gpt-3.5-turbo"
254
+ elif selected_model == "gpt-3.5-turbo-instruct":
255
+ st.session_state.model = "gpt-3.5-turbo-instruct"
256
+ elif selected_model == "gemma-7b":
257
+ st.session_state.model = "gemma-7b"
258
+ elif selected_model == "gemma-7b-it":
259
+ st.session_state.model = "gemma-7b-it"
260
+
261
+ if "gpt" in st.session_state.model:
262
+ st.session_state.conversation = ConversationChain(
263
+ llm=OpenAI(
264
+ temperature=0.2,
265
+ openai_api_key=OPENAI_API_KEY,
266
+ model_name=st.session_state.model,
267
+ ),
268
+ )
269
+ elif "gemma" in st.session_state.model:
270
+ # Load model from Hugging Face.
271
+ st.session_state.is_gemma_model = True
272
+ pass
273
+
274
+ # Sidebar dropdown menu for Policies.
275
+ policies = ["No Policy", "Complex Policy", "Simple Policy"]
276
+ selected_policy = st.sidebar.selectbox("Select Policy:", policies)
277
+ st.sidebar.write(f"Current Policy: {selected_policy}")
278
+
279
+ if selected_policy == "No Policy":
280
+ st.session_state.policy = "No Policy"
281
+ elif selected_policy == "Complex Policy":
282
+ st.session_state.policy = "Complex Policy"
283
+ elif selected_policy == "Simple Policy":
284
+ st.session_state.policy = "Simple Policy"
285
+
286
+ # Sidebar dropdown menu for AI Icons.
287
+ ai_icons = ["AI 1", "AI 2"]
288
+ selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
289
+ st.sidebar.write(f"Current AI Icon: {selected_ai_icon}")
290
+
291
+ if selected_ai_icon == "AI 1":
292
+ st.session_state.selected_ai_icon = "ai1.png"
293
+ elif selected_ai_icon == "AI 2":
294
+ st.session_state.selected_ai_icon = "ai2.png"
295
+
296
+ # Sidebar dropdown menu for User Icons.
297
+ user_icons = ["Man", "Woman"]
298
+ selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
299
+ st.sidebar.write(f"Current User Icon: {selected_user_icon}")
300
+
301
+ if selected_user_icon == "Man":
302
+ st.session_state.selected_user_icon = "man.png"
303
+ elif selected_user_icon == "Woman":
304
+ st.session_state.selected_user_icon = "woman.png"
305
+
306
+ # Placeholder for the chat messages.
307
+ chat_placeholder = st.container()
308
+ # Placeholder for the user input.
309
+ prompt_placeholder = st.form("chat-form")
310
+ token_placeholder = st.empty()
311
+
312
+ with chat_placeholder:
313
+ for chat in st.session_state.history:
314
+ div = f"""
315
+ <div class="chat-row
316
+ {'' if chat.origin == 'ai' else 'row-reverse'}">
317
+ <img class="chat-icon" src="app/static/{
318
+ st.session_state.selected_ai_icon if chat.origin == 'ai'
319
+ else st.session_state.selected_user_icon}"
320
+ width=32 height=32>
321
+ <div class="chat-bubble
322
+ {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
323
+ &#8203;{chat.message}
324
+ </div>
325
+ </div>
326
+ """
327
+ st.markdown(div, unsafe_allow_html=True)
328
+
329
+ for _ in range(3):
330
+ st.markdown("")
331
+
332
+ # User prompt.
333
+ with prompt_placeholder:
334
+ st.markdown("**Chat**")
335
+ cols = st.columns((6, 1))
336
+
337
+ # Large text input in the left column.
338
+ cols[0].text_input(
339
+ "Chat",
340
+ placeholder="What is your question?",
341
+ label_visibility="collapsed",
342
+ key="human_prompt",
343
+ )
344
+ # Red button in the right column.
345
+ cols[1].form_submit_button(
346
+ "Submit",
347
+ type="primary",
348
+ on_click=on_click_callback,
349
+ )
350
+
351
+ token_placeholder.caption(
352
+ f"""
353
+ Used {st.session_state.token_count} tokens \n
354
+ """
355
+ )
356
+
357
+ # GitHub repository of author.
358
+
359
+ st.markdown(
360
+ f"""
361
+ <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
362
+ <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
363
+ </p>
364
+ """,
365
+ unsafe_allow_html=True,
366
+ )
367
+
368
+ # Use the Enter key in the keyborad to click on the Submit button.
369
+ components.html(
370
+ """
371
+ <script>
372
+ const streamlitDoc = window.parent.document;
373
+
374
+ const buttons = Array.from(
375
+ streamlitDoc.querySelectorAll('.stButton > button')
376
+ );
377
+ const submitButton = buttons.find(
378
+ el => el.innerText === 'Submit'
379
+ );
380
+
381
+ streamlitDoc.addEventListener('keydown', function(e) {
382
+ switch (e.key) {
383
+ case 'Enter':
384
+ submitButton.click();
385
+ break;
386
+ }
387
+ });
388
+ </script>
389
+ """,
390
+ height=0,
391
+ width=0,
392
+ )
393
+
394
+
395
+ #############################################################################################################################
396
+
397
+
398
+ if __name__ == "__main__":
399
+ main()