DrishtiSharma commited on
Commit
9419dfe
·
verified ·
1 Parent(s): 2af0ce8

Create better_responses.py

Browse files
Files changed (1) hide show
  1. better_responses.py +1229 -0
better_responses.py ADDED
@@ -0,0 +1,1229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref: https://github.com/twy80/LangChain_llm_Agent/tree/main
2
+ import streamlit as st
3
+ import os, base64, re, requests, datetime, time, json
4
+ import matplotlib.pyplot as plt
5
+ from io import BytesIO
6
+ from functools import partial
7
+ from tempfile import NamedTemporaryFile
8
+ from audio_recorder_streamlit import audio_recorder
9
+ from PIL import Image, UnidentifiedImageError
10
+ from openai import OpenAI
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_openai import OpenAIEmbeddings
13
+ from langchain_anthropic import ChatAnthropic
14
+ from langchain_google_genai import ChatGoogleGenerativeAI
15
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
16
+ from langchain_google_community import GoogleSearchAPIWrapper
17
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
+ from langchain.schema import HumanMessage, AIMessage
19
+ from langchain_community.utilities import BingSearchAPIWrapper
20
+ from langchain_community.document_loaders import PyPDFLoader
21
+ from langchain_community.document_loaders import Docx2txtLoader
22
+ from langchain_community.document_loaders import TextLoader
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
+ from langchain_community.vectorstores import FAISS
25
+ from langchain.tools import Tool, tool
26
+ from langchain.tools.retriever import create_retriever_tool
27
+ # from langchain.agents import create_openai_tools_agent
28
+ from langchain.agents import create_tool_calling_agent
29
+ from langchain.agents import create_react_agent
30
+ from langchain.agents import AgentExecutor
31
+ from langchain_community.agent_toolkits.load_tools import load_tools
32
+ # from langchain_experimental.tools import PythonREPLTool
33
+ from langchain_experimental.utilities import PythonREPL
34
+ from langchain.callbacks.base import BaseCallbackHandler
35
+ from pydantic import BaseModel, Field
36
+ # The following are for type annotations
37
+ from typing import Union, List, Literal, Optional, Dict, Any, Annotated
38
+ from matplotlib.figure import Figure
39
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
40
+ from openai._legacy_response import HttpxBinaryResponseContent
41
+ from tempfile import NamedTemporaryFile, TemporaryDirectory
42
+
43
+ # Load API keys from Hugging Face secrets
44
+ try:
45
+ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
46
+ os.environ["BING_SUBSCRIPTION_KEY"] = st.secrets.get("BING_SUBSCRIPTION_KEY", "")
47
+ os.environ["GOOGLE_API_KEY"] = st.secrets.get("GOOGLE_API_KEY", "")
48
+ os.environ["GOOGLE_CSE_ID"] = st.secrets.get("GOOGLE_CSE_ID", "")
49
+ except KeyError as e:
50
+ st.error(f"Missing required secret: {e}. Please set it in Hugging Face Space secrets.")
51
+ st.stop()
52
+
53
+ def initialize_session_state_variables() -> None:
54
+ """
55
+ Initialize all the session state variables.
56
+ """
57
+ default_values = {
58
+ "ready": False,
59
+ "openai": None,
60
+ "history": [],
61
+ "model_type": "GPT Models from OpenAI",
62
+ "agent_type": 2 * ["Tool Calling"],
63
+ "ai_role": 2 * ["You are a helpful AI assistant."],
64
+ "prompt_exists": False,
65
+ "temperature": [0.7, 0.7],
66
+ "audio_bytes": None,
67
+ "mic_used": False,
68
+ "audio_response": None,
69
+ "image_url": None,
70
+ "image_description": None,
71
+ "uploader_key": 0,
72
+ "tool_names": [[], []],
73
+ "bing_subscription_validity": False,
74
+ "google_cse_id_validity": False,
75
+ "vector_store_message": None,
76
+ "retriever_tool": None,
77
+ "show_uploader": False
78
+ }
79
+
80
+ for key, value in default_values.items():
81
+ if key not in st.session_state:
82
+ st.session_state[key] = value
83
+
84
+
85
+
86
+ class StreamHandler(BaseCallbackHandler):
87
+ def __init__(self, container, initial_text=""):
88
+ self.container = container
89
+ self.text = initial_text
90
+
91
+ def on_llm_new_token(self, token: Any, **kwargs) -> None:
92
+ new_text = self._extract_text(token)
93
+ if new_text:
94
+ self.text += new_text
95
+ self.container.markdown(self.text)
96
+
97
+ def _extract_text(self, token: Any) -> str:
98
+ if isinstance(token, str):
99
+ return token
100
+ elif isinstance(token, list):
101
+ return ''.join(self._extract_text(t) for t in token)
102
+ elif isinstance(token, dict):
103
+ return token.get('text', '')
104
+ else:
105
+ return str(token)
106
+
107
+
108
+ def check_api_keys() -> None:
109
+ # Unset this flag to check the validity of the OpenAI API key
110
+ st.session_state.ready = False
111
+
112
+
113
+ def message_history_to_string(extra_space: bool=True) -> str:
114
+ """
115
+ Return a string of the chat history contained in
116
+ st.session_state.history.
117
+ """
118
+
119
+ history_list = []
120
+ for msg in st.session_state.history:
121
+ if isinstance(msg, HumanMessage):
122
+ history_list.append(f"Human: {msg.content}")
123
+ else:
124
+ history_list.append(f"AI: {msg.content}")
125
+ new_lines = "\n\n" if extra_space else "\n"
126
+
127
+ return new_lines.join(history_list)
128
+
129
+
130
+ def get_chat_model(
131
+ model: str,
132
+ temperature: float,
133
+ callbacks: List[BaseCallbackHandler]
134
+ ) -> Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI, None]:
135
+
136
+ """
137
+ Get the appropriate chat model based on the given model name.
138
+ """
139
+
140
+ model_map = {
141
+ "gpt-": ChatOpenAI,
142
+ }
143
+ for prefix, ModelClass in model_map.items():
144
+ if model.startswith(prefix):
145
+ return ModelClass(
146
+ model=model,
147
+ temperature=temperature,
148
+ streaming=True,
149
+ callbacks=callbacks
150
+ )
151
+ return None
152
+
153
+
154
+ def process_with_images(
155
+ llm: Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI],
156
+ message_content: str,
157
+ image_urls: List[str]
158
+ ) -> str:
159
+
160
+ """
161
+ Process the given history query with associated images using a language model.
162
+ """
163
+
164
+ content_with_images = (
165
+ [{"type": "text", "text": message_content}] +
166
+ [{"type": "image_url", "image_url": {"url": url}} for url in image_urls]
167
+ )
168
+ message_with_images = [HumanMessage(content=content_with_images)]
169
+
170
+ return llm.invoke(message_with_images).content
171
+
172
+
173
+ def process_with_tools(
174
+ llm: Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI],
175
+ tools: List[Tool],
176
+ agent_type: str,
177
+ agent_prompt: str,
178
+ history_query: dict
179
+ ) -> str:
180
+
181
+ """
182
+ Create an AI agent based on the specified agent type and tools,
183
+ then use this agent to process the given history query.
184
+ """
185
+
186
+ if agent_type == "Tool Calling":
187
+ agent = create_tool_calling_agent(llm, tools, agent_prompt)
188
+ else:
189
+ agent = create_react_agent(llm, tools, agent_prompt)
190
+
191
+ agent_executor = AgentExecutor(
192
+ agent=agent, tools=tools, max_iterations=5, verbose=False,
193
+ handle_parsing_errors=True,
194
+ )
195
+
196
+ return agent_executor.invoke(history_query)["output"]
197
+
198
+
199
+ def run_agent(
200
+ query: str,
201
+ model: str,
202
+ tools: List[Tool],
203
+ image_urls: List[str],
204
+ temperature: float=0.7,
205
+ agent_type: Literal["Tool Calling", "ReAct"]="Tool Calling",
206
+ ) -> Union[str, None]:
207
+ """
208
+ Generate text based on user queries.
209
+ Args:
210
+ query: User's query
211
+ model: LLM like "gpt-4o"
212
+ tools: list of tools such as Search and Retrieval
213
+ image_urls: List of URLs for images
214
+ temperature: Value between 0 and 1. Defaults to 0.7
215
+ agent_type: 'Tool Calling' or 'ReAct'
216
+ Return:
217
+ generated text
218
+ """
219
+
220
+ try:
221
+ # Ensure retriever tool is included when "Retrieval" is selected
222
+ if "Retrieval" in st.session_state.tool_names[0]:
223
+ if st.session_state.retriever_tool:
224
+ retriever_tool_name = "retriever" # Ensure naming consistency
225
+ if retriever_tool_name not in [tool.name for tool in tools]:
226
+ tools.append(st.session_state.retriever_tool)
227
+ st.write(f"✅ **{retriever_tool_name} tool has been added successfully.**")
228
+ else:
229
+ st.error("❌ Retriever tool is not initialized. Please create a vector store first.")
230
+ return None # Exit early to avoid broken tool usage
231
+
232
+ # Debugging: Print final tools list
233
+ st.write("**Final Tools Being Used:**", [tool.name for tool in tools])
234
+
235
+ if "retriever" in [tool.name for tool in tools]:
236
+ st.success("✅ Retriever tool is confirmed and ready for use.")
237
+ elif "Retrieval" in st.session_state.tool_names[0]:
238
+ st.warning("⚠️ 'Retrieval' was selected but the retriever tool is missing!")
239
+
240
+ # Initialize the LLM model
241
+ llm = get_chat_model(model, temperature, [StreamHandler(st.empty())])
242
+ if llm is None:
243
+ st.error(f"❌ Unsupported model: {model}", icon="🚨")
244
+ return None
245
+
246
+ # Prepare chat history
247
+ if agent_type == "Tool Calling":
248
+ chat_history = st.session_state.history
249
+ else:
250
+ chat_history = message_history_to_string()
251
+
252
+ history_query = {"chat_history": chat_history, "input": query}
253
+
254
+ # Generate message content
255
+ message_with_no_image = st.session_state.chat_prompt.invoke(history_query)
256
+ message_content = message_with_no_image.messages[0].content
257
+
258
+ if image_urls:
259
+ # Handle images if provided
260
+ generated_text = process_with_images(llm, message_content, image_urls)
261
+ human_message = HumanMessage(
262
+ content=query, additional_kwargs={"image_urls": image_urls}
263
+ )
264
+ elif tools:
265
+ # Use tools for query execution
266
+ generated_text = process_with_tools(
267
+ llm, tools, agent_type, st.session_state.agent_prompt, history_query
268
+ )
269
+ human_message = HumanMessage(content=query)
270
+ else:
271
+ # Fall back to basic query execution without tools
272
+ generated_text = llm.invoke(message_with_no_image).content
273
+ human_message = HumanMessage(content=query)
274
+
275
+ # Convert response into plain text
276
+ if isinstance(generated_text, list):
277
+ generated_text = generated_text[0]["text"]
278
+
279
+ # Update conversation history
280
+ st.session_state.history.append(human_message)
281
+ st.session_state.history.append(AIMessage(content=generated_text))
282
+
283
+ return generated_text
284
+
285
+ except Exception as e:
286
+ st.error(f"An error occurred: {e}", icon="🚨")
287
+ return None
288
+
289
+
290
+ def openai_create_image(
291
+ description: str, model: str="dall-e-3", size: str="1024x1024"
292
+ ) -> Optional[str]:
293
+
294
+ """
295
+ Generate image based on user description.
296
+ Args:
297
+ description: User description
298
+ model: Default set to "dall-e-3"
299
+ size: Pixel size of the generated image
300
+ Return:
301
+ URL of the generated image
302
+ """
303
+
304
+ try:
305
+ with st.spinner("AI is generating..."):
306
+ response = st.session_state.openai.images.generate(
307
+ model=model,
308
+ prompt=description,
309
+ size=size,
310
+ quality="standard",
311
+ n=1,
312
+ )
313
+ image_url = response.data[0].url
314
+ except Exception as e:
315
+ image_url = None
316
+ st.error(f"An error occurred: {e}", icon="🚨")
317
+
318
+ return image_url
319
+
320
+
321
+ def get_vector_store(uploaded_files: List[UploadedFile]) -> Optional[FAISS]:
322
+ """
323
+ Take a list of UploadedFile objects as input, and return a FAISS vector store.
324
+ """
325
+ if not uploaded_files:
326
+ return None
327
+
328
+ documents = []
329
+ loader_map = {
330
+ ".pdf": PyPDFLoader,
331
+ ".txt": TextLoader,
332
+ ".docx": Docx2txtLoader
333
+ }
334
+
335
+ try:
336
+ # Use a temporary directory instead of a fixed 'files/' directory
337
+ with TemporaryDirectory() as temp_dir:
338
+ for uploaded_file in uploaded_files:
339
+ # Create a temporary file in the system's temporary directory
340
+ with NamedTemporaryFile(dir=temp_dir, delete=False) as temp_file:
341
+ temp_file.write(uploaded_file.getbuffer())
342
+ filepath = temp_file.name
343
+
344
+ file_ext = os.path.splitext(uploaded_file.name.lower())[1]
345
+ loader_class = loader_map.get(file_ext)
346
+ if not loader_class:
347
+ st.error(f"Unsupported file type: {file_ext}", icon="🚨")
348
+ return None
349
+
350
+ # Load the document using the selected loader
351
+ loader = loader_class(filepath)
352
+ documents.extend(loader.load())
353
+
354
+ with st.spinner("Vector store in preparation..."):
355
+ text_splitter = RecursiveCharacterTextSplitter(
356
+ chunk_size=1000, chunk_overlap=200
357
+ )
358
+ doc = text_splitter.split_documents(documents)
359
+
360
+ # Choose embeddings
361
+ if st.session_state.model_type == "GPT Models from OpenAI":
362
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-large", dimensions=1536)
363
+ else:
364
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
365
+
366
+ # Create FAISS vector database
367
+ vector_store = FAISS.from_documents(doc, embeddings)
368
+
369
+ except Exception as e:
370
+ vector_store = None
371
+ st.error(f"An error occurred: {e}", icon="🚨")
372
+
373
+ return vector_store
374
+
375
+
376
+
377
+ def get_retriever() -> None:
378
+ """
379
+ Upload document(s), create a vector store, prepare a retriever tool,
380
+ save the tool to the variable st.session_state.retriever_tool.
381
+ """
382
+
383
+ # Section Title
384
+ st.write("")
385
+ st.write("**Query Document(s)**")
386
+
387
+ # File Upload Input
388
+ uploaded_files = st.file_uploader(
389
+ label="Upload an article",
390
+ type=["txt", "pdf", "docx"],
391
+ accept_multiple_files=True,
392
+ label_visibility="collapsed",
393
+ key="document_upload_" + str(st.session_state.uploader_key),
394
+ )
395
+
396
+ # Check if files are uploaded
397
+ if uploaded_files:
398
+ # Use a unique button key to avoid duplicate presses
399
+ if st.button(label="Create the vector store", key=f"create_vector_{st.session_state.uploader_key}"):
400
+ st.info("Creating the vector store and initializing the retriever tool...")
401
+
402
+ # Attempt to create the vector store
403
+ vector_store = get_vector_store(uploaded_files)
404
+
405
+ if vector_store:
406
+ uploaded_file_names = [file.name for file in uploaded_files]
407
+ st.session_state.vector_store_message = (
408
+ f"Vector store for :blue[[{', '.join(uploaded_file_names)}]] is ready!"
409
+ )
410
+
411
+ # Initialize retriever and create tool
412
+ retriever = vector_store.as_retriever()
413
+ st.session_state.retriever_tool = create_retriever_tool(
414
+ retriever,
415
+ name="retriever",
416
+ description="Search uploaded documents for information when queried.",
417
+ )
418
+
419
+ # Add "Retrieval" to the tools list if not already present
420
+ if "Retrieval" not in st.session_state.tool_names[0]:
421
+ st.session_state.tool_names[0].append("Retrieval")
422
+
423
+ st.success("✅ Retriever tool has been successfully initialized and is ready to use.")
424
+
425
+ # Debugging output
426
+ st.write("**Current Tools:**", st.session_state.tool_names[0])
427
+ else:
428
+ st.error("❌ Failed to create vector store. Please check the uploaded files (supported formats: txt, pdf, docx).")
429
+ else:
430
+ st.info("Please upload document(s) to create the vector store.")
431
+
432
+
433
+
434
+
435
+ def display_text_with_equations(text: str):
436
+ # Replace inline LaTeX equation delimiters \\( ... \\) with $
437
+ modified_text = text.replace("\\(", "$").replace("\\)", "$")
438
+
439
+ # Replace block LaTeX equation delimiters \\[ ... \\] with $$
440
+ modified_text = modified_text.replace("\\[", "$$").replace("\\]", "$$")
441
+
442
+ # Use st.markdown to display the formatted text with equations
443
+ st.markdown(modified_text)
444
+
445
+
446
+ def read_audio(audio_bytes: bytes) -> Optional[str]:
447
+ """
448
+ Read audio bytes and return the corresponding text.
449
+ """
450
+ try:
451
+ audio_data = BytesIO(audio_bytes)
452
+ audio_data.name = "recorded_audio.wav" # dummy name
453
+
454
+ transcript = st.session_state.openai.audio.transcriptions.create(
455
+ model="whisper-1", file=audio_data
456
+ )
457
+ text = transcript.text
458
+ except Exception as e:
459
+ text = None
460
+ st.error(f"An error occurred: {e}", icon="🚨")
461
+
462
+ return text
463
+
464
+
465
+ def input_from_mic() -> Optional[str]:
466
+ """
467
+ Convert audio input from mic to text and return it.
468
+ If there is no audio input, None is returned.
469
+ """
470
+
471
+ time.sleep(0.5)
472
+ audio_bytes = audio_recorder(
473
+ pause_threshold=3.0, text="Speak", icon_size="2x",
474
+ recording_color="#e87070", neutral_color="#6aa36f"
475
+ )
476
+
477
+ if audio_bytes == st.session_state.audio_bytes or audio_bytes is None:
478
+ return None
479
+ else:
480
+ st.session_state.audio_bytes = audio_bytes
481
+ return read_audio(audio_bytes)
482
+
483
+
484
+ def perform_tts(text: str) -> Optional[HttpxBinaryResponseContent]:
485
+ """
486
+ Take text as input, perform text-to-speech (TTS),
487
+ and return an audio_response.
488
+ """
489
+
490
+ try:
491
+ with st.spinner("TTS in progress..."):
492
+ audio_response = st.session_state.openai.audio.speech.create(
493
+ model="tts-1",
494
+ voice="fable",
495
+ input=text,
496
+ )
497
+ except Exception as e:
498
+ audio_response = None
499
+ st.error(f"An error occurred: {e}", icon="🚨")
500
+
501
+ return audio_response
502
+
503
+
504
+ def play_audio(audio_response: HttpxBinaryResponseContent) -> None:
505
+ """
506
+ Take an audio response (a bytes-like object)
507
+ from TTS as input, and play the audio.
508
+ """
509
+
510
+ audio_data = audio_response.read()
511
+
512
+ # Encode audio data to base64
513
+ b64 = base64.b64encode(audio_data).decode("utf-8")
514
+
515
+ # Create a markdown string to embed the audio player with the base64 source
516
+ md = f"""
517
+ <audio controls autoplay style="width: 100%;">
518
+ <source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
519
+ Your browser does not support the audio element.
520
+ </audio>
521
+ """
522
+
523
+ # Use Streamlit to render the audio player
524
+ st.markdown(md, unsafe_allow_html=True)
525
+
526
+
527
+ def image_to_base64(image: Image) -> str:
528
+ """
529
+ Convert an image object from PIL to a base64-encoded image,
530
+ and return the resulting encoded image as a string to be used
531
+ in place of a URL.
532
+ """
533
+
534
+ # Convert the image to RGB mode if necessary
535
+ if image.mode != "RGB":
536
+ image = image.convert("RGB")
537
+
538
+ # Save the image to a BytesIO object
539
+ buffered_image = BytesIO()
540
+ image.save(buffered_image, format="JPEG")
541
+
542
+ # Convert BytesIO to bytes and encode to base64
543
+ img_str = base64.b64encode(buffered_image.getvalue())
544
+
545
+ # Convert bytes to string
546
+ base64_image = img_str.decode("utf-8")
547
+
548
+ return f"data:image/jpeg;base64,{base64_image}"
549
+
550
+
551
+ def shorten_image(image: Image, max_pixels: int=1024) -> Image:
552
+ """
553
+ Take an Image object as input, and shorten the image size
554
+ if the image is greater than max_pixels x max_pixels.
555
+ """
556
+
557
+ if max(image.width, image.height) > max_pixels:
558
+ if image.width > image.height:
559
+ new_width, new_height = 1024, image.height * 1024 // image.width
560
+ else:
561
+ new_width, new_height = image.width * 1024 // image.height, 1024
562
+
563
+ image = image.resize((new_width, new_height))
564
+
565
+ return image
566
+
567
+
568
+ def upload_image_files_return_urls(
569
+ type: List[str]=["jpg", "jpeg", "png", "bmp"]
570
+ ) -> List[str]:
571
+
572
+ """
573
+ Upload image files, convert them to base64-encoded images, and
574
+ return the list of the resulting encoded images to be used
575
+ in place of URLs.
576
+ """
577
+
578
+ st.write("")
579
+ st.write("**Query Image(s)**")
580
+ source = st.radio(
581
+ label="Image selection",
582
+ options=("Uploaded", "From URL"),
583
+ horizontal=True,
584
+ label_visibility="collapsed",
585
+ )
586
+ image_urls = []
587
+
588
+ if source == "Uploaded":
589
+ uploaded_files = st.file_uploader(
590
+ label="Upload images",
591
+ type=type,
592
+ accept_multiple_files=True,
593
+ label_visibility="collapsed",
594
+ key="image_upload_" + str(st.session_state.uploader_key),
595
+ )
596
+ if uploaded_files:
597
+ try:
598
+ for image_file in uploaded_files:
599
+ image = Image.open(image_file)
600
+ thumbnail = shorten_image(image, 300)
601
+ st.image(thumbnail)
602
+ image = shorten_image(image, 1024)
603
+ image_urls.append(image_to_base64(image))
604
+ except UnidentifiedImageError as e:
605
+ st.error(f"An error occurred: {e}", icon="🚨")
606
+ else:
607
+ image_url = st.text_input(
608
+ label="URL of the image",
609
+ label_visibility="collapsed",
610
+ key="image_url_" + str(st.session_state.uploader_key),
611
+ )
612
+ if image_url:
613
+ if is_url(image_url):
614
+ st.image(image_url)
615
+ image_urls = [image_url]
616
+ else:
617
+ st.error("Enter a proper URL", icon="🚨")
618
+
619
+ return image_urls
620
+
621
+
622
+ def fig_to_base64(fig: Figure) -> str:
623
+ """
624
+ Convert a Figure object to a base64-encoded image, and return
625
+ the resulting encoded image to be used in place of a URL.
626
+ """
627
+
628
+ with BytesIO() as buffer:
629
+ fig.savefig(buffer, format="JPEG")
630
+ buffer.seek(0)
631
+ image = Image.open(buffer)
632
+
633
+ return image_to_base64(image)
634
+
635
+
636
+ def is_url(text: str) -> bool:
637
+ """
638
+ Determine whether text is a URL or not.
639
+ """
640
+
641
+ regex = r"(http|https)://([\w_-]+(?:\.[\w_-]+)+)(:\S*)?"
642
+ p = re.compile(regex)
643
+ match = p.match(text)
644
+ if match:
645
+ return True
646
+ else:
647
+ return False
648
+
649
+
650
+ def reset_conversation() -> None:
651
+ """
652
+ Reset the session_state variables for resetting the conversation.
653
+ """
654
+
655
+ st.session_state.history = []
656
+ st.session_state.ai_role[1] = st.session_state.ai_role[0]
657
+ st.session_state.prompt_exists = False
658
+ st.session_state.temperature[1] = st.session_state.temperature[0]
659
+ st.session_state.audio_response = None
660
+ st.session_state.vector_store_message = None
661
+ st.session_state.tool_names[1] = st.session_state.tool_names[0]
662
+ st.session_state.agent_type[1] = st.session_state.agent_type[0]
663
+ st.session_state.retriever_tool = None
664
+ st.session_state.uploader_key = 0
665
+
666
+
667
+ def switch_between_apps() -> None:
668
+ """
669
+ Keep the chat settings when switching the mode.
670
+ """
671
+
672
+ st.session_state.temperature[1] = st.session_state.temperature[0]
673
+ st.session_state.ai_role[1] = st.session_state.ai_role[0]
674
+ st.session_state.tool_names[1] = st.session_state.tool_names[0]
675
+ st.session_state.agent_type[1] = st.session_state.agent_type[0]
676
+
677
+
678
+ @tool
679
+ def python_repl(
680
+ code: Annotated[str, "The python code to execute to generate your chart."],
681
+ ):
682
+ """Use this to execute python code. If you want to see the output of a value,
683
+ you should print it out with `print(...)`. This is visible to the user."""
684
+ try:
685
+ result = PythonREPL().run(code)
686
+ except BaseException as e:
687
+ return f"Failed to execute. Error: {repr(e)}"
688
+ result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
689
+ return (
690
+ result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
691
+ )
692
+
693
+
694
+ def set_tools() -> List[Tool]:
695
+ """
696
+ Set and return the tools for the agent. Tools that can be selected
697
+ are internet_search, arxiv, wikipedia, python_repl, and retrieval.
698
+ A Bing Subscription Key or Google CSE ID is required for internet_search.
699
+ """
700
+
701
+ class MySearchToolInput(BaseModel):
702
+ query: str = Field(description="search query to look up")
703
+
704
+ # Load tools
705
+ arxiv = load_tools(["arxiv"])[0]
706
+ wikipedia = load_tools(["wikipedia"])[0]
707
+ # Python REPL is directly used here
708
+ tool_dictionary = {
709
+ "ArXiv": arxiv,
710
+ "Wikipedia": wikipedia,
711
+ "Python_REPL": python_repl,
712
+ "Retrieval": st.session_state.retriever_tool if st.session_state.retriever_tool else None
713
+ }
714
+ tool_options = ["ArXiv", "Wikipedia", "Python_REPL", "Retrieval"]
715
+
716
+ # Add Search tool dynamically if credentials are valid
717
+ if st.session_state.bing_subscription_validity:
718
+ search = BingSearchAPIWrapper()
719
+ elif st.session_state.google_cse_id_validity:
720
+ search = GoogleSearchAPIWrapper()
721
+ else:
722
+ search = None
723
+
724
+ if search is not None:
725
+ internet_search = Tool(
726
+ name="internet_search",
727
+ description=(
728
+ "A search engine for comprehensive, accurate, and trusted results. "
729
+ "Useful for when you need to answer questions about current events. "
730
+ "Input should be a search query."
731
+ ),
732
+ func=partial(search.results, num_results=5),
733
+ args_schema=MySearchToolInput,
734
+ )
735
+ tool_options.insert(0, "Search")
736
+ tool_dictionary["Search"] = internet_search
737
+
738
+ # UI for selecting tools
739
+ st.write("")
740
+ st.write("**Tools**")
741
+ tool_names = st.multiselect(
742
+ label="assistant tools",
743
+ options=tool_options,
744
+ default=st.session_state.tool_names[1],
745
+ label_visibility="collapsed",
746
+ )
747
+
748
+ # Instructions if Search tool is unavailable
749
+ if "Search" not in tool_options:
750
+ st.write(
751
+ "<small>Tools are disabled when images are uploaded and queried. "
752
+ "To search the internet, obtain your Bing Subscription Key "
753
+ "[here](https://portal.azure.com/) or Google CSE ID "
754
+ "[here](https://programmablesearchengine.google.com/about/), "
755
+ "and enter it in the sidebar. Once entered, 'Search' will be displayed "
756
+ "in the list of tools. Note also that PythonREPL from LangChain is still "
757
+ "in the experimental phase, so caution is advised.</small>",
758
+ unsafe_allow_html=True,
759
+ )
760
+ else:
761
+ st.write(
762
+ "<small>Tools are disabled when images are uploaded and queried. "
763
+ "Note also that PythonREPL from LangChain is still in the experimental phase, "
764
+ "so caution is advised.</small>",
765
+ unsafe_allow_html=True,
766
+ )
767
+
768
+ # Handle Retrieval tool initialization
769
+ if "Retrieval" in tool_names:
770
+ if not st.session_state.retriever_tool:
771
+ st.info("Creating the vector store and initializing the retriever tool...")
772
+ get_retriever()
773
+ if st.session_state.retriever_tool:
774
+ st.success("Retriever tool is ready for querying.")
775
+ tool_dictionary["Retrieval"] = st.session_state.retriever_tool
776
+ else:
777
+ st.error("Failed to initialize the retriever tool. Please upload the document again.")
778
+ tool_names.remove("Retrieval") # Prevent broken Retrieval tool
779
+
780
+ # Final tool selection
781
+ tools = [
782
+ tool_dictionary[key]
783
+ for key in tool_names if tool_dictionary[key] is not None
784
+ ]
785
+
786
+ st.write("**Tools selected in set_tools:**", [tool.name for tool in tools])
787
+ st.session_state.tool_names[0] = tool_names
788
+
789
+ return tools
790
+
791
+
792
+
793
+ def set_prompts(agent_type: Literal["Tool Calling", "ReAct"]) -> None:
794
+ """
795
+ Set chat and agent prompts for two different types of agents:
796
+ Tool Calling and ReAct.
797
+ """
798
+
799
+ if agent_type == "Tool Calling":
800
+ st.session_state.chat_prompt = ChatPromptTemplate.from_messages([
801
+ (
802
+ "system",
803
+ f"{st.session_state.ai_role[0]} Your goal is to provide "
804
+ "answers to human inquiries. Should the information not "
805
+ "be available, inform the human explicitly that "
806
+ "the answer could not be found."
807
+ ),
808
+ MessagesPlaceholder(variable_name="chat_history"),
809
+ ("human", "{input}"),
810
+ ])
811
+ st.session_state.agent_prompt = ChatPromptTemplate.from_messages([
812
+ (
813
+ "system",
814
+ f"{st.session_state.ai_role[0]} Your goal is to provide answers to human inquiries. "
815
+ "You should specify the source of your answers, whether they are based on internet search "
816
+ "results ('internet_search'), scientific articles from arxiv.org ('arxiv'), Wikipedia documents ('wikipedia'), "
817
+ "uploaded documents ('retriever'), or your general knowledge. "
818
+ "Use the 'retriever' tool to answer questions specifically related to uploaded documents. "
819
+ "If you cannot find relevant information in the documents using the 'retriever' tool, explicitly inform the user. "
820
+ "Use Markdown syntax and include relevant sources, such as links (URLs)."
821
+ ),
822
+ MessagesPlaceholder(variable_name="chat_history", optional=True),
823
+ ("human", "{input}"),
824
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
825
+ ])
826
+ else:
827
+ st.session_state.chat_prompt = ChatPromptTemplate.from_template(
828
+ f"{st.session_state.ai_role[0]} "
829
+ "Your goal is to provide answers to human inquiries. "
830
+ "Should the information not be available, inform the human "
831
+ "explicitly that the answer could not be found.\n\n"
832
+ "{chat_history}\n\nHuman: {input}\n\n"
833
+ "AI: "
834
+ )
835
+ st.session_state.agent_prompt = ChatPromptTemplate.from_template(
836
+ f"{st.session_state.ai_role[0]} "
837
+ "Your goal is to provide answers to human inquiries. "
838
+ "When giving your answers, tell the human what your response "
839
+ "is based on and which tools you use. Use Markdown syntax "
840
+ "and include relevant sources, such as links (URLs), following "
841
+ "MLA format. Should the information not be available, inform "
842
+ "the human explicitly that the answer could not be found.\n\n"
843
+ "TOOLS:\n"
844
+ "------\n\n"
845
+ "You have access to the following tools:\n\n"
846
+ "{tools}\n\n"
847
+ "To use a tool, please use the following format:\n\n"
848
+ "Thought: Do I need to use a tool? Yes\n"
849
+ "Action: the action to take, should be one of [{tool_names}]\n"
850
+ "Action Input: the input to the action\n"
851
+ "Observation: the result of the action\n\n"
852
+ "When you have a response to say to the Human, "
853
+ "or if you do not need to use a tool, you MUST use "
854
+ "the format:\n\n"
855
+ "Thought: Do I need to use a tool? No\n"
856
+ "Final Answer: [your response here]\n\n"
857
+ "Begin!\n\n"
858
+ "Previous conversation history:\n\n"
859
+ "{chat_history}\n\n"
860
+ "New input: {input}\n"
861
+ "{agent_scratchpad}"
862
+ )
863
+
864
+
865
+ def print_conversation(no_of_msgs: Union[Literal["All"], int]) -> None:
866
+ """
867
+ Print the conversation stored in st.session_state.history.
868
+ """
869
+
870
+ if no_of_msgs == "All":
871
+ no_of_msgs = len(st.session_state.history)
872
+
873
+ for msg in st.session_state.history[-no_of_msgs:]:
874
+ if isinstance(msg, HumanMessage):
875
+ with st.chat_message("human"):
876
+ st.write(msg.content)
877
+ else:
878
+ with st.chat_message("ai"):
879
+ display_text_with_equations(msg.content)
880
+
881
+ if urls := msg.additional_kwargs.get("image_urls"):
882
+ for url in urls:
883
+ st.image(url)
884
+
885
+ # Play TTS
886
+ if (
887
+ st.session_state.model_type == "GPT Models from OpenAI"
888
+ and st.session_state.audio_response is not None
889
+ ):
890
+ play_audio(st.session_state.audio_response)
891
+ st.session_state.audio_response = None
892
+
893
+
894
+ def serialize_messages(
895
+ messages: List[Union[HumanMessage, AIMessage]]
896
+ ) -> List[Dict]:
897
+
898
+ """
899
+ Serialize the list of messages into a list of dicts
900
+ """
901
+
902
+ return [msg.dict() for msg in messages]
903
+
904
+
905
+ def deserialize_messages(
906
+ serialized_messages: List[Dict]
907
+ ) -> List[Union[HumanMessage, AIMessage]]:
908
+
909
+ """
910
+ Deserialize the list of messages from a list of dicts
911
+ """
912
+
913
+ deserialized_messages = []
914
+ for msg in serialized_messages:
915
+ if msg['type'] == 'human':
916
+ deserialized_messages.append(HumanMessage(**msg))
917
+ elif msg['type'] == 'ai':
918
+ deserialized_messages.append(AIMessage(**msg))
919
+ return deserialized_messages
920
+
921
+
922
+ def show_uploader() -> None:
923
+ """
924
+ Set the flag to show the uploader.
925
+ """
926
+
927
+ st.session_state.show_uploader = True
928
+
929
+
930
+ def check_conversation_keys(lst: List[Dict[str, Any]]) -> bool:
931
+ """
932
+ Check if all items in the given list are valid conversation entries.
933
+ """
934
+
935
+ return all(
936
+ isinstance(item, dict) and
937
+ isinstance(item.get("content"), str) and
938
+ isinstance(item.get("type"), str) and
939
+ isinstance(item.get("additional_kwargs"), dict)
940
+ for item in lst
941
+ )
942
+
943
+
944
+ def load_conversation() -> bool:
945
+ """
946
+ Load the conversation from a JSON file
947
+ """
948
+
949
+ st.write("")
950
+ st.write("**Choose a (JSON) conversation file**")
951
+ uploaded_file = st.file_uploader(
952
+ label="Load conversation", type="json", label_visibility="collapsed"
953
+ )
954
+ if uploaded_file:
955
+ try:
956
+ data = json.load(uploaded_file)
957
+ if isinstance(data, list) and check_conversation_keys(data):
958
+ st.session_state.history = deserialize_messages(data)
959
+ return True
960
+ st.error(
961
+ f"The uploaded data does not conform to the expected format.", icon="🚨"
962
+ )
963
+ except Exception as e:
964
+ st.error(f"An error occurred: {e}", icon="🚨")
965
+
966
+ return False
967
+
968
+
969
+ def create_text(model: str) -> None:
970
+ """
971
+ Take an LLM as input and generate text based on user input
972
+ by calling run_agent().
973
+ """
974
+
975
+ # initial system prompts
976
+ general_role = "You are a helpful AI assistant."
977
+ english_teacher = (
978
+ "You are an AI English teacher who analyzes texts and corrects "
979
+ "any grammatical issues if necessary."
980
+ )
981
+ translator = (
982
+ "You are an AI translator who translates English into Korean "
983
+ "and Korean into English."
984
+ )
985
+ coding_adviser = (
986
+ "You are an AI expert in coding who provides advice on "
987
+ "good coding styles."
988
+ )
989
+ science_assistant = "You are an AI science assistant."
990
+ roles = (
991
+ general_role, english_teacher, translator,
992
+ coding_adviser, science_assistant
993
+ )
994
+
995
+ with st.sidebar:
996
+ st.write("")
997
+ type_options = ("Tool Calling", "ReAct")
998
+ st.write("**Agent Type**")
999
+ st.session_state.agent_type[0] = st.sidebar.radio(
1000
+ label="Agent Type",
1001
+ options=type_options,
1002
+ index=type_options.index(st.session_state.agent_type[1]),
1003
+ label_visibility="collapsed",
1004
+ )
1005
+ agent_type = st.session_state.agent_type[0]
1006
+ if st.session_state.model_type == "GPT Models from OpenAI":
1007
+ st.write("")
1008
+ st.write("**Text to Speech**")
1009
+ st.session_state.tts = st.radio(
1010
+ label="TTS",
1011
+ options=("Enabled", "Disabled", "Auto"),
1012
+ # horizontal=True,
1013
+ index=1,
1014
+ label_visibility="collapsed",
1015
+ )
1016
+ st.write("")
1017
+ st.write("**Temperature**")
1018
+ st.session_state.temperature[0] = st.slider(
1019
+ label="Temperature (higher $\Rightarrow$ more random)",
1020
+ min_value=0.0,
1021
+ max_value=1.0,
1022
+ value=st.session_state.temperature[1],
1023
+ step=0.1,
1024
+ format="%.1f",
1025
+ label_visibility="collapsed",
1026
+ )
1027
+ st.write("")
1028
+ st.write("**Messages to Show**")
1029
+ no_of_msgs = st.radio(
1030
+ label="$\\textsf{Messages to show}$",
1031
+ options=("All", 20, 10),
1032
+ label_visibility="collapsed",
1033
+ horizontal=True,
1034
+ index=2,
1035
+ )
1036
+
1037
+ st.write("")
1038
+ st.write("##### Message to AI")
1039
+ st.session_state.ai_role[0] = st.selectbox(
1040
+ label="AI's role",
1041
+ options=roles,
1042
+ index=roles.index(st.session_state.ai_role[1]),
1043
+ label_visibility="collapsed",
1044
+ )
1045
+
1046
+ if st.session_state.ai_role[0] != st.session_state.ai_role[1]:
1047
+ reset_conversation()
1048
+ st.rerun()
1049
+
1050
+ st.write("")
1051
+ st.write("##### Conversation with AI")
1052
+
1053
+ # Print conversation
1054
+ print_conversation(no_of_msgs)
1055
+
1056
+ # Reset, download, or load the conversation
1057
+ c1, c2, c3 = st.columns(3)
1058
+ c1.button(
1059
+ label="$~\:\,\,$Reset$~\:\,\,$",
1060
+ on_click=reset_conversation
1061
+ )
1062
+ c2.download_button(
1063
+ label="Download",
1064
+ data=json.dumps(serialize_messages(st.session_state.history), indent=4),
1065
+ file_name="conversation_with_agent.json",
1066
+ mime="application/json",
1067
+ )
1068
+ c3.button(
1069
+ label="$~~\:\,$Load$~~\:\,$",
1070
+ on_click=show_uploader,
1071
+ )
1072
+
1073
+ if st.session_state.show_uploader and load_conversation():
1074
+ st.session_state.show_uploader = False
1075
+ st.rerun()
1076
+
1077
+ # Set the agent prompts and tools
1078
+ set_prompts(agent_type)
1079
+ tools = set_tools()
1080
+ st.write("**Tools passed to run_agent:**", [tool.name for tool in tools])
1081
+
1082
+
1083
+ image_urls = []
1084
+ with st.sidebar:
1085
+ image_urls = upload_image_files_return_urls()
1086
+
1087
+ if st.session_state.model_type == "GPT Models from OpenAI":
1088
+ audio_input = input_from_mic()
1089
+ if audio_input is not None:
1090
+ query = audio_input
1091
+ st.session_state.prompt_exists = True
1092
+ st.session_state.mic_used = True
1093
+
1094
+ # Use your keyboard
1095
+ text_input = st.chat_input(placeholder="Enter your query")
1096
+
1097
+ if text_input:
1098
+ query = text_input.strip()
1099
+ st.session_state.prompt_exists = True
1100
+
1101
+ if st.session_state.prompt_exists:
1102
+ with st.chat_message("human"):
1103
+ st.write(query)
1104
+
1105
+ with st.chat_message("ai"):
1106
+ generated_text = run_agent(
1107
+ query=query,
1108
+ model=model,
1109
+ tools=tools,
1110
+ image_urls=image_urls,
1111
+ temperature=st.session_state.temperature[0],
1112
+ agent_type=agent_type,
1113
+ )
1114
+ fig = plt.gcf()
1115
+ if fig and fig.get_axes():
1116
+ generated_image_url = fig_to_base64(fig)
1117
+ st.session_state.history[-1].additional_kwargs["image_urls"] = [
1118
+ generated_image_url
1119
+ ]
1120
+ if (
1121
+ st.session_state.model_type == "GPT Models from OpenAI"
1122
+ and generated_text is not None
1123
+ ):
1124
+ # TTS under two conditions
1125
+ cond1 = st.session_state.tts == "Enabled"
1126
+ cond2 = st.session_state.tts == "Auto" and st.session_state.mic_used
1127
+ if cond1 or cond2:
1128
+ st.session_state.audio_response = perform_tts(generated_text)
1129
+ st.session_state.mic_used = False
1130
+
1131
+ st.session_state.prompt_exists = False
1132
+
1133
+ if generated_text is not None:
1134
+ st.session_state.uploader_key += 1
1135
+ st.rerun()
1136
+
1137
+
1138
+ def create_image(model: str) -> None:
1139
+ """
1140
+ Generate image based on user description by calling openai_create_image().
1141
+ """
1142
+
1143
+ # Set the image size
1144
+ with st.sidebar:
1145
+ st.write("")
1146
+ st.write("**Pixel size**")
1147
+ image_size = st.radio(
1148
+ label="$\\hspace{0.1em}\\texttt{Pixel size}$",
1149
+ options=("1024x1024", "1792x1024", "1024x1792"),
1150
+ # horizontal=True,
1151
+ index=0,
1152
+ label_visibility="collapsed",
1153
+ )
1154
+
1155
+ st.write("")
1156
+ st.write("##### Description for your image")
1157
+
1158
+ if st.session_state.image_url is not None:
1159
+ st.info(st.session_state.image_description)
1160
+ st.image(image=st.session_state.image_url, use_column_width=True)
1161
+
1162
+ # Get an image description using the microphone
1163
+ if st.session_state.model_type == "GPT Models from OpenAI":
1164
+ audio_input = input_from_mic()
1165
+ if audio_input is not None:
1166
+ st.session_state.image_description = audio_input
1167
+ st.session_state.prompt_exists = True
1168
+
1169
+ # Get an image description using the keyboard
1170
+ text_input = st.chat_input(
1171
+ placeholder="Enter a description for your image",
1172
+ )
1173
+ if text_input:
1174
+ st.session_state.image_description = text_input.strip()
1175
+ st.session_state.prompt_exists = True
1176
+
1177
+ if st.session_state.prompt_exists:
1178
+ st.session_state.image_url = openai_create_image(
1179
+ st.session_state.image_description, model, image_size
1180
+ )
1181
+ st.session_state.prompt_exists = False
1182
+ if st.session_state.image_url is not None:
1183
+ st.rerun()
1184
+
1185
+
1186
+ def create_text_image() -> None:
1187
+ """
1188
+ Generate text or image by using LLM models like 'gpt-4o'.
1189
+ """
1190
+
1191
+ page_title = "LangChain LLM Agent"
1192
+ page_icon = "📚"
1193
+
1194
+ st.set_page_config(
1195
+ page_title=page_title,
1196
+ page_icon=page_icon,
1197
+ layout="centered"
1198
+ )
1199
+
1200
+ st.write(f"## {page_icon} $\,${page_title}")
1201
+
1202
+ # Initialize all the session state variables
1203
+ initialize_session_state_variables()
1204
+
1205
+ # Define model options directly here
1206
+ model_options = ["gpt-4o-mini", "gpt-4o", "dall-e-3"]
1207
+
1208
+ # Sidebar content
1209
+ with st.sidebar:
1210
+ st.write("**Select a Model**")
1211
+ model = st.radio(
1212
+ label="Models",
1213
+ options=model_options,
1214
+ index=1, # Default to the second option
1215
+ label_visibility="collapsed",
1216
+ on_change=switch_between_apps,
1217
+ )
1218
+
1219
+ st.write("---")
1220
+ st.write("xyz", unsafe_allow_html=True)
1221
+
1222
+ # Main logic for generating text or image
1223
+ if model == "dall-e-3":
1224
+ create_image(model)
1225
+ else:
1226
+ create_text(model)
1227
+
1228
+ if __name__ == "__main__":
1229
+ create_text_image()