Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_postgres.vectorstores import PGVector | |
from langchain_openai import ChatOpenAI | |
from langchain.schema import HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.chains import create_history_aware_retriever | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
import qdrant_client | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from llama_index.core import VectorStoreIndex, StorageContext | |
from llama_index.core import SimpleDirectoryReader | |
from llama_index.core.indices.multi_modal.base import MultiModalVectorStoreIndex | |
from llama_index.multi_modal_llms.openai import OpenAIMultiModal | |
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
chat_llm = ChatOpenAI(temperature = 0.5, model = 'gpt-4-turbo') | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
qa_system_prompt = """You are an assistant for question-answering tasks. \ | |
Use the following pieces of retrieved context to answer the question. \ | |
If you don't know the answer, just say that you don't know. \ | |
context: {context}""" | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(chat_llm, qa_prompt) | |
pg_password = os.getenv("PG_PASSWORD") | |
aws_ec2_ip = os.getenv("AWS_EC2_IP") | |
pg_connection = f"postgresql+psycopg://postgres:{pg_password}@{aws_ec2_ip}:5432/postgres" | |
qd_client = qdrant_client.QdrantClient(path="qdrant_db") | |
image_store = QdrantVectorStore(client=qd_client, collection_name="image_collection") | |
storage_context = StorageContext.from_defaults(image_store=image_store) | |
openai_mm_llm = OpenAIMultiModal(model="gpt-4o", max_new_tokens=1500) | |
def response(message, history, doc_label): | |
text_store = PGVector(collection_name=doc_label, | |
embeddings=embeddings, | |
connection=pg_connection) | |
retriever = text_store.as_retriever() | |
history_aware_retriever = create_history_aware_retriever(chat_llm, | |
retriever, | |
contextualize_q_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
response = rag_chain.invoke({"input": message, "chat_history": chat_history}) | |
chat_history.extend([HumanMessage(content=message), response["answer"]]) | |
return response["answer"] | |
def img_retrieve(query, doc_label): | |
doc_imgs = SimpleDirectoryReader(f"./{doc_label}").load_data() | |
index = MultiModalVectorStoreIndex.from_documents(doc_imgs, | |
storage_context=storage_context) | |
img_query_engine = index.as_query_engine(llm=openai_mm_llm, | |
image_similarity_top_k=3) | |
response_mm = img_query_engine.query(query) | |
retrieved_imgs = [n.metadata["file_path"] for n in response_mm.metadata["image_nodes"]] | |
return retrieved_imgs | |
chat_history = [] | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# 🎨 Multi-modal RAG Chatbot | |
""") | |
with gr.Row(): | |
gr.Markdown("""Select document from the menu, and interact with the text and images in the document. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
doc_label = gr.Dropdown(["LLaVA", "Interior"], label="Select a document:") | |
chatbot = gr.ChatInterface(fn=response, additional_inputs=[doc_label], fill_height=True) | |
with gr.Column(scale=1): | |
sample_1 = "https://i.pinimg.com/originals/e3/44/d7/e344d7631cd515edd36cc6930deaedec.jpg" | |
sample_2 = "https://live.staticflickr.com/5307/5765340890_e386f42a99_b.jpg" | |
sample_3 = "https://blog.kakaocdn.net/dn/nqcUB/btrzYjTgjWl/jFFlIBrdkoKv4jbSyZbiEk/img.jpg" | |
gallery = gr.Gallery(label="Retrieved images", | |
show_label=True, preview=True, | |
object_fit="contain", | |
value=[(sample_1, 'sample_1'), | |
(sample_2, 'sample_2'), | |
(sample_3, 'sample_3')]) | |
query = gr.Textbox(label="Enter query") | |
button = gr.Button(value="Retrieve images") | |
button.click(img_retrieve, [query, doc_label], gallery) | |
demo.launch(share=True) |