import gradio as gr import os from langchain_openai import OpenAIEmbeddings from langchain_postgres.vectorstores import PGVector from langchain_openai import ChatOpenAI from langchain.schema import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_history_aware_retriever from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain import qdrant_client from llama_index.vector_stores.qdrant import QdrantVectorStore from llama_index.core import VectorStoreIndex, StorageContext from llama_index.core import SimpleDirectoryReader from llama_index.core.indices.multi_modal.base import MultiModalVectorStoreIndex from llama_index.multi_modal_llms.openai import OpenAIMultiModal os.environ["OPENAI_API_KEY"] = "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" embeddings = OpenAIEmbeddings(model="text-embedding-3-small") chat_llm = ChatOpenAI(temperature = 0.5, model = 'gpt-4-turbo') contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) qa_system_prompt = """You are an assistant for question-answering tasks. \ Use the following pieces of retrieved context to answer the question. \ If you don't know the answer, just say that you don't know. \ context: {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(chat_llm, qa_prompt) # pg_connection = "postgresql+psycopg://postgres:3434@localhost:5433/mmrag" pg_connection = "postgresql+psycopg://postgres:3434@43.201.34.191:5432/postgres" qd_client = qdrant_client.QdrantClient(path="qdrant_db") image_store = QdrantVectorStore(client=qd_client, collection_name="image_collection") storage_context = StorageContext.from_defaults(image_store=image_store) openai_mm_llm = OpenAIMultiModal(model="gpt-4o", max_new_tokens=1500) def response(message, history, doc_label): text_store = PGVector(collection_name=doc_label, embeddings=embeddings, connection=pg_connection) retriever = text_store.as_retriever() history_aware_retriever = create_history_aware_retriever(chat_llm, retriever, contextualize_q_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) response = rag_chain.invoke({"input": message, "chat_history": chat_history}) chat_history.extend([HumanMessage(content=message), response["answer"]]) return response["answer"] def img_retrieve(query, doc_label): doc_imgs = SimpleDirectoryReader(f"./{doc_label}").load_data() index = MultiModalVectorStoreIndex.from_documents(doc_imgs, storage_context=storage_context) img_query_engine = index.as_query_engine(llm=openai_mm_llm, image_similarity_top_k=3) response_mm = img_query_engine.query(query) retrieved_imgs = [n.metadata["file_path"] for n in response_mm.metadata["image_nodes"]] return retrieved_imgs chat_history = [] with gr.Blocks(theme=gr.themes.Monochrome()) as demo: with gr.Row(): gr.Markdown( """ # 🎨 Multi-modal RAG Chatbot """) with gr.Row(): gr.Markdown("""Select document from the menu, and interact with the text and images in the document. """) with gr.Row(): with gr.Column(scale=2): doc_label = gr.Dropdown(["LLaVA", "Interior"], label="Select a document:") chatbot = gr.ChatInterface(fn=response, additional_inputs=[doc_label], fill_height=True) with gr.Column(scale=1): sample_1 = "https://i.ytimg.com/vi/bLj_mR4Fnls/maxresdefault.jpg" sample_2 = "https://i.ytimg.com/vi/bOJdHU99OO8/maxresdefault.jpg" sample_3 = "https://blog.kakaocdn.net/dn/nqcUB/btrzYjTgjWl/jFFlIBrdkoKv4jbSyZbiEk/img.jpg" gallery = gr.Gallery(label="Retrieved images", show_label=True, preview=True, object_fit="contain", value=[(sample_1, 'sample_1'), (sample_2, 'sample_2'), (sample_3, 'sample_3')]) query = gr.Textbox(label="Enter query") button = gr.Button(value="Retrieve images") button.click(img_retrieve, [query, doc_label], gallery) demo.launch(share=True)