mai-sam-sj-v1 / main.py
samsonleegh's picture
Upload 7 files
53b4105 verified
raw
history blame
No virus
6.9 kB
import chromadb
from llama_index.core.base.embeddings.base import similarity
#from llama_index.llms.ollama import Ollama
from llama_index.llms.groq import Groq
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex
from llama_index.core import StorageContext, get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import load_index_from_storage
import os
from dotenv import load_dotenv
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SimilarityPostprocessor
import time
import gradio as gr
from llama_index.core.memory import ChatMemoryBuffer
from llama_parse import LlamaParse
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.chat_engine import CondenseQuestionChatEngine
load_dotenv()
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY')
# set up callback manager
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug])
Settings.callback_manager = callback_manager
# set up LLM
llm = Groq(model="llama3-70b-8192")#"llama3-8b-8192")
Settings.llm = llm
# set up embedding model
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.embed_model = embed_model
# create splitter
splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50)
Settings.transformations = [splitter]
# create parser
parser = LlamaParse(
api_key=LLAMAINDEX_API_KEY,
result_type="markdown", # "markdown" and "text" are available
verbose=True,
)
#create index
if os.path.exists("./vectordb"):
print("Index Exists!")
storage_context = StorageContext.from_defaults(persist_dir="./vectordb")
index = load_index_from_storage(storage_context)
else:
filename_fn = lambda filename: {"file_name": filename}
required_exts = [".pdf",".docx"]
file_extractor = {".pdf": parser}
reader = SimpleDirectoryReader(
input_dir="./data",
file_extractor=file_extractor,
required_exts=required_exts,
recursive=True,
file_metadata=filename_fn
)
documents = reader.load_data()
print("index creating with `%d` documents", len(documents))
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter])
index.storage_context.persist(persist_dir="./vectordb")
"""
#create document summary index
if os.path.exists("./docsummarydb"):
print("Index Exists!")
storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb")
doc_index = load_index_from_storage(storage_context)
else:
filename_fn = lambda filename: {"file_name": filename}
required_exts = [".pdf",".docx"]
reader = SimpleDirectoryReader(
input_dir="./data",
required_exts=required_exts,
recursive=True,
file_metadata=filename_fn
)
documents = reader.load_data()
print("index creating with `%d` documents", len(documents))
response_synthesizer = get_response_synthesizer(
response_mode="tree_summarize", use_async=True
)
doc_index = DocumentSummaryIndex.from_documents(
documents,
llm = llm,
transformations = [splitter],
response_synthesizer = response_synthesizer,
show_progress = True
)
doc_index.storage_context.persist(persist_dir="./docsummarydb")
"""
"""
retriever = DocumentSummaryIndexEmbeddingRetriever(
doc_index,
similarity_top_k=5,
)
"""
# set up retriever
retriever = VectorIndexRetriever(
index = index,
similarity_top_k = 10,
#vector_store_query_mode="mmr",
#vector_store_kwargs={"mmr_threshold": 0.4}
)
# set up response synthesizer
response_synthesizer = get_response_synthesizer()
### customising prompts worsened the result###
"""
# set up prompt template
qa_prompt_tmpl = (
"Context information from multiple sources is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the information from multiple sources and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
)
qa_prompt = PromptTemplate(qa_prompt_tmpl)
"""
# setting up query engine
query_engine = RetrieverQueryEngine(
retriever = retriever,
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)],
response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True)
)
print(query_engine.get_prompts())
#response = query_engine.query("What happens if the distributor wants its own warehouse for pizzahood?")
#print(response)
memory = ChatMemoryBuffer.from_defaults(token_limit=10000)
custom_prompt = PromptTemplate(
"""\
Given a conversation (between Human and Assistant) and a follow up message from Human, \
rewrite the message to be a standalone question that captures all relevant context \
from the conversation. If you are unsure, ask for more information.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
"""
)
# list of `ChatMessage` objects
custom_chat_history = [
ChatMessage(
role=MessageRole.USER,
content="Hello assistant.",
),
ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."),
]
chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine,
condense_question_prompt=custom_prompt,
chat_history=custom_chat_history,
verbose=True,
memory=memory
)
# gradio with streaming support
with gr.Blocks() as demo:
chat_engine = chat_engine
chatbot = gr.Chatbot()
msg = gr.Textbox(label="⏎ for sending",
placeholder="Ask me something",)
clear = gr.Button("Delete")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
user_message = history[-1][0]
#bot_message = chat_engine.chat(user_message)
bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.")
history[-1][1] = ""
for character in bot_message.response:
history[-1][1] += character
time.sleep(0.01)
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# demo.queue()
demo.launch(share=False)