|
from typing import List |
|
|
|
import gradio |
|
import gradio as gr |
|
import spacy |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.text_splitter import SpacyTextSplitter |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage |
|
from langchain_core.prompts import PromptTemplate |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
spacy.cli.download("en_core_web_sm") |
|
|
|
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. |
|
Tips: Make sure to cite your sources, and use the exact words from the context. |
|
{context} |
|
Question: {question} |
|
Helpful Answer:""" |
|
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) |
|
|
|
|
|
def convert_chat_history_to_messages(chat_history) -> List[BaseMessage]: |
|
result = [] |
|
for human, ai in chat_history: |
|
result.append(HumanMessage(content=human)) |
|
result.append(AIMessage(content=ai)) |
|
return result |
|
|
|
|
|
class RAGDemo(object): |
|
def __init__(self): |
|
self.embedding = None |
|
self.vector_db = None |
|
self.chat_model = None |
|
|
|
def _init_chat_model(self, model_name, api_key): |
|
if not api_key: |
|
gradio.Error("Please enter model API key.") |
|
return |
|
if 'glm' in model_name: |
|
gradio.Error("GLM is not supported yet.") |
|
elif 'gemini' in model_name: |
|
self.chat_model = ChatGoogleGenerativeAI( |
|
google_api_key=api_key, |
|
model='gemini-pro', |
|
convert_system_message_to_human=True, |
|
) |
|
|
|
def _init_embedding(self, embedding_model_name, api_key): |
|
if not api_key: |
|
gradio.Error("Please enter embedding API key.") |
|
return |
|
if 'glm' in embedding_model_name: |
|
gradio.Error("GLM is not supported yet.") |
|
else: |
|
self.embedding = HuggingFaceInferenceAPIEmbeddings( |
|
api_key=api_key, model_name=embedding_model_name |
|
) |
|
|
|
def _build_vector_db(self, file_path): |
|
if not file_path: |
|
gradio.Error("Please enter vector database file path.") |
|
return |
|
gr.Info("Building vector database...") |
|
loader = PyPDFLoader(file_path) |
|
pages = loader.load() |
|
|
|
text_splitter = SpacyTextSplitter(chunk_size=500, chunk_overlap=50) |
|
docs = text_splitter.split_documents(pages) |
|
|
|
self.vector_db = Chroma.from_documents( |
|
documents=docs, embedding=self.embedding |
|
) |
|
gr.Info("Vector database built successfully.") |
|
print("Vector database built successfully.") |
|
|
|
def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file): |
|
self._init_chat_model(model_name, api_key) |
|
self._init_embedding(embedding_model, embedding_api_key) |
|
self._build_vector_db(data_file) |
|
|
|
def _retrieval_qa(self, input_text): |
|
basic_qa = RetrievalQA.from_chain_type( |
|
self.chat_model, |
|
retriever=self.vector_db.as_retriever(), |
|
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, |
|
verbose=True, |
|
) |
|
resp = basic_qa.invoke(input_text) |
|
return resp['result'] |
|
|
|
def _chat_qa(self, message, chat_history): |
|
if not message: |
|
return "", chat_history |
|
memory = ConversationBufferMemory( |
|
chat_memory=ChatMessageHistory( |
|
messages=convert_chat_history_to_messages(chat_history) |
|
), |
|
memory_key="chat_history", |
|
return_messages=True, |
|
) |
|
qa = ConversationalRetrievalChain.from_llm( |
|
self.chat_model, |
|
retriever=self.vector_db.as_retriever(), |
|
memory=memory, |
|
verbose=True, |
|
) |
|
resp = qa.invoke(message) |
|
print(f">>> {resp}") |
|
chat_history.append((message, resp['answer'])) |
|
return "", chat_history |
|
|
|
def _retry_chat_qa(self, chat_history): |
|
message = "" |
|
if chat_history: |
|
message, _ = chat_history.pop() |
|
return self._chat_qa(message, chat_history) |
|
|
|
def __call__(self): |
|
with gr.Blocks(title="🔥 RAG Demo") as demo: |
|
gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and " |
|
"[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)") |
|
with gr.Tab("Settings"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
model_name = gr.Dropdown( |
|
choices=['gemini-1.0-pro'], |
|
value='gemini-1.0-pro', |
|
label="model" |
|
) |
|
api_key = gr.Textbox(placeholder="your api key for LLM", label="api key") |
|
embedding_model = gr.Dropdown( |
|
choices=['sentence-transformers/all-MiniLM-L6-v2', |
|
'intfloat/multilingual-e5-large'], |
|
value="sentence-transformers/all-MiniLM-L6-v2", |
|
label="embedding model" |
|
) |
|
embedding_api_key = gr.Textbox(placeholder="your api key for embedding", label="embedding api key") |
|
with gr.Column(): |
|
data_file = gr.File(file_count='single', label="pdf file") |
|
initial_btn = gr.Button("submit") |
|
with gr.Tab("RAG"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(placeholder="input your question...", label="input") |
|
submit_btn = gr.Button("submit") |
|
with gr.Column(): |
|
output = gr.TextArea(label="answer") |
|
with gr.Tab("Chat RAG"): |
|
chatbot = gr.Chatbot(label="chat with pdf") |
|
input_msg = gr.Textbox(placeholder="input your question...", label="input") |
|
with gr.Row(): |
|
clear_btn = gr.ClearButton([chatbot, input_msg], value="🧹 Clear") |
|
retry_btn = gr.Button("♻️ Retry") |
|
initial_btn.click( |
|
self._init_settings, |
|
inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file] |
|
) |
|
|
|
submit_btn.click( |
|
self._retrieval_qa, |
|
inputs=input_text, |
|
outputs=output, |
|
) |
|
|
|
input_msg.submit( |
|
self._chat_qa, |
|
inputs=[input_msg, chatbot], |
|
outputs=[input_msg, chatbot] |
|
) |
|
|
|
retry_btn.click( |
|
self._retry_chat_qa, |
|
inputs=chatbot, |
|
outputs=[input_msg, chatbot] |
|
) |
|
return demo |
|
|
|
|
|
app = RAGDemo() |
|
app().launch(debug=True) |
|
|