|
|
|
import streamlit as st |
|
import os, base64, re, requests, datetime, time, json |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
from functools import partial |
|
from tempfile import NamedTemporaryFile |
|
from audio_recorder_streamlit import audio_recorder |
|
from PIL import Image, UnidentifiedImageError |
|
from openai import OpenAI |
|
from langchain_openai import ChatOpenAI |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_anthropic import ChatAnthropic |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
from langchain_google_community import GoogleSearchAPIWrapper |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.schema import HumanMessage, AIMessage |
|
from langchain_community.utilities import BingSearchAPIWrapper |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.document_loaders import Docx2txtLoader |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.tools import Tool, tool |
|
from langchain.tools.retriever import create_retriever_tool |
|
|
|
from langchain.agents import create_tool_calling_agent |
|
from langchain.agents import create_react_agent |
|
from langchain.agents import AgentExecutor |
|
from langchain_community.agent_toolkits.load_tools import load_tools |
|
|
|
from langchain_experimental.utilities import PythonREPL |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from pydantic import BaseModel, Field |
|
|
|
from typing import Union, List, Literal, Optional, Dict, Any, Annotated |
|
from matplotlib.figure import Figure |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
from openai._legacy_response import HttpxBinaryResponseContent |
|
from tempfile import NamedTemporaryFile, TemporaryDirectory |
|
|
|
|
|
try: |
|
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] |
|
os.environ["BING_SUBSCRIPTION_KEY"] = st.secrets.get("BING_SUBSCRIPTION_KEY", "") |
|
os.environ["GOOGLE_API_KEY"] = st.secrets.get("GOOGLE_API_KEY", "") |
|
os.environ["GOOGLE_CSE_ID"] = st.secrets.get("GOOGLE_CSE_ID", "") |
|
except KeyError as e: |
|
st.error(f"Missing required secret: {e}. Please set it in Hugging Face Space secrets.") |
|
st.stop() |
|
|
|
def initialize_session_state_variables() -> None: |
|
""" |
|
Initialize all the session state variables. |
|
""" |
|
default_values = { |
|
"ready": False, |
|
"openai": None, |
|
"history": [], |
|
"model_type": "GPT Models from OpenAI", |
|
"agent_type": 2 * ["Tool Calling"], |
|
"ai_role": 2 * ["You are a helpful AI assistant."], |
|
"prompt_exists": False, |
|
"temperature": [0.7, 0.7], |
|
"audio_bytes": None, |
|
"mic_used": False, |
|
"audio_response": None, |
|
"image_url": None, |
|
"image_description": None, |
|
"uploader_key": 0, |
|
"tool_names": [[], []], |
|
"bing_subscription_validity": False, |
|
"google_cse_id_validity": False, |
|
"vector_store_message": None, |
|
"retriever_tool": None, |
|
"show_uploader": False |
|
} |
|
|
|
for key, value in default_values.items(): |
|
if key not in st.session_state: |
|
st.session_state[key] = value |
|
|
|
|
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: Any, **kwargs) -> None: |
|
new_text = self._extract_text(token) |
|
if new_text: |
|
self.text += new_text |
|
self.container.markdown(self.text) |
|
|
|
def _extract_text(self, token: Any) -> str: |
|
if isinstance(token, str): |
|
return token |
|
elif isinstance(token, list): |
|
return ''.join(self._extract_text(t) for t in token) |
|
elif isinstance(token, dict): |
|
return token.get('text', '') |
|
else: |
|
return str(token) |
|
|
|
|
|
def check_api_keys() -> None: |
|
|
|
st.session_state.ready = False |
|
|
|
|
|
def message_history_to_string(extra_space: bool=True) -> str: |
|
""" |
|
Return a string of the chat history contained in |
|
st.session_state.history. |
|
""" |
|
|
|
history_list = [] |
|
for msg in st.session_state.history: |
|
if isinstance(msg, HumanMessage): |
|
history_list.append(f"Human: {msg.content}") |
|
else: |
|
history_list.append(f"AI: {msg.content}") |
|
new_lines = "\n\n" if extra_space else "\n" |
|
|
|
return new_lines.join(history_list) |
|
|
|
|
|
def get_chat_model( |
|
model: str, |
|
temperature: float, |
|
callbacks: List[BaseCallbackHandler] |
|
) -> Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI, None]: |
|
|
|
""" |
|
Get the appropriate chat model based on the given model name. |
|
""" |
|
|
|
model_map = { |
|
"gpt-": ChatOpenAI, |
|
} |
|
for prefix, ModelClass in model_map.items(): |
|
if model.startswith(prefix): |
|
return ModelClass( |
|
model=model, |
|
temperature=temperature, |
|
streaming=True, |
|
callbacks=callbacks |
|
) |
|
return None |
|
|
|
|
|
def process_with_images( |
|
llm: Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI], |
|
message_content: str, |
|
image_urls: List[str] |
|
) -> str: |
|
|
|
""" |
|
Process the given history query with associated images using a language model. |
|
""" |
|
|
|
content_with_images = ( |
|
[{"type": "text", "text": message_content}] + |
|
[{"type": "image_url", "image_url": {"url": url}} for url in image_urls] |
|
) |
|
message_with_images = [HumanMessage(content=content_with_images)] |
|
|
|
return llm.invoke(message_with_images).content |
|
|
|
|
|
def process_with_tools( |
|
llm: Union[ChatOpenAI, ChatAnthropic, ChatGoogleGenerativeAI], |
|
tools: List[Tool], |
|
agent_type: str, |
|
agent_prompt: str, |
|
history_query: dict |
|
) -> str: |
|
|
|
""" |
|
Create an AI agent based on the specified agent type and tools, |
|
then use this agent to process the given history query. |
|
""" |
|
|
|
if agent_type == "Tool Calling": |
|
agent = create_tool_calling_agent(llm, tools, agent_prompt) |
|
else: |
|
agent = create_react_agent(llm, tools, agent_prompt) |
|
|
|
agent_executor = AgentExecutor( |
|
agent=agent, tools=tools, max_iterations=5, verbose=False, |
|
handle_parsing_errors=True, |
|
) |
|
|
|
return agent_executor.invoke(history_query)["output"] |
|
|
|
|
|
def run_agent( |
|
query: str, |
|
model: str, |
|
tools: List[Tool], |
|
image_urls: List[str], |
|
temperature: float=0.7, |
|
agent_type: Literal["Tool Calling", "ReAct"]="Tool Calling", |
|
) -> Union[str, None]: |
|
""" |
|
Generate text based on user queries. |
|
Args: |
|
query: User's query |
|
model: LLM like "gpt-4o" |
|
tools: list of tools such as Search and Retrieval |
|
image_urls: List of URLs for images |
|
temperature: Value between 0 and 1. Defaults to 0.7 |
|
agent_type: 'Tool Calling' or 'ReAct' |
|
Return: |
|
generated text |
|
""" |
|
|
|
try: |
|
|
|
if "Retrieval" in st.session_state.tool_names[0]: |
|
if st.session_state.retriever_tool: |
|
retriever_tool_name = "retriever" |
|
if retriever_tool_name not in [tool.name for tool in tools]: |
|
tools.append(st.session_state.retriever_tool) |
|
st.write(f"✅ **{retriever_tool_name} tool has been added successfully.**") |
|
else: |
|
st.error("❌ Retriever tool is not initialized. Please create a vector store first.") |
|
return None |
|
|
|
|
|
st.write("**Final Tools Being Used:**", [tool.name for tool in tools]) |
|
|
|
if "retriever" in [tool.name for tool in tools]: |
|
st.success("✅ Retriever tool is confirmed and ready for use.") |
|
elif "Retrieval" in st.session_state.tool_names[0]: |
|
st.warning("⚠️ 'Retrieval' was selected but the retriever tool is missing!") |
|
|
|
|
|
llm = get_chat_model(model, temperature, [StreamHandler(st.empty())]) |
|
if llm is None: |
|
st.error(f"❌ Unsupported model: {model}", icon="🚨") |
|
return None |
|
|
|
|
|
if agent_type == "Tool Calling": |
|
chat_history = st.session_state.history |
|
else: |
|
chat_history = message_history_to_string() |
|
|
|
history_query = {"chat_history": chat_history, "input": query} |
|
|
|
|
|
message_with_no_image = st.session_state.chat_prompt.invoke(history_query) |
|
message_content = message_with_no_image.messages[0].content |
|
|
|
if image_urls: |
|
|
|
generated_text = process_with_images(llm, message_content, image_urls) |
|
human_message = HumanMessage( |
|
content=query, additional_kwargs={"image_urls": image_urls} |
|
) |
|
elif tools: |
|
|
|
generated_text = process_with_tools( |
|
llm, tools, agent_type, st.session_state.agent_prompt, history_query |
|
) |
|
human_message = HumanMessage(content=query) |
|
else: |
|
|
|
generated_text = llm.invoke(message_with_no_image).content |
|
human_message = HumanMessage(content=query) |
|
|
|
|
|
if isinstance(generated_text, list): |
|
generated_text = generated_text[0]["text"] |
|
|
|
|
|
st.session_state.history.append(human_message) |
|
st.session_state.history.append(AIMessage(content=generated_text)) |
|
|
|
return generated_text |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
return None |
|
|
|
|
|
def openai_create_image( |
|
description: str, model: str="dall-e-3", size: str="1024x1024" |
|
) -> Optional[str]: |
|
|
|
""" |
|
Generate image based on user description. |
|
|
|
Args: |
|
description: User description |
|
model: Default set to "dall-e-3" |
|
size: Pixel size of the generated image |
|
|
|
Return: |
|
URL of the generated image |
|
""" |
|
|
|
try: |
|
with st.spinner("AI is generating..."): |
|
response = st.session_state.openai.images.generate( |
|
model=model, |
|
prompt=description, |
|
size=size, |
|
quality="standard", |
|
n=1, |
|
) |
|
image_url = response.data[0].url |
|
except Exception as e: |
|
image_url = None |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
|
|
return image_url |
|
|
|
|
|
def get_vector_store(uploaded_files: List[UploadedFile]) -> Optional[FAISS]: |
|
""" |
|
Take a list of UploadedFile objects as input, and return a FAISS vector store. |
|
""" |
|
if not uploaded_files: |
|
return None |
|
|
|
documents = [] |
|
loader_map = { |
|
".pdf": PyPDFLoader, |
|
".txt": TextLoader, |
|
".docx": Docx2txtLoader |
|
} |
|
|
|
try: |
|
|
|
with TemporaryDirectory() as temp_dir: |
|
for uploaded_file in uploaded_files: |
|
|
|
with NamedTemporaryFile(dir=temp_dir, delete=False) as temp_file: |
|
temp_file.write(uploaded_file.getbuffer()) |
|
filepath = temp_file.name |
|
|
|
file_ext = os.path.splitext(uploaded_file.name.lower())[1] |
|
loader_class = loader_map.get(file_ext) |
|
if not loader_class: |
|
st.error(f"Unsupported file type: {file_ext}", icon="🚨") |
|
return None |
|
|
|
|
|
loader = loader_class(filepath) |
|
documents.extend(loader.load()) |
|
|
|
with st.spinner("Vector store in preparation..."): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, chunk_overlap=200 |
|
) |
|
doc = text_splitter.split_documents(documents) |
|
|
|
|
|
if st.session_state.model_type == "GPT Models from OpenAI": |
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-large", dimensions=1536) |
|
else: |
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") |
|
|
|
|
|
vector_store = FAISS.from_documents(doc, embeddings) |
|
|
|
except Exception as e: |
|
vector_store = None |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
|
|
return vector_store |
|
|
|
|
|
|
|
def get_retriever() -> None: |
|
""" |
|
Upload document(s), create a vector store, prepare a retriever tool, |
|
save the tool to the variable st.session_state.retriever_tool. |
|
""" |
|
|
|
|
|
st.write("") |
|
st.write("**Query Document(s)**") |
|
|
|
|
|
uploaded_files = st.file_uploader( |
|
label="Upload an article", |
|
type=["txt", "pdf", "docx"], |
|
accept_multiple_files=True, |
|
label_visibility="collapsed", |
|
key="document_upload_" + str(st.session_state.uploader_key), |
|
) |
|
|
|
|
|
if uploaded_files: |
|
|
|
if st.button(label="Create the vector store", key=f"create_vector_{st.session_state.uploader_key}"): |
|
st.info("Creating the vector store and initializing the retriever tool...") |
|
|
|
|
|
vector_store = get_vector_store(uploaded_files) |
|
|
|
if vector_store: |
|
uploaded_file_names = [file.name for file in uploaded_files] |
|
st.session_state.vector_store_message = ( |
|
f"Vector store for :blue[[{', '.join(uploaded_file_names)}]] is ready!" |
|
) |
|
|
|
|
|
retriever = vector_store.as_retriever() |
|
st.session_state.retriever_tool = create_retriever_tool( |
|
retriever, |
|
name="retriever", |
|
description="Search uploaded documents for information when queried.", |
|
) |
|
|
|
|
|
if "Retrieval" not in st.session_state.tool_names[0]: |
|
st.session_state.tool_names[0].append("Retrieval") |
|
|
|
st.success("✅ Retriever tool has been successfully initialized and is ready to use.") |
|
|
|
|
|
st.write("**Current Tools:**", st.session_state.tool_names[0]) |
|
else: |
|
st.error("❌ Failed to create vector store. Please check the uploaded files (supported formats: txt, pdf, docx).") |
|
else: |
|
st.info("Please upload document(s) to create the vector store.") |
|
|
|
|
|
|
|
|
|
def display_text_with_equations(text: str): |
|
|
|
modified_text = text.replace("\\(", "$").replace("\\)", "$") |
|
|
|
|
|
modified_text = modified_text.replace("\\[", "$$").replace("\\]", "$$") |
|
|
|
|
|
st.markdown(modified_text) |
|
|
|
|
|
def read_audio(audio_bytes: bytes) -> Optional[str]: |
|
""" |
|
Read audio bytes and return the corresponding text. |
|
""" |
|
try: |
|
audio_data = BytesIO(audio_bytes) |
|
audio_data.name = "recorded_audio.wav" |
|
|
|
transcript = st.session_state.openai.audio.transcriptions.create( |
|
model="whisper-1", file=audio_data |
|
) |
|
text = transcript.text |
|
except Exception as e: |
|
text = None |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
|
|
return text |
|
|
|
|
|
def input_from_mic() -> Optional[str]: |
|
""" |
|
Convert audio input from mic to text and return it. |
|
If there is no audio input, None is returned. |
|
""" |
|
|
|
time.sleep(0.5) |
|
audio_bytes = audio_recorder( |
|
pause_threshold=3.0, text="Speak", icon_size="2x", |
|
recording_color="#e87070", neutral_color="#6aa36f" |
|
) |
|
|
|
if audio_bytes == st.session_state.audio_bytes or audio_bytes is None: |
|
return None |
|
else: |
|
st.session_state.audio_bytes = audio_bytes |
|
return read_audio(audio_bytes) |
|
|
|
|
|
def perform_tts(text: str) -> Optional[HttpxBinaryResponseContent]: |
|
""" |
|
Take text as input, perform text-to-speech (TTS), |
|
and return an audio_response. |
|
""" |
|
|
|
try: |
|
with st.spinner("TTS in progress..."): |
|
audio_response = st.session_state.openai.audio.speech.create( |
|
model="tts-1", |
|
voice="fable", |
|
input=text, |
|
) |
|
except Exception as e: |
|
audio_response = None |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
|
|
return audio_response |
|
|
|
|
|
def play_audio(audio_response: HttpxBinaryResponseContent) -> None: |
|
""" |
|
Take an audio response (a bytes-like object) |
|
from TTS as input, and play the audio. |
|
""" |
|
|
|
audio_data = audio_response.read() |
|
|
|
|
|
b64 = base64.b64encode(audio_data).decode("utf-8") |
|
|
|
|
|
md = f""" |
|
<audio controls autoplay style="width: 100%;"> |
|
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3"> |
|
Your browser does not support the audio element. |
|
</audio> |
|
""" |
|
|
|
|
|
st.markdown(md, unsafe_allow_html=True) |
|
|
|
|
|
def image_to_base64(image: Image) -> str: |
|
""" |
|
Convert an image object from PIL to a base64-encoded image, |
|
and return the resulting encoded image as a string to be used |
|
in place of a URL. |
|
""" |
|
|
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
|
|
buffered_image = BytesIO() |
|
image.save(buffered_image, format="JPEG") |
|
|
|
|
|
img_str = base64.b64encode(buffered_image.getvalue()) |
|
|
|
|
|
base64_image = img_str.decode("utf-8") |
|
|
|
return f"data:image/jpeg;base64,{base64_image}" |
|
|
|
|
|
def shorten_image(image: Image, max_pixels: int=1024) -> Image: |
|
""" |
|
Take an Image object as input, and shorten the image size |
|
if the image is greater than max_pixels x max_pixels. |
|
""" |
|
|
|
if max(image.width, image.height) > max_pixels: |
|
if image.width > image.height: |
|
new_width, new_height = 1024, image.height * 1024 // image.width |
|
else: |
|
new_width, new_height = image.width * 1024 // image.height, 1024 |
|
|
|
image = image.resize((new_width, new_height)) |
|
|
|
return image |
|
|
|
|
|
def upload_image_files_return_urls( |
|
type: List[str]=["jpg", "jpeg", "png", "bmp"] |
|
) -> List[str]: |
|
|
|
""" |
|
Upload image files, convert them to base64-encoded images, and |
|
return the list of the resulting encoded images to be used |
|
in place of URLs. |
|
""" |
|
|
|
st.write("") |
|
st.write("**Query Image(s)**") |
|
source = st.radio( |
|
label="Image selection", |
|
options=("Uploaded", "From URL"), |
|
horizontal=True, |
|
label_visibility="collapsed", |
|
) |
|
image_urls = [] |
|
|
|
if source == "Uploaded": |
|
uploaded_files = st.file_uploader( |
|
label="Upload images", |
|
type=type, |
|
accept_multiple_files=True, |
|
label_visibility="collapsed", |
|
key="image_upload_" + str(st.session_state.uploader_key), |
|
) |
|
if uploaded_files: |
|
try: |
|
for image_file in uploaded_files: |
|
image = Image.open(image_file) |
|
thumbnail = shorten_image(image, 300) |
|
st.image(thumbnail) |
|
image = shorten_image(image, 1024) |
|
image_urls.append(image_to_base64(image)) |
|
except UnidentifiedImageError as e: |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
else: |
|
image_url = st.text_input( |
|
label="URL of the image", |
|
label_visibility="collapsed", |
|
key="image_url_" + str(st.session_state.uploader_key), |
|
) |
|
if image_url: |
|
if is_url(image_url): |
|
st.image(image_url) |
|
image_urls = [image_url] |
|
else: |
|
st.error("Enter a proper URL", icon="🚨") |
|
|
|
return image_urls |
|
|
|
|
|
def fig_to_base64(fig: Figure) -> str: |
|
""" |
|
Convert a Figure object to a base64-encoded image, and return |
|
the resulting encoded image to be used in place of a URL. |
|
""" |
|
|
|
with BytesIO() as buffer: |
|
fig.savefig(buffer, format="JPEG") |
|
buffer.seek(0) |
|
image = Image.open(buffer) |
|
|
|
return image_to_base64(image) |
|
|
|
|
|
def is_url(text: str) -> bool: |
|
""" |
|
Determine whether text is a URL or not. |
|
""" |
|
|
|
regex = r"(http|https)://([\w_-]+(?:\.[\w_-]+)+)(:\S*)?" |
|
p = re.compile(regex) |
|
match = p.match(text) |
|
if match: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def reset_conversation() -> None: |
|
""" |
|
Reset the session_state variables for resetting the conversation. |
|
""" |
|
|
|
st.session_state.history = [] |
|
st.session_state.ai_role[1] = st.session_state.ai_role[0] |
|
st.session_state.prompt_exists = False |
|
st.session_state.temperature[1] = st.session_state.temperature[0] |
|
st.session_state.audio_response = None |
|
st.session_state.vector_store_message = None |
|
st.session_state.tool_names[1] = st.session_state.tool_names[0] |
|
st.session_state.agent_type[1] = st.session_state.agent_type[0] |
|
st.session_state.retriever_tool = None |
|
st.session_state.uploader_key = 0 |
|
|
|
|
|
def switch_between_apps() -> None: |
|
""" |
|
Keep the chat settings when switching the mode. |
|
""" |
|
|
|
st.session_state.temperature[1] = st.session_state.temperature[0] |
|
st.session_state.ai_role[1] = st.session_state.ai_role[0] |
|
st.session_state.tool_names[1] = st.session_state.tool_names[0] |
|
st.session_state.agent_type[1] = st.session_state.agent_type[0] |
|
|
|
|
|
@tool |
|
def python_repl( |
|
code: Annotated[str, "The python code to execute to generate your chart."], |
|
): |
|
"""Use this to execute python code. If you want to see the output of a value, |
|
you should print it out with `print(...)`. This is visible to the user.""" |
|
try: |
|
result = PythonREPL().run(code) |
|
except BaseException as e: |
|
return f"Failed to execute. Error: {repr(e)}" |
|
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}" |
|
return ( |
|
result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER." |
|
) |
|
|
|
|
|
def set_tools() -> List[Tool]: |
|
""" |
|
Set and return the tools for the agent. Tools that can be selected |
|
are internet_search, arxiv, wikipedia, python_repl, and retrieval. |
|
A Bing Subscription Key or Google CSE ID is required for internet_search. |
|
""" |
|
|
|
class MySearchToolInput(BaseModel): |
|
query: str = Field(description="search query to look up") |
|
|
|
|
|
arxiv = load_tools(["arxiv"])[0] |
|
wikipedia = load_tools(["wikipedia"])[0] |
|
|
|
tool_dictionary = { |
|
"ArXiv": arxiv, |
|
"Wikipedia": wikipedia, |
|
"Python_REPL": python_repl, |
|
"Retrieval": st.session_state.retriever_tool if st.session_state.retriever_tool else None |
|
} |
|
tool_options = ["ArXiv", "Wikipedia", "Python_REPL", "Retrieval"] |
|
|
|
|
|
if st.session_state.bing_subscription_validity: |
|
search = BingSearchAPIWrapper() |
|
elif st.session_state.google_cse_id_validity: |
|
search = GoogleSearchAPIWrapper() |
|
else: |
|
search = None |
|
|
|
if search is not None: |
|
internet_search = Tool( |
|
name="internet_search", |
|
description=( |
|
"A search engine for comprehensive, accurate, and trusted results. " |
|
"Useful for when you need to answer questions about current events. " |
|
"Input should be a search query." |
|
), |
|
func=partial(search.results, num_results=5), |
|
args_schema=MySearchToolInput, |
|
) |
|
tool_options.insert(0, "Search") |
|
tool_dictionary["Search"] = internet_search |
|
|
|
|
|
st.write("") |
|
st.write("**Tools**") |
|
tool_names = st.multiselect( |
|
label="assistant tools", |
|
options=tool_options, |
|
default=st.session_state.tool_names[1], |
|
label_visibility="collapsed", |
|
) |
|
|
|
|
|
if "Search" not in tool_options: |
|
st.write( |
|
"<small>Tools are disabled when images are uploaded and queried. " |
|
"To search the internet, obtain your Bing Subscription Key " |
|
"[here](https://portal.azure.com/) or Google CSE ID " |
|
"[here](https://programmablesearchengine.google.com/about/), " |
|
"and enter it in the sidebar. Once entered, 'Search' will be displayed " |
|
"in the list of tools. Note also that PythonREPL from LangChain is still " |
|
"in the experimental phase, so caution is advised.</small>", |
|
unsafe_allow_html=True, |
|
) |
|
else: |
|
st.write( |
|
"<small>Tools are disabled when images are uploaded and queried. " |
|
"Note also that PythonREPL from LangChain is still in the experimental phase, " |
|
"so caution is advised.</small>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
if "Retrieval" in tool_names: |
|
if not st.session_state.retriever_tool: |
|
st.info("Creating the vector store and initializing the retriever tool...") |
|
get_retriever() |
|
if st.session_state.retriever_tool: |
|
st.success("Retriever tool is ready for querying.") |
|
tool_dictionary["Retrieval"] = st.session_state.retriever_tool |
|
else: |
|
st.error("Failed to initialize the retriever tool. Please upload the document again.") |
|
tool_names.remove("Retrieval") |
|
|
|
|
|
tools = [ |
|
tool_dictionary[key] |
|
for key in tool_names if tool_dictionary[key] is not None |
|
] |
|
|
|
st.write("**Tools selected in set_tools:**", [tool.name for tool in tools]) |
|
st.session_state.tool_names[0] = tool_names |
|
|
|
return tools |
|
|
|
|
|
|
|
def set_prompts(agent_type: Literal["Tool Calling", "ReAct"]) -> None: |
|
""" |
|
Set chat and agent prompts for two different types of agents: |
|
Tool Calling and ReAct. |
|
""" |
|
|
|
if agent_type == "Tool Calling": |
|
st.session_state.chat_prompt = ChatPromptTemplate.from_messages([ |
|
( |
|
"system", |
|
f"{st.session_state.ai_role[0]} Your goal is to provide " |
|
"answers to human inquiries. Should the information not " |
|
"be available, inform the human explicitly that " |
|
"the answer could not be found." |
|
), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{input}"), |
|
]) |
|
st.session_state.agent_prompt = ChatPromptTemplate.from_messages([ |
|
( |
|
"system", |
|
f"{st.session_state.ai_role[0]} Your goal is to provide answers to human inquiries. " |
|
"You should specify the source of your answers, whether they are based on internet search " |
|
"results ('internet_search'), scientific articles from arxiv.org ('arxiv'), Wikipedia documents ('wikipedia'), " |
|
"uploaded documents ('retriever'), or your general knowledge. " |
|
"Use the 'retriever' tool to answer questions specifically related to uploaded documents. " |
|
"If you cannot find relevant information in the documents using the 'retriever' tool, explicitly inform the user. " |
|
"Use Markdown syntax and include relevant sources, such as links (URLs)." |
|
), |
|
MessagesPlaceholder(variable_name="chat_history", optional=True), |
|
("human", "{input}"), |
|
MessagesPlaceholder(variable_name="agent_scratchpad"), |
|
]) |
|
else: |
|
st.session_state.chat_prompt = ChatPromptTemplate.from_template( |
|
f"{st.session_state.ai_role[0]} " |
|
"Your goal is to provide answers to human inquiries. " |
|
"Should the information not be available, inform the human " |
|
"explicitly that the answer could not be found.\n\n" |
|
"{chat_history}\n\nHuman: {input}\n\n" |
|
"AI: " |
|
) |
|
st.session_state.agent_prompt = ChatPromptTemplate.from_template( |
|
f"{st.session_state.ai_role[0]} " |
|
"Your goal is to provide answers to human inquiries. " |
|
"When giving your answers, tell the human what your response " |
|
"is based on and which tools you use. Use Markdown syntax " |
|
"and include relevant sources, such as links (URLs), following " |
|
"MLA format. Should the information not be available, inform " |
|
"the human explicitly that the answer could not be found.\n\n" |
|
"TOOLS:\n" |
|
"------\n\n" |
|
"You have access to the following tools:\n\n" |
|
"{tools}\n\n" |
|
"To use a tool, please use the following format:\n\n" |
|
"Thought: Do I need to use a tool? Yes\n" |
|
"Action: the action to take, should be one of [{tool_names}]\n" |
|
"Action Input: the input to the action\n" |
|
"Observation: the result of the action\n\n" |
|
"When you have a response to say to the Human, " |
|
"or if you do not need to use a tool, you MUST use " |
|
"the format:\n\n" |
|
"Thought: Do I need to use a tool? No\n" |
|
"Final Answer: [your response here]\n\n" |
|
"Begin!\n\n" |
|
"Previous conversation history:\n\n" |
|
"{chat_history}\n\n" |
|
"New input: {input}\n" |
|
"{agent_scratchpad}" |
|
) |
|
|
|
|
|
def print_conversation(no_of_msgs: Union[Literal["All"], int]) -> None: |
|
""" |
|
Print the conversation stored in st.session_state.history. |
|
""" |
|
|
|
if no_of_msgs == "All": |
|
no_of_msgs = len(st.session_state.history) |
|
|
|
for msg in st.session_state.history[-no_of_msgs:]: |
|
if isinstance(msg, HumanMessage): |
|
with st.chat_message("human"): |
|
st.write(msg.content) |
|
else: |
|
with st.chat_message("ai"): |
|
display_text_with_equations(msg.content) |
|
|
|
if urls := msg.additional_kwargs.get("image_urls"): |
|
for url in urls: |
|
st.image(url) |
|
|
|
|
|
if ( |
|
st.session_state.model_type == "GPT Models from OpenAI" |
|
and st.session_state.audio_response is not None |
|
): |
|
play_audio(st.session_state.audio_response) |
|
st.session_state.audio_response = None |
|
|
|
|
|
def serialize_messages( |
|
messages: List[Union[HumanMessage, AIMessage]] |
|
) -> List[Dict]: |
|
|
|
""" |
|
Serialize the list of messages into a list of dicts |
|
""" |
|
|
|
return [msg.dict() for msg in messages] |
|
|
|
|
|
def deserialize_messages( |
|
serialized_messages: List[Dict] |
|
) -> List[Union[HumanMessage, AIMessage]]: |
|
|
|
""" |
|
Deserialize the list of messages from a list of dicts |
|
""" |
|
|
|
deserialized_messages = [] |
|
for msg in serialized_messages: |
|
if msg['type'] == 'human': |
|
deserialized_messages.append(HumanMessage(**msg)) |
|
elif msg['type'] == 'ai': |
|
deserialized_messages.append(AIMessage(**msg)) |
|
return deserialized_messages |
|
|
|
|
|
def show_uploader() -> None: |
|
""" |
|
Set the flag to show the uploader. |
|
""" |
|
|
|
st.session_state.show_uploader = True |
|
|
|
|
|
def check_conversation_keys(lst: List[Dict[str, Any]]) -> bool: |
|
""" |
|
Check if all items in the given list are valid conversation entries. |
|
""" |
|
|
|
return all( |
|
isinstance(item, dict) and |
|
isinstance(item.get("content"), str) and |
|
isinstance(item.get("type"), str) and |
|
isinstance(item.get("additional_kwargs"), dict) |
|
for item in lst |
|
) |
|
|
|
|
|
def load_conversation() -> bool: |
|
""" |
|
Load the conversation from a JSON file |
|
""" |
|
|
|
st.write("") |
|
st.write("**Choose a (JSON) conversation file**") |
|
uploaded_file = st.file_uploader( |
|
label="Load conversation", type="json", label_visibility="collapsed" |
|
) |
|
if uploaded_file: |
|
try: |
|
data = json.load(uploaded_file) |
|
if isinstance(data, list) and check_conversation_keys(data): |
|
st.session_state.history = deserialize_messages(data) |
|
return True |
|
st.error( |
|
f"The uploaded data does not conform to the expected format.", icon="🚨" |
|
) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
|
|
return False |
|
|
|
|
|
def create_text(model: str) -> None: |
|
""" |
|
Take an LLM as input and generate text based on user input |
|
by calling run_agent(). |
|
""" |
|
|
|
|
|
general_role = "You are a helpful AI assistant." |
|
english_teacher = ( |
|
"You are an AI English teacher who analyzes texts and corrects " |
|
"any grammatical issues if necessary." |
|
) |
|
translator = ( |
|
"You are an AI translator who translates English into Korean " |
|
"and Korean into English." |
|
) |
|
coding_adviser = ( |
|
"You are an AI expert in coding who provides advice on " |
|
"good coding styles." |
|
) |
|
science_assistant = "You are an AI science assistant." |
|
roles = ( |
|
general_role, english_teacher, translator, |
|
coding_adviser, science_assistant |
|
) |
|
|
|
with st.sidebar: |
|
st.write("") |
|
type_options = ("Tool Calling", "ReAct") |
|
st.write("**Agent Type**") |
|
st.session_state.agent_type[0] = st.sidebar.radio( |
|
label="Agent Type", |
|
options=type_options, |
|
index=type_options.index(st.session_state.agent_type[1]), |
|
label_visibility="collapsed", |
|
) |
|
agent_type = st.session_state.agent_type[0] |
|
if st.session_state.model_type == "GPT Models from OpenAI": |
|
st.write("") |
|
st.write("**Text to Speech**") |
|
st.session_state.tts = st.radio( |
|
label="TTS", |
|
options=("Enabled", "Disabled", "Auto"), |
|
|
|
index=1, |
|
label_visibility="collapsed", |
|
) |
|
st.write("") |
|
st.write("**Temperature**") |
|
st.session_state.temperature[0] = st.slider( |
|
label="Temperature (higher $\Rightarrow$ more random)", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=st.session_state.temperature[1], |
|
step=0.1, |
|
format="%.1f", |
|
label_visibility="collapsed", |
|
) |
|
st.write("") |
|
st.write("**Messages to Show**") |
|
no_of_msgs = st.radio( |
|
label="$\\textsf{Messages to show}$", |
|
options=("All", 20, 10), |
|
label_visibility="collapsed", |
|
horizontal=True, |
|
index=2, |
|
) |
|
|
|
st.write("") |
|
st.write("##### Message to AI") |
|
st.session_state.ai_role[0] = st.selectbox( |
|
label="AI's role", |
|
options=roles, |
|
index=roles.index(st.session_state.ai_role[1]), |
|
label_visibility="collapsed", |
|
) |
|
|
|
if st.session_state.ai_role[0] != st.session_state.ai_role[1]: |
|
reset_conversation() |
|
st.rerun() |
|
|
|
st.write("") |
|
st.write("##### Conversation with AI") |
|
|
|
|
|
print_conversation(no_of_msgs) |
|
|
|
|
|
c1, c2, c3 = st.columns(3) |
|
c1.button( |
|
label="$~\:\,\,$Reset$~\:\,\,$", |
|
on_click=reset_conversation |
|
) |
|
c2.download_button( |
|
label="Download", |
|
data=json.dumps(serialize_messages(st.session_state.history), indent=4), |
|
file_name="conversation_with_agent.json", |
|
mime="application/json", |
|
) |
|
c3.button( |
|
label="$~~\:\,$Load$~~\:\,$", |
|
on_click=show_uploader, |
|
) |
|
|
|
if st.session_state.show_uploader and load_conversation(): |
|
st.session_state.show_uploader = False |
|
st.rerun() |
|
|
|
|
|
set_prompts(agent_type) |
|
tools = set_tools() |
|
st.write("**Tools passed to run_agent:**", [tool.name for tool in tools]) |
|
|
|
|
|
image_urls = [] |
|
with st.sidebar: |
|
image_urls = upload_image_files_return_urls() |
|
|
|
if st.session_state.model_type == "GPT Models from OpenAI": |
|
audio_input = input_from_mic() |
|
if audio_input is not None: |
|
query = audio_input |
|
st.session_state.prompt_exists = True |
|
st.session_state.mic_used = True |
|
|
|
|
|
text_input = st.chat_input(placeholder="Enter your query") |
|
|
|
if text_input: |
|
query = text_input.strip() |
|
st.session_state.prompt_exists = True |
|
|
|
if st.session_state.prompt_exists: |
|
with st.chat_message("human"): |
|
st.write(query) |
|
|
|
with st.chat_message("ai"): |
|
generated_text = run_agent( |
|
query=query, |
|
model=model, |
|
tools=tools, |
|
image_urls=image_urls, |
|
temperature=st.session_state.temperature[0], |
|
agent_type=agent_type, |
|
) |
|
fig = plt.gcf() |
|
if fig and fig.get_axes(): |
|
generated_image_url = fig_to_base64(fig) |
|
st.session_state.history[-1].additional_kwargs["image_urls"] = [ |
|
generated_image_url |
|
] |
|
if ( |
|
st.session_state.model_type == "GPT Models from OpenAI" |
|
and generated_text is not None |
|
): |
|
|
|
cond1 = st.session_state.tts == "Enabled" |
|
cond2 = st.session_state.tts == "Auto" and st.session_state.mic_used |
|
if cond1 or cond2: |
|
st.session_state.audio_response = perform_tts(generated_text) |
|
st.session_state.mic_used = False |
|
|
|
st.session_state.prompt_exists = False |
|
|
|
if generated_text is not None: |
|
st.session_state.uploader_key += 1 |
|
st.rerun() |
|
|
|
|
|
def create_image(model: str) -> None: |
|
""" |
|
Generate image based on user description by calling openai_create_image(). |
|
""" |
|
|
|
|
|
with st.sidebar: |
|
st.write("") |
|
st.write("**Pixel size**") |
|
image_size = st.radio( |
|
label="$\\hspace{0.1em}\\texttt{Pixel size}$", |
|
options=("1024x1024", "1792x1024", "1024x1792"), |
|
|
|
index=0, |
|
label_visibility="collapsed", |
|
) |
|
|
|
st.write("") |
|
st.write("##### Description for your image") |
|
|
|
if st.session_state.image_url is not None: |
|
st.info(st.session_state.image_description) |
|
st.image(image=st.session_state.image_url, use_column_width=True) |
|
|
|
|
|
if st.session_state.model_type == "GPT Models from OpenAI": |
|
audio_input = input_from_mic() |
|
if audio_input is not None: |
|
st.session_state.image_description = audio_input |
|
st.session_state.prompt_exists = True |
|
|
|
|
|
text_input = st.chat_input( |
|
placeholder="Enter a description for your image", |
|
) |
|
if text_input: |
|
st.session_state.image_description = text_input.strip() |
|
st.session_state.prompt_exists = True |
|
|
|
if st.session_state.prompt_exists: |
|
st.session_state.image_url = openai_create_image( |
|
st.session_state.image_description, model, image_size |
|
) |
|
st.session_state.prompt_exists = False |
|
if st.session_state.image_url is not None: |
|
st.rerun() |
|
|
|
|
|
def create_text_image() -> None: |
|
""" |
|
Generate text or image by using LLM models like 'gpt-4o'. |
|
""" |
|
|
|
page_title = "LangChain LLM Agent" |
|
page_icon = "📚" |
|
|
|
st.set_page_config( |
|
page_title=page_title, |
|
page_icon=page_icon, |
|
layout="centered" |
|
) |
|
|
|
st.write(f"## {page_icon} $\,${page_title}") |
|
|
|
|
|
initialize_session_state_variables() |
|
|
|
|
|
model_options = ["gpt-4o-mini", "gpt-4o", "dall-e-3"] |
|
|
|
|
|
with st.sidebar: |
|
st.write("**Select a Model**") |
|
model = st.radio( |
|
label="Models", |
|
options=model_options, |
|
index=1, |
|
label_visibility="collapsed", |
|
on_change=switch_between_apps, |
|
) |
|
|
|
st.write("---") |
|
st.write("xyz", unsafe_allow_html=True) |
|
|
|
|
|
if model == "dall-e-3": |
|
create_image(model) |
|
else: |
|
create_text(model) |
|
|
|
if __name__ == "__main__": |
|
create_text_image() |